From 0c8fb18d0ad51d34da54cfcdf04d0f754f8d4872 Mon Sep 17 00:00:00 2001 From: morelos Date: Wed, 18 Jun 2025 08:23:26 -0700 Subject: [PATCH] [ET-VK][Ops] enabling double support for quantization and dequantization ops Pull Request resolved: https://github.com/pytorch/executorch/pull/11553 # Context Since we enabled the possibility for double support in an earlier diff, this enables double support for quantization and dequantization. Since there are limitations to how 64bit can be supported, the expectation is that IO is to be downgraded to 32bit. # Changes We create additional test cases for double support and make sure to pass in the double if its permitted (it's only allowed in buffers), and we also make sure to include double variants in the corresponding YAML files for quantization and dequantization. ghstack-source-id: 291249593 @exported-using-ghexport Differential Revision: [D76289197](https://our.internmc.facebook.com/intern/diff/D76289197/) --- .../graph/ops/glsl/dequantize_buffer.yaml | 1 + .../graph/ops/glsl/dequantize_texture.glsl | 10 +++- .../graph/ops/glsl/dequantize_texture.yaml | 1 + .../graph/ops/glsl/quantize_buffer.yaml | 1 + .../graph/ops/glsl/quantize_texture.yaml | 1 + .../runtime/graph/ops/impl/Quantize.cpp | 2 + .../vulkan/test/op_tests/dequantize_test.cpp | 51 +++++++++++++++++++ .../vulkan/test/op_tests/quantize_test.cpp | 51 +++++++++++++++++++ 8 files changed, 116 insertions(+), 2 deletions(-) diff --git a/backends/vulkan/runtime/graph/ops/glsl/dequantize_buffer.yaml b/backends/vulkan/runtime/graph/ops/glsl/dequantize_buffer.yaml index 4e434935356..fb0d2ee61bf 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/dequantize_buffer.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/dequantize_buffer.yaml @@ -11,6 +11,7 @@ dequantize_buffer: OUT_DTYPE: - VALUE: half - VALUE: float + - VALUE: double shader_variants: - NAME: dequantize_per_tensor_buffer MODE: per_tensor diff --git a/backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.glsl b/backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.glsl index cfc61dd1816..801f4a2f6a2 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.glsl @@ -139,7 +139,10 @@ void dequantize_per_tensor() { [[unroll]] for (int i = 0; i < 4; ++i) { IN_T qvalue = IN_T(intex[i]); OUT_T value = dequantize_val(qvalue, scale, zero_point); - outtex[i] = value; + $if OUT_DTYPE == "double": + outtex[i] = float(value); + $else: + outtex[i] = value; } write_texel(t_out, pos, outtex); } @@ -177,7 +180,10 @@ void dequantize_per_token() { [[unroll]] for (int i = 0; i < 4; ++i) { IN_T qvalue = IN_T(intex[i]); OUT_T value = dequantize_val(qvalue, scale_val, zero_point_val); - outtex[i] = value; + $if OUT_DTYPE == "double": + outtex[i] = float(value); + $else: + outtex[i] = value; } write_texel(t_out, pos, outtex); diff --git a/backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.yaml b/backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.yaml index fc8c18468ed..7d19a543a03 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.yaml @@ -11,6 +11,7 @@ dequantize_texture: OUT_DTYPE: - VALUE: half - VALUE: float + - VALUE: double shader_variants: - NAME: dequantize_per_tensor_texture3d MODE: per_tensor diff --git a/backends/vulkan/runtime/graph/ops/glsl/quantize_buffer.yaml b/backends/vulkan/runtime/graph/ops/glsl/quantize_buffer.yaml index 90af2590936..4d95d610314 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/quantize_buffer.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/quantize_buffer.yaml @@ -7,6 +7,7 @@ quantize_buffer: IN_DTYPE: - VALUE: half - VALUE: float + - VALUE: double OUT_DTYPE: - VALUE: uint8 - VALUE: int8 diff --git a/backends/vulkan/runtime/graph/ops/glsl/quantize_texture.yaml b/backends/vulkan/runtime/graph/ops/glsl/quantize_texture.yaml index 042eb0f8196..65002ce26b6 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/quantize_texture.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/quantize_texture.yaml @@ -7,6 +7,7 @@ quantize_texture: IN_DTYPE: - VALUE: half - VALUE: float + - VALUE: double OUT_DTYPE: - VALUE: uint8 - VALUE: int8 diff --git a/backends/vulkan/runtime/graph/ops/impl/Quantize.cpp b/backends/vulkan/runtime/graph/ops/impl/Quantize.cpp index 35712d59fb9..49277b4d718 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Quantize.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Quantize.cpp @@ -188,6 +188,7 @@ void quantize_per_tensor_impl( // Verify input is a floating point type VK_CHECK_COND( + graph.dtype_of(input) == vkapi::kDouble || graph.dtype_of(input) == vkapi::kFloat || graph.dtype_of(input) == vkapi::kHalf); @@ -214,6 +215,7 @@ void quantize_per_token_impl( // Verify input is a floating point type VK_CHECK_COND( + graph.dtype_of(input) == vkapi::kDouble || graph.dtype_of(input) == vkapi::kFloat || graph.dtype_of(input) == vkapi::kHalf); diff --git a/backends/vulkan/test/op_tests/dequantize_test.cpp b/backends/vulkan/test/op_tests/dequantize_test.cpp index 1ec0602a4f2..6c604076c41 100644 --- a/backends/vulkan/test/op_tests/dequantize_test.cpp +++ b/backends/vulkan/test/op_tests/dequantize_test.cpp @@ -366,6 +366,12 @@ void test_vulkan_dequantize_per_tensor( vkcompute::utils::kBuffer, vkcompute::utils::kBuffer); + // Telling the system to expect a float instead of a double + // since the shader can only return 32bit anyways + if (out_dtype == at::kDouble) { + out_dtype = at::kFloat; + } + // Test with texture storage test_vulkan_dequantize_per_tensor_impl( input_sizes, @@ -400,6 +406,12 @@ void test_vulkan_dequantize_per_token( vkcompute::utils::kBuffer, vkcompute::utils::kBuffer); + // Telling the system to expect a float instead of a double + // since the shader can only return 32bit anyways + if (out_dtype == at::kDouble) { + out_dtype = at::kFloat; + } + // Test with texture storage test_vulkan_dequantize_per_token_impl( input_sizes, @@ -793,6 +805,24 @@ TEST( at::kHalf); // output dtype } +TEST( + VulkanDequantizePerTensorTest, + test_vulkan_dequantize_per_tensor_int8_to_double) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } + test_vulkan_dequantize_per_tensor( + {2, 3}, // input sizes + 0.05, // scale + 10, // zero_point + -128, // quant_min + 127, // quant_max + at::kChar, // input dtype + at::kDouble); // output dtype +} + void test_reference_dequantize_per_token( const std::vector& input_sizes, const std::vector& scales, @@ -1288,3 +1318,24 @@ TEST( at::kInt, // input dtype at::kHalf); // output dtype } + +TEST( + VulkanDequantizePerTokenTest, + test_vulkan_dequantize_per_token_int8_to_double) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } + std::vector scales = {0.05, 0.001}; + std::vector zero_points = {10, -5}; + + test_vulkan_dequantize_per_token( + {2, 2}, // input sizes (2 tokens) + scales, + zero_points, + -128, // quant_min + 127, // quant_max + at::kChar, // input dtype + at::kDouble); // output dtype +} diff --git a/backends/vulkan/test/op_tests/quantize_test.cpp b/backends/vulkan/test/op_tests/quantize_test.cpp index 7ea98b14fb2..150bda6989e 100644 --- a/backends/vulkan/test/op_tests/quantize_test.cpp +++ b/backends/vulkan/test/op_tests/quantize_test.cpp @@ -315,6 +315,12 @@ void test_vulkan_quantize_per_tensor( vkcompute::utils::kBuffer, vkcompute::utils::kBuffer); + // If the in_dtype is a double, convert to float for texture implementation + // since they don't support 64bit as inputs + if (in_dtype == at::kDouble) { + in_dtype = at::kFloat; + } + // Test with texture storage test_vulkan_quantize_per_tensor_impl( input_sizes, @@ -349,6 +355,12 @@ void test_vulkan_quantize_per_token( vkcompute::utils::kBuffer, vkcompute::utils::kBuffer); + // If the in_dtype is a double, convert to float for texture implementation + // since they don't support 64bit as inputs + if (in_dtype == at::kDouble) { + in_dtype = at::kFloat; + } + // Test with texture storage test_vulkan_quantize_per_token_impl( input_sizes, @@ -655,6 +667,24 @@ TEST( at::kChar); // output dtype } +TEST( + VulkanQuantizePerTensorTest, + test_vulkan_quantize_per_tensor_double_to_int8) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } + test_vulkan_quantize_per_tensor( + {2, 3}, // input sizes + 0.01, // scale + 1, // zero_point + -128, // quant_min + 127, // quant_max + at::kDouble, // input dtype + at::kChar); // output dtype +} + void test_reference_quantize_per_token( const std::vector& input_sizes, const std::vector& pre_scales, @@ -1075,3 +1105,24 @@ TEST(VulkanQuantizePerTensorTest, test_vulkan_quantize_per_token_half_to_int8) { at::kHalf, // input dtype at::kChar); // output dtype } + +TEST( + VulkanQuantizePerTensorTest, + test_vulkan_quantize_per_token_double_to_int8) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } + std::vector scales = {0.1, 0.2}; + std::vector zero_points = {0, 5}; + + test_vulkan_quantize_per_token( + {2, 2}, // input sizes (2*2=4 tokens) + scales, + zero_points, + -128, // quant_min + 127, // quant_max + at::kDouble, // input dtype + at::kChar); // output dtype +}