Skip to content

Commit b0ebaed

Browse files
author
morelos
committed
[ET-VK][Ops] enabling double support for quantization and dequantization ops
Pull Request resolved: #11553 # Context Since we enabled the possibility for double support in an earlier diff, this enables double support for quantization and dequantization. Since there are limitations to how 64bit can be supported, the expectation is that IO is to be downgraded to 32bit. # Changes We create additional test cases for double support and make sure to pass in the double if its permitted (it's only allowed in buffers), and we also make sure to include double variants in the corresponding YAML files for quantization and dequantization. ghstack-source-id: 291142411 @exported-using-ghexport Differential Revision: [D76289197](https://our.internmc.facebook.com/intern/diff/D76289197/)
1 parent 7bd15b9 commit b0ebaed

File tree

8 files changed

+106
-2
lines changed

8 files changed

+106
-2
lines changed

backends/vulkan/runtime/graph/ops/glsl/dequantize_buffer.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ dequantize_buffer:
1111
OUT_DTYPE:
1212
- VALUE: half
1313
- VALUE: float
14+
- VALUE: double
1415
shader_variants:
1516
- NAME: dequantize_per_tensor_buffer
1617
MODE: per_tensor

backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.glsl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,10 @@ void dequantize_per_tensor() {
139139
[[unroll]] for (int i = 0; i < 4; ++i) {
140140
IN_T qvalue = IN_T(intex[i]);
141141
OUT_T value = dequantize_val(qvalue, scale, zero_point);
142-
outtex[i] = value;
142+
$if OUT_DTYPE == "double":
143+
outtex[i] = float(value);
144+
$else:
145+
outtex[i] = value;
143146
}
144147
write_texel(t_out, pos, outtex);
145148
}
@@ -177,7 +180,10 @@ void dequantize_per_token() {
177180
[[unroll]] for (int i = 0; i < 4; ++i) {
178181
IN_T qvalue = IN_T(intex[i]);
179182
OUT_T value = dequantize_val(qvalue, scale_val, zero_point_val);
180-
outtex[i] = value;
183+
$if OUT_DTYPE == "double":
184+
outtex[i] = float(value);
185+
$else:
186+
outtex[i] = value;
181187
}
182188

183189
write_texel(t_out, pos, outtex);

backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ dequantize_texture:
1111
OUT_DTYPE:
1212
- VALUE: half
1313
- VALUE: float
14+
- VALUE: double
1415
shader_variants:
1516
- NAME: dequantize_per_tensor_texture3d
1617
MODE: per_tensor

backends/vulkan/runtime/graph/ops/glsl/quantize_buffer.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ quantize_buffer:
77
IN_DTYPE:
88
- VALUE: half
99
- VALUE: float
10+
- VALUE: double
1011
OUT_DTYPE:
1112
- VALUE: uint8
1213
- VALUE: int8

backends/vulkan/runtime/graph/ops/glsl/quantize_texture.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ quantize_texture:
77
IN_DTYPE:
88
- VALUE: half
99
- VALUE: float
10+
- VALUE: double
1011
OUT_DTYPE:
1112
- VALUE: uint8
1213
- VALUE: int8

backends/vulkan/runtime/graph/ops/impl/Quantize.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,7 @@ void quantize_per_tensor_impl(
188188

189189
// Verify input is a floating point type
190190
VK_CHECK_COND(
191+
graph.dtype_of(input) == vkapi::kDouble ||
191192
graph.dtype_of(input) == vkapi::kFloat ||
192193
graph.dtype_of(input) == vkapi::kHalf);
193194

@@ -214,6 +215,7 @@ void quantize_per_token_impl(
214215

215216
// Verify input is a floating point type
216217
VK_CHECK_COND(
218+
graph.dtype_of(input) == vkapi::kDouble ||
217219
graph.dtype_of(input) == vkapi::kFloat ||
218220
graph.dtype_of(input) == vkapi::kHalf);
219221

backends/vulkan/test/op_tests/dequantize_test.cpp

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -366,6 +366,12 @@ void test_vulkan_dequantize_per_tensor(
366366
vkcompute::utils::kBuffer,
367367
vkcompute::utils::kBuffer);
368368

369+
// Telling the system to expect a float instead of a double
370+
// since the shader can only return 32bit anyways
371+
if (out_dtype == at::kDouble) {
372+
out_dtype = at::kFloat;
373+
}
374+
369375
// Test with texture storage
370376
test_vulkan_dequantize_per_tensor_impl(
371377
input_sizes,
@@ -400,6 +406,12 @@ void test_vulkan_dequantize_per_token(
400406
vkcompute::utils::kBuffer,
401407
vkcompute::utils::kBuffer);
402408

409+
// Telling the system to expect a float instead of a double
410+
// since the shader can only return 32bit anyways
411+
if (out_dtype == at::kDouble) {
412+
out_dtype = at::kFloat;
413+
}
414+
403415
// Test with texture storage
404416
test_vulkan_dequantize_per_token_impl(
405417
input_sizes,
@@ -793,6 +805,19 @@ TEST(
793805
at::kHalf); // output dtype
794806
}
795807

808+
TEST(
809+
VulkanDequantizePerTensorTest,
810+
test_vulkan_dequantize_per_tensor_int32_to_double) {
811+
test_vulkan_dequantize_per_tensor(
812+
{2, 4, 3}, // input sizes
813+
0.0001, // scale
814+
100, // zero_point
815+
-2147483648, // quant_min
816+
2147483647, // quant_max
817+
at::kInt, // input dtype
818+
at::kDouble); // output dtype
819+
}
820+
796821
void test_reference_dequantize_per_token(
797822
const std::vector<int>& input_sizes,
798823
const std::vector<float>& scales,
@@ -1288,3 +1313,19 @@ TEST(
12881313
at::kInt, // input dtype
12891314
at::kHalf); // output dtype
12901315
}
1316+
1317+
TEST(
1318+
VulkanDequantizePerTokenTest,
1319+
test_vulkan_dequantize_per_token_int32_to_double) {
1320+
std::vector<float> scales = {0.0001, 0.0002, 0.0003, 0.0};
1321+
std::vector<int> zero_points = {100, -100, 50, -50};
1322+
1323+
test_vulkan_dequantize_per_token(
1324+
{2, 2, 8}, // input sizes (2*2=4 tokens)
1325+
scales,
1326+
zero_points,
1327+
-2147483648, // quant_min
1328+
2147483647, // quant_max
1329+
at::kInt, // input dtype
1330+
at::kDouble); // output dtype
1331+
}

backends/vulkan/test/op_tests/quantize_test.cpp

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,12 @@ void test_vulkan_quantize_per_tensor(
315315
vkcompute::utils::kBuffer,
316316
vkcompute::utils::kBuffer);
317317

318+
// If the in_dtype is a double, convert to float for texture implementation
319+
// since they don't support 64bit as inputs
320+
if (in_dtype == at::kDouble) {
321+
in_dtype = at::kFloat;
322+
}
323+
318324
// Test with texture storage
319325
test_vulkan_quantize_per_tensor_impl(
320326
input_sizes,
@@ -349,6 +355,12 @@ void test_vulkan_quantize_per_token(
349355
vkcompute::utils::kBuffer,
350356
vkcompute::utils::kBuffer);
351357

358+
// If the in_dtype is a double, convert to float for texture implementation
359+
// since they don't support 64bit as inputs
360+
if (in_dtype == at::kDouble) {
361+
in_dtype = at::kFloat;
362+
}
363+
352364
// Test with texture storage
353365
test_vulkan_quantize_per_token_impl(
354366
input_sizes,
@@ -655,6 +667,24 @@ TEST(
655667
at::kChar); // output dtype
656668
}
657669

670+
TEST(
671+
VulkanQuantizePerTensorTest,
672+
test_vulkan_quantize_per_tensor_double_to_int8) {
673+
if (!vkcompute::api::context()
674+
->adapter_ptr()
675+
->has_full_int8_buffers_support()) {
676+
GTEST_SKIP();
677+
}
678+
test_vulkan_quantize_per_tensor(
679+
{2, 3}, // input sizes
680+
0.01, // scale
681+
1, // zero_point
682+
-128, // quant_min
683+
127, // quant_max
684+
at::kDouble, // input dtype
685+
at::kChar); // output dtype
686+
}
687+
658688
void test_reference_quantize_per_token(
659689
const std::vector<int>& input_sizes,
660690
const std::vector<float>& pre_scales,
@@ -1075,3 +1105,24 @@ TEST(VulkanQuantizePerTensorTest, test_vulkan_quantize_per_token_half_to_int8) {
10751105
at::kHalf, // input dtype
10761106
at::kChar); // output dtype
10771107
}
1108+
1109+
TEST(
1110+
VulkanQuantizePerTensorTest,
1111+
test_vulkan_quantize_per_token_double_to_int8) {
1112+
if (!vkcompute::api::context()
1113+
->adapter_ptr()
1114+
->has_full_int8_buffers_support()) {
1115+
GTEST_SKIP();
1116+
}
1117+
std::vector<float> scales = {0.1, 0.2};
1118+
std::vector<int> zero_points = {0, 5};
1119+
1120+
test_vulkan_quantize_per_token(
1121+
{2, 2}, // input sizes (2*2=4 tokens)
1122+
scales,
1123+
zero_points,
1124+
-128, // quant_min
1125+
127, // quant_max
1126+
at::kDouble, // input dtype
1127+
at::kChar); // output dtype
1128+
}

0 commit comments

Comments
 (0)