Skip to content

Commit 33d84b4

Browse files
author
morelos
committed
[ET-VK][Ops] enabling double support for quantization and dequantization ops
Pull Request resolved: #11553 # Context Since we enabled the possibility for double support in an earlier diff, this enables double support for quantization and dequantization. Since there are limitations to how 64bit can be supported, the expectation is that IO is to be downgraded to 32bit. # Changes We create additional test cases for double support and make sure to pass in the double if its permitted (it's only allowed in buffers), and we also make sure to include double variants in the corresponding YAML files for quantization and dequantization. ghstack-source-id: 290819507 @exported-using-ghexport Differential Revision: [D76289197](https://our.internmc.facebook.com/intern/diff/D76289197/)
1 parent 16e22af commit 33d84b4

File tree

8 files changed

+96
-2
lines changed

8 files changed

+96
-2
lines changed

backends/vulkan/runtime/graph/ops/glsl/dequantize_buffer.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ dequantize_buffer:
1111
OUT_DTYPE:
1212
- VALUE: half
1313
- VALUE: float
14+
- VALUE: double
1415
shader_variants:
1516
- NAME: dequantize_per_tensor_buffer
1617
MODE: per_tensor

backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.glsl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,10 @@ void dequantize_per_tensor() {
7070
[[unroll]] for (int i = 0; i < 4; ++i) {
7171
IN_T qvalue = IN_T(intex[i]);
7272
OUT_T value = dequantize_val(qvalue, scale, zero_point);
73-
outtex[i] = value;
73+
$if OUT_DTYPE == "double":
74+
outtex[i] = float(value);
75+
$else:
76+
outtex[i] = value;
7477
}
7578
write_texel(t_out, pos, outtex);
7679
}
@@ -108,7 +111,10 @@ void dequantize_per_token() {
108111
[[unroll]] for (int i = 0; i < 4; ++i) {
109112
IN_T qvalue = IN_T(intex[i]);
110113
OUT_T value = dequantize_val(qvalue, scale_val, zero_point_val);
111-
outtex[i] = value;
114+
$if OUT_DTYPE == "double":
115+
outtex[i] = float(value);
116+
$else:
117+
outtex[i] = value;
112118
}
113119

114120
write_texel(t_out, pos, outtex);

backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ dequantize_texture:
1111
OUT_DTYPE:
1212
- VALUE: half
1313
- VALUE: float
14+
- VALUE: double
1415
shader_variants:
1516
- NAME: dequantize_per_tensor_texture3d
1617
MODE: per_tensor

backends/vulkan/runtime/graph/ops/glsl/quantize_buffer.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ quantize_buffer:
77
IN_DTYPE:
88
- VALUE: half
99
- VALUE: float
10+
- VALUE: double
1011
OUT_DTYPE:
1112
- VALUE: uint8
1213
- VALUE: int8

backends/vulkan/runtime/graph/ops/glsl/quantize_texture.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ quantize_texture:
77
IN_DTYPE:
88
- VALUE: half
99
- VALUE: float
10+
- VALUE: double
1011
OUT_DTYPE:
1112
- VALUE: uint8
1213
- VALUE: int8

backends/vulkan/runtime/graph/ops/impl/Quantize.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,7 @@ void quantize_per_tensor_impl(
196196

197197
// Verify input is a floating point type
198198
VK_CHECK_COND(
199+
graph.dtype_of(input) == vkapi::kDouble ||
199200
graph.dtype_of(input) == vkapi::kFloat ||
200201
graph.dtype_of(input) == vkapi::kHalf);
201202

@@ -222,6 +223,7 @@ void quantize_per_token_impl(
222223

223224
// Verify input is a floating point type
224225
VK_CHECK_COND(
226+
graph.dtype_of(input) == vkapi::kDouble ||
225227
graph.dtype_of(input) == vkapi::kFloat ||
226228
graph.dtype_of(input) == vkapi::kHalf);
227229

backends/vulkan/test/op_tests/dequantize_test.cpp

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -366,6 +366,12 @@ void test_vulkan_dequantize_per_tensor(
366366
vkcompute::utils::kBuffer,
367367
vkcompute::utils::kBuffer);
368368

369+
// Telling the system to expect a float instead of a double
370+
// since the shader can only return 32bit anyways
371+
if (out_dtype == at::kDouble) {
372+
out_dtype = at::kFloat;
373+
}
374+
369375
// Test with texture storage
370376
test_vulkan_dequantize_per_tensor_impl(
371377
input_sizes,
@@ -400,6 +406,12 @@ void test_vulkan_dequantize_per_token(
400406
vkcompute::utils::kBuffer,
401407
vkcompute::utils::kBuffer);
402408

409+
// Telling the system to expect a float instead of a double
410+
// since the shader can only return 32bit anyways
411+
if (out_dtype == at::kDouble) {
412+
out_dtype = at::kFloat;
413+
}
414+
403415
// Test with texture storage
404416
test_vulkan_dequantize_per_token_impl(
405417
input_sizes,
@@ -793,6 +805,19 @@ TEST(
793805
at::kHalf); // output dtype
794806
}
795807

808+
TEST(
809+
VulkanDequantizePerTensorTest,
810+
test_vulkan_dequantize_per_tensor_int32_to_double) {
811+
test_vulkan_dequantize_per_tensor(
812+
{2, 4, 3}, // input sizes
813+
0.0001, // scale
814+
100, // zero_point
815+
-2147483648, // quant_min
816+
2147483647, // quant_max
817+
at::kInt, // input dtype
818+
at::kDouble); // output dtype
819+
}
820+
796821
void test_reference_dequantize_per_token(
797822
const std::vector<int>& input_sizes,
798823
const std::vector<float>& scales,
@@ -1281,3 +1306,19 @@ TEST(
12811306
at::kInt, // input dtype
12821307
at::kHalf); // output dtype
12831308
}
1309+
1310+
TEST(
1311+
VulkanDequantizePerTokenTest,
1312+
test_vulkan_dequantize_per_token_int32_to_double) {
1313+
std::vector<float> scales = {0.0001, 0.0002, 0.0003, 0.0};
1314+
std::vector<int> zero_points = {100, -100, 50, -50};
1315+
1316+
test_vulkan_dequantize_per_token(
1317+
{2, 2, 8}, // input sizes (2*2=4 tokens)
1318+
scales,
1319+
zero_points,
1320+
-2147483648, // quant_min
1321+
2147483647, // quant_max
1322+
at::kInt, // input dtype
1323+
at::kDouble); // output dtype
1324+
}

backends/vulkan/test/op_tests/quantize_test.cpp

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,12 @@ void test_vulkan_quantize_per_tensor(
315315
vkcompute::utils::kBuffer,
316316
vkcompute::utils::kBuffer);
317317

318+
// If the in_dtype is a double, convert to float for texture implementation
319+
// since they don't support 64bit as inputs
320+
if (in_dtype == at::kDouble) {
321+
in_dtype = at::kFloat;
322+
}
323+
318324
// Test with texture storage
319325
test_vulkan_quantize_per_tensor_impl(
320326
input_sizes,
@@ -349,6 +355,12 @@ void test_vulkan_quantize_per_token(
349355
vkcompute::utils::kBuffer,
350356
vkcompute::utils::kBuffer);
351357

358+
// If the in_dtype is a double, convert to float for texture implementation
359+
// since they don't support 64bit as inputs
360+
if (in_dtype == at::kDouble) {
361+
in_dtype = at::kFloat;
362+
}
363+
352364
// Test with texture storage
353365
test_vulkan_quantize_per_token_impl(
354366
input_sizes,
@@ -655,6 +667,19 @@ TEST(
655667
at::kChar); // output dtype
656668
}
657669

670+
TEST(
671+
VulkanQuantizePerTensorTest,
672+
test_vulkan_quantize_per_tensor_double_to_int8) {
673+
test_vulkan_quantize_per_tensor(
674+
{2, 3}, // input sizes
675+
0.01, // scale
676+
1, // zero_point
677+
-128, // quant_min
678+
127, // quant_max
679+
at::kDouble, // input dtype
680+
at::kChar); // output dtype
681+
}
682+
658683
void test_reference_quantize_per_token(
659684
const std::vector<int>& input_sizes,
660685
const std::vector<float>& pre_scales,
@@ -1066,3 +1091,19 @@ TEST(VulkanQuantizePerTensorTest, test_vulkan_quantize_per_token_half_to_int8) {
10661091
at::kHalf, // input dtype
10671092
at::kChar); // output dtype
10681093
}
1094+
1095+
TEST(
1096+
VulkanQuantizePerTensorTest,
1097+
test_vulkan_quantize_per_token_double_to_int8) {
1098+
std::vector<float> scales = {0.1, 0.2};
1099+
std::vector<int> zero_points = {0, 5};
1100+
1101+
test_vulkan_quantize_per_token(
1102+
{2, 2}, // input sizes (2*2=4 tokens)
1103+
scales,
1104+
zero_points,
1105+
-128, // quant_min
1106+
127, // quant_max
1107+
at::kDouble, // input dtype
1108+
at::kChar); // output dtype
1109+
}

0 commit comments

Comments
 (0)