From 37fb6b6cede784a99b45c484134f242cf3ef0308 Mon Sep 17 00:00:00 2001 From: morelos Date: Wed, 11 Jun 2025 09:59:32 -0700 Subject: [PATCH] [ET-VK][Ops] choose_qparams ops skeleton test framework Skeleton framework that is needed to build out the choose_qparams.tensor and choose_qparams_per_token_asymmetric.default operators based on cpu implementation Differential Revision: [D76436870](https://our.internmc.facebook.com/intern/diff/D76436870/) [ghstack-poisoned] --- .../test/op_tests/choose_qparams_test.cpp | 234 ++++++++++++++++++ backends/vulkan/test/op_tests/targets.bzl | 8 + 2 files changed, 242 insertions(+) create mode 100644 backends/vulkan/test/op_tests/choose_qparams_test.cpp diff --git a/backends/vulkan/test/op_tests/choose_qparams_test.cpp b/backends/vulkan/test/op_tests/choose_qparams_test.cpp new file mode 100644 index 00000000000..825811c4df0 --- /dev/null +++ b/backends/vulkan/test/op_tests/choose_qparams_test.cpp @@ -0,0 +1,234 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +#include +#include +#include + +#include +#include + +#include +#include + +namespace torch { +namespace executor { +namespace native { + +// Forward declarations of the functions we're testing +std::tuple choose_qparams_tensor_out( + const Tensor& input, + int64_t quant_min, + int64_t quant_max, + ET_UNUSED double eps, + ScalarType dtype, + Tensor& scale_out, + Tensor& zero_point_out); + +std::tuple choose_qparams_per_token_asymmetric_out( + const Tensor& input, + ScalarType dtype, + Tensor& scale_out, + Tensor& zero_point_out); + +// Wrapper function for choose_qparams_tensor_out without context +Tensor& choose_qparams_tensor_out_no_context( + const Tensor& input, + int64_t quant_min, + int64_t quant_max, + ET_UNUSED double eps, + ScalarType dtype, + Tensor& scale_out, + Tensor& zero_point_out) { + torch::executor::native::choose_qparams_tensor_out( + input, quant_min, quant_max, eps, dtype, scale_out, zero_point_out); + return scale_out; +} + +// Wrapper function for choose_qparams_per_token_asymmetric_out without context +Tensor& choose_qparams_per_token_asymmetric_out_no_context( + const Tensor& input, + ScalarType dtype, + Tensor& scale_out, + Tensor& zero_point_out) { + torch::executor::native::choose_qparams_per_token_asymmetric_out( + input, dtype, scale_out, zero_point_out); + return scale_out; +} + +// ATen wrapper for choose_qparams_tensor +std::tuple choose_qparams_tensor_aten( + const at::Tensor& input, + int64_t quant_min, + int64_t quant_max, + at::ScalarType dtype) { + auto scale_out = at::empty({}, at::device(at::kCPU).dtype(at::kDouble)); + auto zero_point_out = at::empty({}, at::device(at::kCPU).dtype(at::kLong)); + double eps = 1e-7; + + // Convert at::ScalarType to executorch::ScalarType + ScalarType et_dtype; + switch (dtype) { + case at::kByte: + et_dtype = ScalarType::Byte; + break; + case at::kChar: + et_dtype = ScalarType::Char; + break; + case at::kShort: + et_dtype = ScalarType::Short; + break; + case at::kInt: + et_dtype = ScalarType::Int; + break; + case at::kLong: + et_dtype = ScalarType::Long; + break; + case at::kFloat: + et_dtype = ScalarType::Float; + break; + case at::kDouble: + et_dtype = ScalarType::Double; + break; + default: + throw std::runtime_error("Unsupported dtype"); + } + + // Use WRAP_TO_ATEN with the wrapper function + WRAP_TO_ATEN(choose_qparams_tensor_out_no_context, 5) + (input, quant_min, quant_max, eps, et_dtype, scale_out, zero_point_out); + + return {scale_out, zero_point_out}; +} + +// ATen wrapper for choose_qparams_per_token_asymmetric +std::tuple choose_qparams_per_token_asymmetric_aten( + const at::Tensor& input, + at::ScalarType dtype) { + // Calculate output sizes for scale and zero_point tensors + std::vector output_sizes; + for (int64_t i = 0; i < input.dim() - 1; i++) { + output_sizes.push_back(input.size(i)); + } + output_sizes.push_back(1); + + auto scale_out = + at::empty(output_sizes, at::device(at::kCPU).dtype(at::kDouble)); + auto zero_point_out = + at::empty(output_sizes, at::device(at::kCPU).dtype(at::kLong)); + + // Convert at::ScalarType to executorch::ScalarType + ScalarType et_dtype; + switch (dtype) { + case at::kByte: + et_dtype = ScalarType::Byte; + break; + case at::kChar: + et_dtype = ScalarType::Char; + break; + case at::kShort: + et_dtype = ScalarType::Short; + break; + case at::kInt: + et_dtype = ScalarType::Int; + break; + case at::kLong: + et_dtype = ScalarType::Long; + break; + case at::kFloat: + et_dtype = ScalarType::Float; + break; + case at::kDouble: + et_dtype = ScalarType::Double; + break; + default: + throw std::runtime_error("Unsupported dtype"); + } + + // Use WRAP_TO_ATEN with the wrapper function + WRAP_TO_ATEN(choose_qparams_per_token_asymmetric_out_no_context, 2) + (input, et_dtype, scale_out, zero_point_out); + + return {scale_out, zero_point_out}; +} + +} // namespace native +} // namespace executor +} // namespace torch + +// +// Test functions +// + +// Helper function to get the name of a ScalarType for better error messages +std::string scalar_type_name(c10::ScalarType dtype) { + switch (dtype) { + case c10::kLong: + return "c10::kLong"; + case c10::kShort: + return "c10::kShort"; + case c10::kComplexHalf: + return "c10::kComplexHalf"; + case c10::kComplexFloat: + return "c10::kComplexFloat"; + case c10::kComplexDouble: + return "c10::kComplexDouble"; + case c10::kBool: + return "c10::kBool"; + case c10::kQInt8: + return "c10::kQInt8"; + case c10::kQUInt8: + return "c10::kQUInt8"; + case c10::kQInt32: + return "c10::kQInt32"; + case c10::kBFloat16: + return "c10::kBFloat16"; + case c10::kQUInt4x2: + return "c10::kQUInt4x2"; + case c10::kQUInt2x4: + return "c10::kQUInt2x4"; + default: + return "Unknown(" + std::to_string(static_cast(dtype)) + ")"; + } +} + +vkcompute::vkapi::ScalarType from_at_scalartype(c10::ScalarType at_scalartype) { + using namespace vkcompute; + switch (at_scalartype) { + case c10::kFloat: + return vkapi::kFloat; + case c10::kHalf: + return vkapi::kHalf; + case c10::kInt: + return vkapi::kInt; + case c10::kLong: + // We don't have inherent vkapi::kLong, use kInt instead + return vkapi::kInt; + case c10::kChar: + return vkapi::kChar; + case c10::kByte: + return vkapi::kByte; + case c10::kDouble: + return vkapi::kDouble; + case c10::kShort: + return vkapi::kShort; + case c10::kUInt16: + return vkapi::kUInt16; + default: + VK_THROW( + "Unsupported at::ScalarType: ", + scalar_type_name(at_scalartype), + " (", + static_cast(at_scalartype), + ")"); + } +} diff --git a/backends/vulkan/test/op_tests/targets.bzl b/backends/vulkan/test/op_tests/targets.bzl index f8da9b7e3e9..7eb8201b260 100644 --- a/backends/vulkan/test/op_tests/targets.bzl +++ b/backends/vulkan/test/op_tests/targets.bzl @@ -164,5 +164,13 @@ def define_common_targets(is_fbcode = False): "//executorch/extension/aten_util:aten_bridge", ] ) + define_test_targets( + "choose_qparams_test", + extra_deps = [ + "//executorch/kernels/quantized/cpu:op_choose_qparams", + "//executorch/extension/tensor:tensor", + "//executorch/extension/aten_util:aten_bridge", + ] + ) define_test_targets("linear_weight_int4_test") define_test_targets("rotary_embedding_test")