Skip to content

Commit e9b7bea

Browse files
authored
[ET-VK][Ops] enabling double support for quantization and dequantization ops
Differential Revision: D76289197 Pull Request resolved: #11553
1 parent 6dab7cc commit e9b7bea

File tree

8 files changed

+116
-2
lines changed

8 files changed

+116
-2
lines changed

backends/vulkan/runtime/graph/ops/glsl/dequantize_buffer.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ dequantize_buffer:
1111
OUT_DTYPE:
1212
- VALUE: half
1313
- VALUE: float
14+
- VALUE: double
1415
shader_variants:
1516
- NAME: dequantize_per_tensor_buffer
1617
MODE: per_tensor

backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.glsl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,10 @@ void dequantize_per_tensor() {
139139
[[unroll]] for (int i = 0; i < 4; ++i) {
140140
IN_T qvalue = IN_T(intex[i]);
141141
OUT_T value = dequantize_val(qvalue, scale, zero_point);
142-
outtex[i] = value;
142+
$if OUT_DTYPE == "double":
143+
outtex[i] = float(value);
144+
$else:
145+
outtex[i] = value;
143146
}
144147
write_texel(t_out, pos, outtex);
145148
}
@@ -177,7 +180,10 @@ void dequantize_per_token() {
177180
[[unroll]] for (int i = 0; i < 4; ++i) {
178181
IN_T qvalue = IN_T(intex[i]);
179182
OUT_T value = dequantize_val(qvalue, scale_val, zero_point_val);
180-
outtex[i] = value;
183+
$if OUT_DTYPE == "double":
184+
outtex[i] = float(value);
185+
$else:
186+
outtex[i] = value;
181187
}
182188

183189
write_texel(t_out, pos, outtex);

backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ dequantize_texture:
1111
OUT_DTYPE:
1212
- VALUE: half
1313
- VALUE: float
14+
- VALUE: double
1415
shader_variants:
1516
- NAME: dequantize_per_tensor_texture3d
1617
MODE: per_tensor

backends/vulkan/runtime/graph/ops/glsl/quantize_buffer.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ quantize_buffer:
77
IN_DTYPE:
88
- VALUE: half
99
- VALUE: float
10+
- VALUE: double
1011
OUT_DTYPE:
1112
- VALUE: uint8
1213
- VALUE: int8

backends/vulkan/runtime/graph/ops/glsl/quantize_texture.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ quantize_texture:
77
IN_DTYPE:
88
- VALUE: half
99
- VALUE: float
10+
- VALUE: double
1011
OUT_DTYPE:
1112
- VALUE: uint8
1213
- VALUE: int8

backends/vulkan/runtime/graph/ops/impl/Quantize.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,7 @@ void quantize_per_tensor_impl(
188188

189189
// Verify input is a floating point type
190190
VK_CHECK_COND(
191+
graph.dtype_of(input) == vkapi::kDouble ||
191192
graph.dtype_of(input) == vkapi::kFloat ||
192193
graph.dtype_of(input) == vkapi::kHalf);
193194

@@ -214,6 +215,7 @@ void quantize_per_token_impl(
214215

215216
// Verify input is a floating point type
216217
VK_CHECK_COND(
218+
graph.dtype_of(input) == vkapi::kDouble ||
217219
graph.dtype_of(input) == vkapi::kFloat ||
218220
graph.dtype_of(input) == vkapi::kHalf);
219221

backends/vulkan/test/op_tests/dequantize_test.cpp

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -366,6 +366,12 @@ void test_vulkan_dequantize_per_tensor(
366366
vkcompute::utils::kBuffer,
367367
vkcompute::utils::kBuffer);
368368

369+
// Telling the system to expect a float instead of a double
370+
// since the shader can only return 32bit anyways
371+
if (out_dtype == at::kDouble) {
372+
out_dtype = at::kFloat;
373+
}
374+
369375
// Test with texture storage
370376
test_vulkan_dequantize_per_tensor_impl(
371377
input_sizes,
@@ -400,6 +406,12 @@ void test_vulkan_dequantize_per_token(
400406
vkcompute::utils::kBuffer,
401407
vkcompute::utils::kBuffer);
402408

409+
// Telling the system to expect a float instead of a double
410+
// since the shader can only return 32bit anyways
411+
if (out_dtype == at::kDouble) {
412+
out_dtype = at::kFloat;
413+
}
414+
403415
// Test with texture storage
404416
test_vulkan_dequantize_per_token_impl(
405417
input_sizes,
@@ -793,6 +805,24 @@ TEST(
793805
at::kHalf); // output dtype
794806
}
795807

808+
TEST(
809+
VulkanDequantizePerTensorTest,
810+
test_vulkan_dequantize_per_tensor_int8_to_double) {
811+
if (!vkcompute::api::context()
812+
->adapter_ptr()
813+
->has_full_int8_buffers_support()) {
814+
GTEST_SKIP();
815+
}
816+
test_vulkan_dequantize_per_tensor(
817+
{2, 3}, // input sizes
818+
0.05, // scale
819+
10, // zero_point
820+
-128, // quant_min
821+
127, // quant_max
822+
at::kChar, // input dtype
823+
at::kDouble); // output dtype
824+
}
825+
796826
void test_reference_dequantize_per_token(
797827
const std::vector<int>& input_sizes,
798828
const std::vector<float>& scales,
@@ -1288,3 +1318,24 @@ TEST(
12881318
at::kInt, // input dtype
12891319
at::kHalf); // output dtype
12901320
}
1321+
1322+
TEST(
1323+
VulkanDequantizePerTokenTest,
1324+
test_vulkan_dequantize_per_token_int8_to_double) {
1325+
if (!vkcompute::api::context()
1326+
->adapter_ptr()
1327+
->has_full_int8_buffers_support()) {
1328+
GTEST_SKIP();
1329+
}
1330+
std::vector<float> scales = {0.05, 0.001};
1331+
std::vector<int> zero_points = {10, -5};
1332+
1333+
test_vulkan_dequantize_per_token(
1334+
{2, 2}, // input sizes (2 tokens)
1335+
scales,
1336+
zero_points,
1337+
-128, // quant_min
1338+
127, // quant_max
1339+
at::kChar, // input dtype
1340+
at::kDouble); // output dtype
1341+
}

backends/vulkan/test/op_tests/quantize_test.cpp

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,12 @@ void test_vulkan_quantize_per_tensor(
315315
vkcompute::utils::kBuffer,
316316
vkcompute::utils::kBuffer);
317317

318+
// If the in_dtype is a double, convert to float for texture implementation
319+
// since they don't support 64bit as inputs
320+
if (in_dtype == at::kDouble) {
321+
in_dtype = at::kFloat;
322+
}
323+
318324
// Test with texture storage
319325
test_vulkan_quantize_per_tensor_impl(
320326
input_sizes,
@@ -349,6 +355,12 @@ void test_vulkan_quantize_per_token(
349355
vkcompute::utils::kBuffer,
350356
vkcompute::utils::kBuffer);
351357

358+
// If the in_dtype is a double, convert to float for texture implementation
359+
// since they don't support 64bit as inputs
360+
if (in_dtype == at::kDouble) {
361+
in_dtype = at::kFloat;
362+
}
363+
352364
// Test with texture storage
353365
test_vulkan_quantize_per_token_impl(
354366
input_sizes,
@@ -655,6 +667,24 @@ TEST(
655667
at::kChar); // output dtype
656668
}
657669

670+
TEST(
671+
VulkanQuantizePerTensorTest,
672+
test_vulkan_quantize_per_tensor_double_to_int8) {
673+
if (!vkcompute::api::context()
674+
->adapter_ptr()
675+
->has_full_int8_buffers_support()) {
676+
GTEST_SKIP();
677+
}
678+
test_vulkan_quantize_per_tensor(
679+
{2, 3}, // input sizes
680+
0.01, // scale
681+
1, // zero_point
682+
-128, // quant_min
683+
127, // quant_max
684+
at::kDouble, // input dtype
685+
at::kChar); // output dtype
686+
}
687+
658688
void test_reference_quantize_per_token(
659689
const std::vector<int>& input_sizes,
660690
const std::vector<float>& pre_scales,
@@ -1075,3 +1105,24 @@ TEST(VulkanQuantizePerTensorTest, test_vulkan_quantize_per_token_half_to_int8) {
10751105
at::kHalf, // input dtype
10761106
at::kChar); // output dtype
10771107
}
1108+
1109+
TEST(
1110+
VulkanQuantizePerTensorTest,
1111+
test_vulkan_quantize_per_token_double_to_int8) {
1112+
if (!vkcompute::api::context()
1113+
->adapter_ptr()
1114+
->has_full_int8_buffers_support()) {
1115+
GTEST_SKIP();
1116+
}
1117+
std::vector<float> scales = {0.1, 0.2};
1118+
std::vector<int> zero_points = {0, 5};
1119+
1120+
test_vulkan_quantize_per_token(
1121+
{2, 2}, // input sizes (2*2=4 tokens)
1122+
scales,
1123+
zero_points,
1124+
-128, // quant_min
1125+
127, // quant_max
1126+
at::kDouble, // input dtype
1127+
at::kChar); // output dtype
1128+
}

0 commit comments

Comments
 (0)