Skip to content

Commit 086594e

Browse files
author
morelos
committed
[ET-VK][Ops] dequantization op shaders and impl
Pull Request resolved: #11483 # Operator Description The dequantization operator converts lower-precision integer tensors (uint8/int8/int32) back to floating-point formats (fp16/fp32) using affine dequantization. This operator supports two dequantization modes: - **Per-tensor dequantization**: Uses a single scale and zero_point for the entire tensor - **Per-token dequantization**: Uses different scale and zero_point values for each "token" (typically rows or channels) The dequantization formula is: `dequantized_value = (quantized_value - zero_point) * scale` **Example**: For a quantized uint8 value `153` with `scale=0.1`, `zero_point=128`: - `(153 - 128) * 0.1 = 25 * 0.1 = 2.5` (float output) The dequantization parameters serve these purposes: - **scale**: Controls the granularity of reconstruction (same scale used during quantization) - **zero_point**: Maps the integer zero representation back to floating-point zero - **quant_min/quant_max**: Define the valid range that was used during original quantization (for validation) # Shader Algorithm Overview ## Texture Storage Implementation (`dequantize_texture.glsl`) The texture-based implementation operates on 3D textures where data is stored in RGBA texel format (4 components per texel): **Per-tensor Mode**: Each compute thread processes one texel position. It loads a 4-component integer texel from the input texture, and applies dequantization to each of the 4 components using shared scale/zero_point parameters. It then writes the dequantized 4-component floating-point result to the output texture. This method processes all components uniformly with the same dequantization parameters. **Per-token Mode**: We need to calculate the token index based on the spatial position, it'll differ between various cases like 3D and 2D. For instance we might define the token_idx as `z * dims.y + y` for 3D, or just `y` for 2D cases. We then retrieve the per-token scale/zero_point from the texture storage according to the token_idx. We need to do component indexing based on the texel_idx and token_idx: `texel_idx = token_idx / 4`, along with the component id `comp_idx = token_idx % 4` to get the necessary scale/zero_point values. We then apply dequantization with the corresponding token-specific parameters to the 4 components of the current texel, converting each integer component to its floating-point representation. ## Buffer Storage Implementation (`dequantize_buffer.glsl`) The buffer-based implementation operates on linear memory buffers with stride-based indexing: **Per-tensor Mode**: In this case, each compute thread will process one element at its global position. It converts the 3D position to linear buffer indices using stride calculations `tidx_to_bufi(pos, strides)`. It then loads single quantized integer values from the input buffer and applies dequantization using shared scale/zero_point parameters. We then store the dequantized floating-point result to the output buffer at the corresponding index. **Per-token Mode**: We first calculate the logical tensor position from the linear buffer index through dimension unwrapping. We then determine the token index based on the tensor dimensionality: - 4D: `token_idx = w * (z * y) + z * y + y` - 3D: `token_idx = z * y + y` - 2D: `token_idx = y` We then directly index into scale/zero_point buffers using token_idx and apply dequantization with the token-specific parameters, converting the quantized integer value back to its original floating-point representation. # Performance Considerations / Future Improvements Current implementation uses default workgroup sizing. Buffer implementation processes one element per thread. Could be optimized to process multiple elements per thread for better throughput. NOTE: Currently the only input types supported are **byte** (uint8), **char** (int8), **int** (int32). The only output types supported are **half** (fp16) and **float** (fp32). A future diff plans to implement **double** (fp64) output dtype support. ghstack-source-id: 290294978 @exported-using-ghexport Differential Revision: [D76267107](https://our.internmc.facebook.com/intern/diff/D76267107/)
1 parent 4246ad6 commit 086594e

File tree

7 files changed

+756
-4
lines changed

7 files changed

+756
-4
lines changed
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#ifndef DEQUANTIZE_GLSLH
10+
#define DEQUANTIZE_GLSLH
11+
12+
OUT_T dequantize_val(IN_T qvalue, float scale_val, int zero_point_val) {
13+
return OUT_T(float(int(qvalue) - zero_point_val) * scale_val);
14+
}
15+
16+
#endif // DEQUANTIZE_GLSLH
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#version 450 core
10+
11+
#define PRECISION ${PRECISION}
12+
13+
#define IN_T ${buffer_scalar_type(IN_DTYPE)}
14+
#define OUT_T ${buffer_scalar_type(OUT_DTYPE)}
15+
16+
${define_active_storage_type("buffer")}
17+
${define_required_extensions(IN_DTYPE)}
18+
${define_required_extensions(OUT_DTYPE)}
19+
20+
layout(std430) buffer;
21+
22+
${layout_declare_tensor(B, "r", "t_in", IN_DTYPE, "buffer")}
23+
${layout_declare_tensor(B, "w", "t_out", OUT_DTYPE, "buffer")}
24+
25+
$if MODE == "per_tensor":
26+
layout(push_constant) uniform restrict Block {
27+
float scale;
28+
int zero_point;
29+
int quant_min;
30+
int quant_max;
31+
};
32+
$else:
33+
${layout_declare_tensor(B, "r", "t_scale", "float", "buffer")}
34+
${layout_declare_tensor(B, "r", "t_zero_point", "int", "buffer")}
35+
36+
layout(push_constant) uniform restrict Block {
37+
int num_tokens;
38+
int quant_min;
39+
int quant_max;
40+
};
41+
42+
${layout_declare_ubo(B, "ivec4", "t_in_sizes")}
43+
${layout_declare_ubo(B, "ivec4", "t_in_strides")}
44+
${layout_declare_ubo(B, "ivec4", "t_out_sizes")}
45+
${layout_declare_ubo(B, "ivec4", "t_out_strides")}
46+
47+
#include "indexing_utils.h"
48+
#include "dequantize.glslh"
49+
50+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
51+
52+
void main() {
53+
$if MODE == "per_tensor":
54+
const ivec4 pos = ivec4(
55+
gl_GlobalInvocationID.x,
56+
gl_GlobalInvocationID.y,
57+
gl_GlobalInvocationID.z,
58+
0);
59+
60+
const int t_in_idx = tidx_to_bufi(pos, t_in_strides);
61+
const int t_out_idx = tidx_to_bufi(pos, t_out_strides);
62+
63+
IN_T qvalue = t_in[t_in_idx];
64+
OUT_T value;
65+
66+
value = dequantize_val(qvalue, scale, zero_point);
67+
68+
t_out[t_out_idx] = value;
69+
70+
$if MODE == "per_token":
71+
const ivec4 pos = ivec4(
72+
gl_GlobalInvocationID.x,
73+
gl_GlobalInvocationID.y,
74+
gl_GlobalInvocationID.z,
75+
0);
76+
77+
const int t_in_idx = tidx_to_bufi(pos, t_in_strides);
78+
const int t_out_idx = tidx_to_bufi(pos, t_out_strides);
79+
80+
// Skip if out of bounds
81+
if (t_in_idx >= t_in_sizes.x * t_in_sizes.y * t_in_sizes.z * t_in_sizes.w) {
82+
return;
83+
}
84+
85+
IN_T qvalue = t_in[t_in_idx];
86+
OUT_T value;
87+
88+
// Calculate logical position from linear index and strides
89+
ivec4 logical_pos;
90+
int remaining = t_in_idx;
91+
92+
logical_pos.x = remaining % t_in_sizes.x;
93+
remaining /= t_in_sizes.x;
94+
95+
logical_pos.y = remaining % t_in_sizes.y;
96+
remaining /= t_in_sizes.y;
97+
98+
logical_pos.z = remaining % t_in_sizes.z;
99+
remaining /= t_in_sizes.z;
100+
101+
logical_pos.w = remaining;
102+
103+
// Calculate token index based on logical position
104+
int token_idx = 0;
105+
106+
// Check dimensions to determine how to calculate token_idx
107+
if (t_in_sizes.w > 1) {
108+
// 4D tensor
109+
token_idx = logical_pos.w * (t_in_sizes.z * t_in_sizes.y) + logical_pos.z * t_in_sizes.y + logical_pos.y;
110+
} else if (t_in_sizes.z > 1) {
111+
// 3D tensor
112+
token_idx = logical_pos.z * t_in_sizes.y + logical_pos.y;
113+
} else if (t_in_sizes.y > 1) {
114+
// 2D tensor
115+
token_idx = logical_pos.y;
116+
}
117+
// For 1D tensor, token_idx remains 0
118+
119+
// Make sure token_idx is within bounds
120+
token_idx = min(token_idx, num_tokens - 1);
121+
122+
value = dequantize_val(qvalue, t_scale[token_idx], t_zero_point[token_idx]);
123+
124+
t_out[t_out_idx] = value;
125+
}
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
dequantize_buffer:
2+
parameter_names_with_default_values:
3+
IN_DTYPE: int32
4+
OUT_DTYPE: float
5+
MODE: per_tensor
6+
generate_variant_forall:
7+
IN_DTYPE:
8+
- VALUE: uint8
9+
- VALUE: int8
10+
- VALUE: int32
11+
OUT_DTYPE:
12+
- VALUE: half
13+
- VALUE: float
14+
shader_variants:
15+
- NAME: dequantize_per_tensor_buffer
16+
MODE: per_tensor
17+
- NAME: dequantize_per_token_buffer
18+
MODE: per_token
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#version 450 core
10+
11+
#define PRECISION ${PRECISION}
12+
13+
#define IN_T ${buffer_scalar_type(IN_DTYPE)}
14+
#define IVEC4_T ${texel_load_type(IN_DTYPE, "texture3d")}
15+
16+
#define OUT_T ${buffer_scalar_type(OUT_DTYPE)}
17+
#define FVEC4_T ${texel_load_type(OUT_DTYPE, "texture3d")}
18+
19+
${define_active_storage_type("texture3d")}
20+
${define_required_extensions(IN_DTYPE)}
21+
${define_required_extensions(OUT_DTYPE)}
22+
23+
#extension GL_EXT_control_flow_attributes : require
24+
25+
layout(std430) buffer;
26+
27+
${layout_declare_tensor(B, "r", "t_in", IN_DTYPE, "texture3d")}
28+
${layout_declare_tensor(B, "w", "t_out", OUT_DTYPE, "texture3d")}
29+
30+
$if MODE == "per_tensor":
31+
layout(push_constant) uniform restrict Block {
32+
float scale;
33+
int zero_point;
34+
int quant_min;
35+
int quant_max;
36+
};
37+
$else:
38+
${layout_declare_tensor(B, "r", "t_scale", "float", "texture3d")}
39+
${layout_declare_tensor(B, "r", "t_zero_point", "int", "texture3d")}
40+
41+
layout(push_constant) uniform restrict Block {
42+
int num_tokens;
43+
int quant_min;
44+
int quant_max;
45+
};
46+
47+
${layout_declare_ubo(B, "ivec3", "t_in_limits")}
48+
${layout_declare_ubo(B, "ivec3", "t_out_limits")}
49+
50+
#include "indexing_utils.h"
51+
#include "dequantize.glslh"
52+
53+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
54+
55+
void main() {
56+
$if MODE == "per_tensor":
57+
const ivec3 pos = ivec3(gl_GlobalInvocationID);
58+
59+
// Skip if out of bounds
60+
if (any(greaterThanEqual(pos, t_in_limits))) {
61+
return;
62+
}
63+
64+
IVEC4_T intex = load_texel(t_in, pos);
65+
FVEC4_T outtex;
66+
67+
[[unroll]] for (int i = 0; i < 4; ++i) {
68+
IN_T qvalue = IN_T(intex[i]);
69+
OUT_T value = dequantize_val(qvalue, scale, zero_point);
70+
outtex[i] = value;
71+
}
72+
write_texel(t_out, pos, outtex);
73+
74+
$if MODE == "per_token":
75+
const ivec3 pos = ivec3(gl_GlobalInvocationID);
76+
77+
// Skip if out of bounds
78+
if (any(greaterThanEqual(pos, t_in_limits))) {
79+
return;
80+
}
81+
82+
IVEC4_T intex = load_texel(t_in, pos);
83+
84+
int token_idx = 0;
85+
ivec3 dims = t_in_limits;
86+
87+
if (dims.z > 1) {
88+
// 3D tensor
89+
token_idx = pos.z * dims.y + pos.y;
90+
} else if (dims.y > 1) {
91+
// 2D tensor
92+
token_idx = pos.y;
93+
}
94+
// For 1D tensor, token_idx remains 0
95+
96+
// Make sure token_idx is within bounds
97+
token_idx = min(token_idx, num_tokens - 1);
98+
99+
// For texture storage, we need to calculate the texel position and component index
100+
int texel_idx = token_idx / 4;
101+
int comp_idx = token_idx % 4;
102+
103+
vec4 scale_vals = load_texel(t_scale, ivec3(texel_idx, 0, 0));
104+
ivec4 zp_vals = load_texel(t_zero_point, ivec3(texel_idx, 0, 0));
105+
106+
float scale_val = scale_vals[comp_idx];
107+
int zero_point_val = zp_vals[comp_idx];
108+
109+
FVEC4_T outtex;
110+
[[unroll]] for (int i = 0; i < 4; ++i) {
111+
IN_T qvalue = IN_T(intex[i]);
112+
OUT_T value = dequantize_val(qvalue, scale_val, zero_point_val);
113+
outtex[i] = value;
114+
}
115+
116+
write_texel(t_out, pos, outtex);
117+
}
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
dequantize_texture:
2+
parameter_names_with_default_values:
3+
IN_DTYPE: int32
4+
OUT_DTYPE: float
5+
MODE: per_tensor
6+
generate_variant_forall:
7+
IN_DTYPE:
8+
- VALUE: uint8
9+
- VALUE: int8
10+
- VALUE: int32
11+
OUT_DTYPE:
12+
- VALUE: half
13+
- VALUE: float
14+
shader_variants:
15+
- NAME: dequantize_per_tensor_texture3d
16+
MODE: per_tensor
17+
- NAME: dequantize_per_token_texture3d
18+
MODE: per_token

0 commit comments

Comments
 (0)