@@ -29,6 +29,7 @@ namespace {
29
29
*/
30
30
void check_dequantize_args (
31
31
const Tensor& input,
32
+ int64_t zero_point,
32
33
int64_t quant_min,
33
34
int64_t quant_max,
34
35
ScalarType dtype,
@@ -39,6 +40,18 @@ void check_dequantize_args(
39
40
" input.scalar_type() %" PRId8 " is not char type" ,
40
41
static_cast <int8_t >(input.scalar_type ()));
41
42
43
+ // Check zp range
44
+ ET_CHECK_MSG (
45
+ zero_point >= quant_min,
46
+ " zero_point must be %" PRId64 " <= quant_min %" PRId64,
47
+ zero_point,
48
+ quant_min);
49
+ ET_CHECK_MSG (
50
+ zero_point <= quant_max,
51
+ " zero_point must be %" PRId64 " >= quant_max %" PRId64,
52
+ zero_point,
53
+ quant_max);
54
+
42
55
// Check output dtype is float
43
56
ET_CHECK_MSG (
44
57
out.scalar_type () == ScalarType::Float,
@@ -73,18 +86,10 @@ void check_dequantize_args(
73
86
/* *
74
87
* Scalar implementation of quantization for a single value.
75
88
*/
76
- template <typename K, typename T>
77
- T dequantize_val (
78
- float scale,
79
- int32_t zero_point,
80
- K value,
81
- int64_t quant_min,
82
- int64_t quant_max) {
83
- (void )quant_min;
84
- (void )quant_max;
85
- return static_cast <T>((static_cast <int32_t >(value) - zero_point) * scale);
89
+ template <typename Q, typename F>
90
+ F dequantize_val (float scale, int32_t zero_point, Q qvalue) {
91
+ return static_cast <F>((static_cast <int32_t >(qvalue) - zero_point) * scale);
86
92
}
87
-
88
93
} // namespace
89
94
90
95
Tensor& dequantize_per_tensor_out (
@@ -106,29 +111,71 @@ Tensor& dequantize_per_tensor_out(
106
111
" Failed to resize out Tensor in dequantize_per_tensor_out" );
107
112
108
113
// Validate input parameters
109
- check_dequantize_args (input, quant_min, quant_max, dtype, out);
114
+ check_dequantize_args (input, zero_point, quant_min, quant_max, dtype, out);
110
115
111
- // Pre-compute inverse scale for better performance
112
116
int32_t zp = static_cast <int32_t >(zero_point);
113
- int32_t qmin = static_cast <int32_t >(quant_min);
114
- int32_t qmax = static_cast <int32_t >(quant_max);
115
117
116
118
// Get pointers to input and output data
117
119
const int8_t * input_data = input.const_data_ptr <int8_t >();
118
120
float * out_data = out.mutable_data_ptr <float >();
119
121
const size_t numel = input.numel ();
120
122
123
+ size_t i = 0 ;
121
124
#if defined(HAS_HELIUM_SIMD)
122
- // Helium MVE implementation for float32 to int8 quantization
123
- #Error " Implement MVE version!"
124
- #else
125
- // Scalar implementation for float32 to int8 quantization
126
- for (size_t i = 0 ; i < numel; i++) {
127
- out_data[i] =
128
- dequantize_val<int8_t , float >(scale, zp, input_data[i], qmin, qmax);
125
+ // Helium MVE implementation for int8 to float quantization
126
+ static uint8x16_t voffset{
127
+ 0x0 ,
128
+ 0x8 ,
129
+ 0x4 ,
130
+ 0xC ,
131
+ 0x1 ,
132
+ 0x9 ,
133
+ 0x5 ,
134
+ 0xD ,
135
+ 0x2 ,
136
+ 0xA ,
137
+ 0x6 ,
138
+ 0xE ,
139
+ 0x3 ,
140
+ 0xB ,
141
+ 0x7 ,
142
+ 0xF };
143
+
144
+ int16x8_t vzp = vdupq_n_s16 (static_cast <int16_t >(zp));
145
+ float32x4_t vscale = vdupq_n_f32 (static_cast <float >(scale));
146
+
147
+ for (; i + 15 < numel; i += 16 ) {
148
+ int8x16_t in_084C195D2A6E3B7F =
149
+ vldrbq_gather_offset_s8 (input_data, voffset);
150
+
151
+ int16x8_t in_04152637 = vsubq_s16 (vmovlbq_s8 (in_084C195D2A6E3B7F), vzp);
152
+ int16x8_t in_8C9DAEBF = vsubq_s16 (vmovltq_s8 (in_084C195D2A6E3B7F), vzp);
153
+
154
+ float32x4_t inf_0123 = vcvtq_f32_s32 (vmovlbq_s16 (in_04152637));
155
+ float32x4_t inf_4567 = vcvtq_f32_s32 (vmovltq_s16 (in_04152637));
156
+ float32x4_t inf_89AB = vcvtq_f32_s32 (vmovlbq_s16 (in_8C9DAEBF));
157
+ float32x4_t inf_CDEF = vcvtq_f32_s32 (vmovltq_s16 (in_8C9DAEBF));
158
+
159
+ float32x4_t out_0123 = vmulq_f32 (inf_0123, vscale);
160
+ float32x4_t out_4567 = vmulq_f32 (inf_4567, vscale);
161
+ float32x4_t out_89AB = vmulq_f32 (inf_89AB, vscale);
162
+ float32x4_t out_CDEF = vmulq_f32 (inf_CDEF, vscale);
163
+
164
+ vstrwq_f32 (out_data + 0 , out_0123);
165
+ vstrwq_f32 (out_data + 4 , out_4567);
166
+ vstrwq_f32 (out_data + 8 , out_89AB);
167
+ vstrwq_f32 (out_data + 12 , out_CDEF);
168
+
169
+ input_data += 16 ;
170
+ out_data += 16 ;
129
171
}
130
- #endif
172
+ #endif // defined(HAS_HELIUM_SIMD)
131
173
174
+ for (; i < numel; i++) {
175
+ *out_data = dequantize_val<int8_t , float >(scale, zp, *input_data);
176
+ input_data++;
177
+ out_data++;
178
+ }
132
179
return out;
133
180
}
134
181
0 commit comments