Skip to content

Commit db00b95

Browse files
authored
[ET-VK][Ops] dequantization op shaders and impl
Differential Revision: D76267107 Pull Request resolved: #11483
1 parent aa0ee22 commit db00b95

File tree

7 files changed

+936
-8
lines changed

7 files changed

+936
-8
lines changed
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#ifndef DEQUANTIZE_GLSLH
10+
#define DEQUANTIZE_GLSLH
11+
12+
OUT_T dequantize_val(IN_T qvalue, float scale_val, int zero_point_val) {
13+
return OUT_T(float(int(qvalue) - zero_point_val) * scale_val);
14+
}
15+
16+
#endif // DEQUANTIZE_GLSLH
Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#version 450 core
10+
11+
#define PRECISION ${PRECISION}
12+
13+
#define IN_T ${buffer_scalar_type(IN_DTYPE)}
14+
#define OUT_T ${buffer_scalar_type(OUT_DTYPE)}
15+
16+
#define ${MODE}
17+
18+
${define_active_storage_type("buffer")}
19+
${define_required_extensions(IN_DTYPE)}
20+
${define_required_extensions(OUT_DTYPE)}
21+
22+
layout(std430) buffer;
23+
24+
#include "indexing_utils.h"
25+
26+
${layout_declare_tensor(B, "w", "t_out", OUT_DTYPE, "buffer")}
27+
${layout_declare_tensor(B, "r", "t_in", IN_DTYPE, "buffer")}
28+
29+
$if MODE == "per_tensor":
30+
layout(push_constant) uniform restrict Block {
31+
float scale;
32+
int zero_point;
33+
int quant_min;
34+
int quant_max;
35+
};
36+
$if MODE == "per_token":
37+
${layout_declare_tensor(B, "r", "t_scale", "float", "buffer")}
38+
${layout_declare_tensor(B, "r", "t_zero_point", "int", "buffer")}
39+
40+
layout(push_constant) uniform restrict Block {
41+
int num_tokens;
42+
int quant_min;
43+
int quant_max;
44+
};
45+
46+
${layout_declare_ubo(B, "int", "out_numel")}
47+
${layout_declare_ubo(B, "ivec4", "t_in_sizes")}
48+
${layout_declare_ubo(B, "ivec4", "t_in_strides")}
49+
${layout_declare_ubo(B, "ivec4", "t_out_sizes")}
50+
${layout_declare_ubo(B, "ivec4", "t_out_strides")}
51+
52+
${layout_declare_spec_const(C, "int", "out_layout", "DEFAULT_LAYOUT")}
53+
${layout_declare_spec_const(C, "int", "in_layout", "DEFAULT_LAYOUT")}
54+
55+
#include "dequantize.glslh"
56+
57+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
58+
59+
const lowp ivec4 out_dim_order = unhash_dim_order(out_layout);
60+
const lowp ivec4 in_dim_order = unhash_dim_order(in_layout);
61+
62+
/*
63+
* DEQUANTIZATION SHADER (BUFFER STORAGE)
64+
*
65+
* This shader converts n-bit integer tensor values back to floating-point representations
66+
* using pre-computed quantization parameters (scale and zero_point). The dequantization
67+
* reconstructs the original floating-point values from their discrete integer representations
68+
* with minimal precision loss.
69+
*
70+
* ALGORITHM:
71+
* 1. Load quantized integer value from buffer
72+
* 2. Apply dequantization formula: value = (qvalue - zero_point) * scale
73+
* 3. Store reconstructed floating-point value to output buffer
74+
*
75+
* WORKGROUP CONFIGURATION:
76+
* - Per-Tensor Mode:
77+
* - Global WG Size: {num_elements, 1, 1} (one thread per tensor element)
78+
* - Local WG Size: Default (typically {64, 1, 1} or based on global WG size)
79+
* - Per-Token Mode:
80+
* - Global WG Size: {num_elements, 1, 1} (one thread per tensor element)
81+
* - Local WG Size: Default (typically {64, 1, 1} or based on global WG size)
82+
*
83+
* SUPPORTED CONFIGURATIONS:
84+
* - Buffer Storage: Uses linear buffer indexing with stride-based tensor access
85+
* - Per-Tensor: Supports any tensor layout through stride calculations and dimension ordering
86+
* - Per-Token: Supports only width packed tensors (packed_dim = 0) and standard axis mapping
87+
* - Scale/zero_point tensors: Must use buffer storage with width packing (packed_dim = 0)
88+
*
89+
* DEQUANTIZATION FORMULA VISUALIZATION:
90+
* For integer range [quant_min, quant_max] mapped back to [min_val, max_val]:
91+
*
92+
* Integer Domain: Floating Point Domain:
93+
* quant_min ──────────────► min_val
94+
* │ │
95+
* │ scale = (max_val - min_val) / (quant_max - quant_min)
96+
* │ zero_point = quant_min - round(min_val / scale)
97+
* │ │
98+
* quant_max ──────────────► max_val
99+
*
100+
* Dequantization Process:
101+
* Input: -103 (int8)
102+
* Step 1: qvalue - zero_point = -103 - (-128) = 25
103+
* Step 2: result * scale = 25 * 0.1 = 2.5
104+
* Output: 2.5 (float)
105+
*
106+
* PER-TENSOR DEQUANTIZATION:
107+
* - Single scale and zero_point values for entire tensor
108+
* - All elements use same dequantization parameters
109+
* - Parameters passed as push constants for efficiency
110+
* - Formula: value = (qvalue - zero_point) * scale
111+
*
112+
* PER-TOKEN DEQUANTIZATION:
113+
* - Separate scale and zero_point for each token
114+
* - Token = all elements except last dimension (e.g., for [B,S,H]: B*S tokens of H elements)
115+
* - Parameters stored in buffer arrays indexed by token_id
116+
* - Each thread calculates its token_id from tensor coordinates
117+
* - Formula: value = (qvalue - zero_point[token_id]) * scale[token_id]
118+
*
119+
* Token ID calculation for element at tensor index (w, z, y, x):
120+
* - 4D tensor: token_id = w * (sizes.z * sizes.y) + z * sizes.y + y
121+
* - 3D tensor: token_id = z * sizes.y + y
122+
* - 2D tensor: token_id = y
123+
* - 1D tensor: token_id = 0
124+
*/
125+
126+
#ifdef per_tensor
127+
128+
void dequantize_per_tensor() {
129+
const int out_bufi = int(gl_GlobalInvocationID.x);
130+
131+
if (out_bufi >= out_numel) {
132+
return;
133+
}
134+
135+
const ivec4 out_tidx = bufi_to_tidx(out_bufi, t_out_strides, out_dim_order);
136+
const int in_bufi = tidx_to_bufi(out_tidx, t_in_strides);
137+
138+
IN_T qvalue = t_in[in_bufi];
139+
OUT_T value = dequantize_val(qvalue, scale, zero_point);
140+
141+
t_out[out_bufi] = value;
142+
}
143+
144+
#else
145+
146+
void dequantize_per_token() {
147+
const int out_bufi = int(gl_GlobalInvocationID.x);
148+
149+
if (out_bufi >= out_numel) {
150+
return;
151+
}
152+
153+
const ivec4 out_tidx = bufi_to_tidx(out_bufi, t_out_strides, out_dim_order);
154+
const int in_bufi = tidx_to_bufi(out_tidx, t_in_strides);
155+
156+
IN_T qvalue = t_in[in_bufi];
157+
158+
int token_idx = 0;
159+
160+
if (t_out_sizes.w > 1) {
161+
// 4D tensor
162+
token_idx = out_tidx.w * (t_out_sizes.z * t_out_sizes.y) + out_tidx.z * t_out_sizes.y + out_tidx.y;
163+
} else if (t_out_sizes.z > 1) {
164+
// 3D tensor
165+
token_idx = out_tidx.z * t_out_sizes.y + out_tidx.y;
166+
} else if (t_out_sizes.y > 1) {
167+
// 2D tensor
168+
token_idx = out_tidx.y;
169+
}
170+
// For 1D tensor, token_idx remains 0
171+
172+
token_idx = min(token_idx, num_tokens - 1);
173+
174+
OUT_T value = dequantize_val(qvalue, t_scale[token_idx], t_zero_point[token_idx]);
175+
176+
t_out[out_bufi] = value;
177+
}
178+
179+
#endif
180+
181+
void main() {
182+
dequantize_${MODE}();
183+
}
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
dequantize_buffer:
2+
parameter_names_with_default_values:
3+
IN_DTYPE: int32
4+
OUT_DTYPE: float
5+
MODE: per_tensor
6+
generate_variant_forall:
7+
IN_DTYPE:
8+
- VALUE: uint8
9+
- VALUE: int8
10+
- VALUE: int32
11+
OUT_DTYPE:
12+
- VALUE: half
13+
- VALUE: float
14+
shader_variants:
15+
- NAME: dequantize_per_tensor_buffer
16+
MODE: per_tensor
17+
- NAME: dequantize_per_token_buffer
18+
MODE: per_token

0 commit comments

Comments
 (0)