Skip to content

Commit 1c3e74f

Browse files
authored
[ET-VK][Ops] choose_qparams ops skeleton test framework
Differential Revision: D76436870 Pull Request resolved: #11554
1 parent 83d4735 commit 1c3e74f

File tree

2 files changed

+125
-0
lines changed

2 files changed

+125
-0
lines changed
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <gtest/gtest.h>
10+
11+
#include <ATen/ATen.h>
12+
13+
#include <executorch/backends/vulkan/runtime/api/api.h>
14+
#include <executorch/backends/vulkan/runtime/graph/ComputeGraph.h>
15+
#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
16+
17+
#include <executorch/extension/aten_util/make_aten_functor_from_et_functor.h>
18+
#include <executorch/extension/kernel_util/make_boxed_from_unboxed_functor.h>
19+
20+
#include "test_utils.h"
21+
22+
#include <cassert>
23+
#include <iostream>
24+
25+
namespace torch {
26+
namespace executor {
27+
namespace native {
28+
29+
// Forward declarations of the functions we're testing
30+
std::tuple<Tensor&, Tensor&> choose_qparams_tensor_out(
31+
const Tensor& input,
32+
int64_t quant_min,
33+
int64_t quant_max,
34+
ET_UNUSED double eps,
35+
ScalarType dtype,
36+
Tensor& scale_out,
37+
Tensor& zero_point_out);
38+
39+
std::tuple<Tensor&, Tensor&> choose_qparams_per_token_asymmetric_out(
40+
const Tensor& input,
41+
ScalarType dtype,
42+
Tensor& scale_out,
43+
Tensor& zero_point_out);
44+
45+
// Wrapper function for choose_qparams_tensor_out without context
46+
Tensor& choose_qparams_tensor_out_no_context(
47+
const Tensor& input,
48+
int64_t quant_min,
49+
int64_t quant_max,
50+
ET_UNUSED double eps,
51+
ScalarType dtype,
52+
Tensor& scale_out,
53+
Tensor& zero_point_out) {
54+
torch::executor::native::choose_qparams_tensor_out(
55+
input, quant_min, quant_max, eps, dtype, scale_out, zero_point_out);
56+
return scale_out;
57+
}
58+
59+
// Wrapper function for choose_qparams_per_token_asymmetric_out without context
60+
Tensor& choose_qparams_per_token_asymmetric_out_no_context(
61+
const Tensor& input,
62+
ScalarType dtype,
63+
Tensor& scale_out,
64+
Tensor& zero_point_out) {
65+
torch::executor::native::choose_qparams_per_token_asymmetric_out(
66+
input, dtype, scale_out, zero_point_out);
67+
return scale_out;
68+
}
69+
70+
// ATen wrapper for choose_qparams_tensor
71+
std::tuple<at::Tensor, at::Tensor> choose_qparams_tensor_aten(
72+
const at::Tensor& input,
73+
int64_t quant_min,
74+
int64_t quant_max,
75+
at::ScalarType dtype) {
76+
auto scale_out = at::empty({}, at::device(at::kCPU).dtype(at::kDouble));
77+
auto zero_point_out = at::empty({}, at::device(at::kCPU).dtype(at::kLong));
78+
double eps = 1e-7;
79+
80+
ScalarType et_dtype = at_scalartype_to_et_scalartype(dtype);
81+
82+
// Use WRAP_TO_ATEN with the wrapper function
83+
WRAP_TO_ATEN(choose_qparams_tensor_out_no_context, 5)
84+
(input, quant_min, quant_max, eps, et_dtype, scale_out, zero_point_out);
85+
86+
return {scale_out, zero_point_out};
87+
}
88+
89+
// ATen wrapper for choose_qparams_per_token_asymmetric
90+
std::tuple<at::Tensor, at::Tensor> choose_qparams_per_token_asymmetric_aten(
91+
const at::Tensor& input,
92+
at::ScalarType dtype) {
93+
// Calculate output sizes for scale and zero_point tensors
94+
std::vector<int64_t> output_sizes;
95+
for (int64_t i = 0; i < input.dim() - 1; i++) {
96+
output_sizes.push_back(input.size(i));
97+
}
98+
output_sizes.push_back(1);
99+
100+
auto scale_out =
101+
at::empty(output_sizes, at::device(at::kCPU).dtype(at::kDouble));
102+
auto zero_point_out =
103+
at::empty(output_sizes, at::device(at::kCPU).dtype(at::kLong));
104+
105+
ScalarType et_dtype = at_scalartype_to_et_scalartype(dtype);
106+
107+
// Use WRAP_TO_ATEN with the wrapper function
108+
WRAP_TO_ATEN(choose_qparams_per_token_asymmetric_out_no_context, 2)
109+
(input, et_dtype, scale_out, zero_point_out);
110+
111+
return {scale_out, zero_point_out};
112+
}
113+
114+
} // namespace native
115+
} // namespace executor
116+
} // namespace torch

backends/vulkan/test/op_tests/targets.bzl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,15 @@ def define_common_targets(is_fbcode = False):
195195
"//executorch/extension/aten_util:aten_bridge",
196196
]
197197
)
198+
define_test_targets(
199+
"choose_qparams_test",
200+
extra_deps = [
201+
":test_utils",
202+
"//executorch/kernels/quantized/cpu:op_choose_qparams",
203+
"//executorch/extension/tensor:tensor",
204+
"//executorch/extension/aten_util:aten_bridge",
205+
]
206+
)
198207
define_test_targets(
199208
"linear_weight_int4_test",
200209
extra_deps = [

0 commit comments

Comments
 (0)