Skip to content

Commit 92d9b14

Browse files
author
morelos
committed
[ET] enabling half dtype input for quantization
Pull Request resolved: #11479 # Context Currently the cpu implementation for the quantization operator (which includes `quantize_per_token`, `quantize_per_tensor`, and `quantize_per_channel`), does not inherently support half (fp16) input scalar types. In order to align with the PyTorch implementation that accepts fp16 and bfp16 inputs, this diff aims to enable half input dtype support for the quantization operators. We will be comparing this implementation against the vulkan operators. # Changes As defined in ExecuTorch [scalar_type_util.h](https://github.com/pytorch/executorch/blob/053686242c1687f0d51b3bb8befd14b047d7b025/runtime/core/exec_aten/util/scalar_type_util.h#L190) file, there is a method to enable support simply changing which preprocessor is called to ET_FORALL_FLOATH_TYPES. This enables support for Half (fp16), Float (fp32), and Double (fp64). I have also included more comprehensive testing against the input dtypes, including adding double testing since it didn't already exist before. Instead of just confirming that all the output dtypes are supported, we also have a check that all input dtypes are supported now as well. ghstack-source-id: 290376481 @exported-using-ghexport Differential Revision: [D76053764](https://our.internmc.facebook.com/intern/diff/D76053764/)
1 parent 8cfa858 commit 92d9b14

File tree

2 files changed

+67
-2
lines changed

2 files changed

+67
-2
lines changed

kernels/quantized/cpu/op_quantize.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ Tensor& quantize_per_tensor_out(
150150
break;
151151

152152
switch (input.scalar_type()) {
153-
ET_FORALL_FLOAT_TYPES(CALCULATE_FLOAT_TYPE);
153+
ET_FORALL_FLOATH_TYPES(CALCULATE_FLOAT_TYPE);
154154
default:
155155
ET_CHECK_MSG(
156156
false,
@@ -346,7 +346,7 @@ Tensor& quantize_per_channel_out(
346346
break;
347347

348348
switch (input.scalar_type()) {
349-
ET_FORALL_FLOAT_TYPES(CALCULATE_FLOAT_TYPE);
349+
ET_FORALL_FLOATH_TYPES(CALCULATE_FLOAT_TYPE);
350350
default:
351351
ET_CHECK_MSG(
352352
false,

kernels/quantized/test/op_quantize_test.cpp

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,32 @@ void test_dtype() {
4949
EXPECT_TENSOR_EQ(out, expected);
5050
}
5151

52+
template <ScalarType INPUT_DTYPE>
53+
void test_input_dtype() {
54+
TensorFactory<INPUT_DTYPE> tf_input;
55+
56+
Tensor input = tf_input.full({3, 5}, 4);
57+
double scale = 0.5;
58+
int64_t zero_point = 108;
59+
int64_t quant_min = 0;
60+
int64_t quant_max = 127;
61+
62+
TensorFactory<ScalarType::Char> tfo;
63+
Tensor out = tfo.zeros({3, 5});
64+
// 4 / 0.5 + 108 = 116
65+
Tensor expected = tfo.full({3, 5}, 116);
66+
quantize_per_tensor_out(
67+
input, scale, zero_point, quant_min, quant_max, ScalarType::Char, out);
68+
69+
EXPECT_TENSOR_EQ(out, expected);
70+
}
71+
72+
TEST(OpQuantizeOutTest, AllInputDtypesSupported) {
73+
test_input_dtype<ScalarType::Float>();
74+
test_input_dtype<ScalarType::Half>();
75+
test_input_dtype<ScalarType::Double>();
76+
}
77+
5278
TEST(OpQuantizeOutTest, AllDtypesSupported) {
5379
test_dtype<ScalarType::Byte>();
5480
test_dtype<ScalarType::Char>();
@@ -58,6 +84,45 @@ TEST(OpQuantizeOutTest, AllDtypesSupported) {
5884
test_dtype<ScalarType::Int>();
5985
}
6086

87+
TEST(OpQuantizeOutTest, DoubleInputTest) {
88+
TensorFactory<ScalarType::Double> tf_double;
89+
90+
// Test with a more complex value that might have precision differences
91+
Tensor input = tf_double.full({2, 3}, 3.14159265359);
92+
double scale = 0.01;
93+
int64_t zero_point = -100;
94+
int64_t quant_min = 0;
95+
int64_t quant_max = 255;
96+
97+
TensorFactory<ScalarType::Byte> tfo;
98+
Tensor out = tfo.zeros({2, 3});
99+
// 3.14159265359 / 0.01 - 100 = 214.159265359
100+
Tensor expected = tfo.full({2, 3}, 214);
101+
quantize_per_tensor_out(
102+
input, scale, zero_point, quant_min, quant_max, ScalarType::Byte, out);
103+
104+
EXPECT_TENSOR_EQ(out, expected);
105+
}
106+
107+
TEST(OpQuantizeOutTest, HalfInputTest) {
108+
TensorFactory<ScalarType::Half> tf_half;
109+
110+
Tensor input = tf_half.full({2, 3}, 2.5);
111+
double scale = 0.5;
112+
int64_t zero_point = 10;
113+
int64_t quant_min = -128;
114+
int64_t quant_max = 127;
115+
116+
TensorFactory<ScalarType::Char> tfo;
117+
Tensor out = tfo.zeros({2, 3});
118+
// 2.5 / 0.5 + 10 = 15
119+
Tensor expected = tfo.full({2, 3}, 15);
120+
quantize_per_tensor_out(
121+
input, scale, zero_point, quant_min, quant_max, ScalarType::Char, out);
122+
123+
EXPECT_TENSOR_EQ(out, expected);
124+
}
125+
61126
TEST(OpQuantizeOutTest, TensorArgOverload) {
62127
TensorFactory<ScalarType::Float> tf_float;
63128
TensorFactory<ScalarType::Double> tf_double;

0 commit comments

Comments
 (0)