Skip to content

Commit 972c988

Browse files
authored
[ET-VK] Enable int8 tiled compute shader to be used with buffer tensors (#10415)
## Context As title. Allow the optimized int8 tiled compute shader to be usable for buffer-backed tensors as well. ## Changes * Generate buffer variants for the int8 linear tiled shader * Force the scales tensor to always be a buffer to reduce the number of shader variants that need to be generated. * Generate an additional variant that computes only 1 output row * Do not require output rows to be an exact multiple of 4 or 6 to use the tiled implementation Differential Revision: [D73276277](https://our.internmc.facebook.com/intern/diff/D73276277/)
1 parent 2553d99 commit 972c988

File tree

4 files changed

+75
-38
lines changed

4 files changed

+75
-38
lines changed

backends/vulkan/runtime/graph/ops/glsl/q_8w_linear_tiled.glsl

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -17,17 +17,17 @@
1717

1818
${define_required_extensions(DTYPE)}
1919

20-
$if STORAGE == "buffer":
20+
$if WEIGHT_STORAGE == "buffer":
2121
${define_required_extensions("int8")}
2222

2323
#extension GL_EXT_control_flow_attributes : require
2424

2525
layout(std430) buffer;
2626

27-
${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE, is_scalar_array=False)}
28-
${layout_declare_tensor(B, "r", "t_in", DTYPE, STORAGE, is_scalar_array=False)}
29-
${layout_declare_tensor(B, "r", "t_weight", "int8", STORAGE, is_scalar_array=False)}
30-
${layout_declare_tensor(B, "r", "t_scales", DTYPE, STORAGE, is_scalar_array=False)}
27+
${layout_declare_tensor(B, "w", "t_out", DTYPE, OUT_STORAGE, is_scalar_array=False)}
28+
${layout_declare_tensor(B, "r", "t_in", DTYPE, IN_STORAGE, is_scalar_array=False)}
29+
${layout_declare_tensor(B, "r", "t_weight", "int8", WEIGHT_STORAGE, is_scalar_array=False)}
30+
${layout_declare_tensor(B, "r", "t_scales", DTYPE, SCALES_STORAGE, is_scalar_array=False)}
3131

3232

3333
layout(push_constant) uniform restrict Block {
@@ -50,10 +50,10 @@ void main() {
5050
VEC4_T b[4];
5151
VEC4_T c[TILE_ROWS];
5252

53-
$if STORAGE == "buffer":
53+
$if SCALES_STORAGE == "buffer":
5454
const VEC4_T scales = VEC4_T(t_scales[out_col >> 2]);
5555
$else:
56-
const VEC4_T scales = VEC4_T(texelFetch(t_scales, ivec3(out_col >> 2, 0, 0), 0));
56+
const VEC4_T scales = VEC4_T(texelFetch(t_scales, ivec2(out_col >> 2, 0), 0));
5757

5858
[[unroll]] for (int i = 0; i < TILE_ROWS; ++i) {
5959
c[i] = VEC4_T(0.0);
@@ -62,30 +62,32 @@ void main() {
6262
for (int pos = 0; pos < in_sizes.x; pos += 4) {
6363
// Preload weight tensor
6464
[[unroll]] for (int i = 0; i < 4; i++) {
65-
$if STORAGE == "buffer":
66-
b[i] = t_weight[((pos + i) * B_sizes.x + out_col) >> 2];
65+
$if WEIGHT_STORAGE == "buffer":
66+
b[i] = t_weight[((pos + i) * out_sizes.x + out_col) >> 2];
6767
$else:
68-
b[i] = VEC4_T(texelFetch(t_weight, ivec3(out_col >> 2, pos + i, 0), 0));
68+
b[i] = VEC4_T(texelFetch(t_weight, ivec2(out_col >> 2, pos + i), 0));
6969
}
7070

7171
// Preload input tensor
7272
[[unroll]] for (int i = 0; i < TILE_ROWS; i++) {
73-
$if STORAGE == "buffer":
74-
a[i] = t_in[((out_row + i) * in_sizes.x + (pos)) >> 2];
73+
$if IN_STORAGE == "buffer":
74+
a[i] = t_in[((out_row + i) * in_sizes.x + pos) >> 2];
7575
$else:
7676
a[i] = VEC4_T(texelFetch(t_in, ivec3(pos >> 2, out_row + i, 0), 0));
7777
}
7878

79-
// Compute partial output
79+
// Accumulate output
8080
[[unroll]] for (int i = 0; i < TILE_ROWS; ++i) {
8181
c[i] += a[i].x * b[0] + a[i].y * b[1] + a[i].z * b[2] + a[i].w * b[3];
8282
}
8383
}
8484

85-
// Store output tensor
85+
// Store to output tensor
8686
[[unroll]] for (int i = 0; i < TILE_ROWS; ++i) {
87-
$if STORAGE == "buffer":
88-
t_out[((out_row + i) * out_sizes.x + out_col) >> 2] = c[i] * scales;
87+
$if OUT_STORAGE == "buffer":
88+
if (out_row + i < out_sizes.y) {
89+
t_out[((out_row + i) * out_sizes.x + out_col) >> 2] = c[i] * scales;
90+
}
8991
$else:
9092
imageStore(t_out, ivec3(out_col >> 2, out_row + i, 0), c[i] * scales);
9193
}

backends/vulkan/runtime/graph/ops/glsl/q_8w_linear_tiled.yaml

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,26 @@
77
q_8w_linear_tiled:
88
parameter_names_with_default_values:
99
DTYPE: float
10-
STORAGE: texture3d
10+
IN_STORAGE: texture3d
11+
OUT_STORAGE: texture3d
12+
WEIGHT_STORAGE: texture2d
13+
SCALES_STORAGE: texture2d
1114
TILE_ROWS: 4
15+
generate_variant_forall:
16+
TILE_ROWS:
17+
- VALUE: 1
18+
SUFFIX: o4x1
19+
- VALUE: 4
20+
SUFFIX: o4x4
21+
- VALUE: 6
22+
SUFFIX: o4x6
1223
shader_variants:
13-
- NAME: q_8w_linear_tiled_o4x4_texture3d_float
14-
STORAGE: texture3d
15-
TILE_ROWS: 4
16-
- NAME: q_8w_linear_tiled_o4x6_texture3d_float
17-
STORAGE: texture3d
18-
TILE_ROWS: 6
24+
- NAME: q_8w_linear_tiled_texture3d_texture3d_texture2d_texture2d_float
25+
- NAME: q_8w_linear_tiled_buffer_buffer_texture2d_texture2d_float
26+
IN_STORAGE: buffer
27+
OUT_STORAGE: buffer
28+
- NAME: q_8w_linear_tiled_buffer_buffer_buffer_buffer_float
29+
IN_STORAGE: buffer
30+
OUT_STORAGE: buffer
31+
WEIGHT_STORAGE: buffer
32+
SCALES_STORAGE: buffer

backends/vulkan/runtime/graph/ops/impl/QuantizedLinearInt8.cpp

Lines changed: 35 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -146,28 +146,53 @@ void add_q_8w_linear_tiled_node(
146146
const ValueRef q_mat2_data,
147147
const ValueRef scales_data,
148148
const ValueRef out) {
149-
utils::StorageType stype = graph.storage_type_of(out);
149+
utils::StorageType q_mat2_storage = utils::kTexture2D;
150+
151+
uint32_t max_extent = graph.context()->adapter_ptr()->max_texture2d_dim();
152+
std::vector<int64_t> qmat2_orig_sizes = graph.sizes_of(q_mat2_data);
153+
const int64_t ndim = graph.dim_of(q_mat2_data);
154+
const int64_t K = qmat2_orig_sizes.at(ndim - 1);
155+
const int64_t N = qmat2_orig_sizes.at(ndim - 2);
156+
157+
if (N > max_extent * 4 || K > max_extent) {
158+
q_mat2_storage = utils::kBuffer;
159+
}
160+
150161
ValueRef q_mat2 = prepack_standard_hw_transposed(
151-
graph, q_mat2_data, stype, utils::kWidthPacked);
162+
graph, q_mat2_data, q_mat2_storage, utils::kWidthPacked);
163+
164+
utils::StorageType scales_storage = utils::kTexture2D;
165+
if (N > max_extent) {
166+
scales_storage = utils::kBuffer;
167+
}
152168
ValueRef scales =
153-
prepack_standard(graph, scales_data, stype, utils::kWidthPacked);
169+
prepack_standard(graph, scales_data, scales_storage, utils::kWidthPacked);
154170

155171
std::string kernel_name = "q_8w_linear_tiled";
156172
kernel_name.reserve(kShaderNameReserve);
173+
add_storage_type_suffix(kernel_name, graph.storage_type_of(out));
174+
add_storage_type_suffix(kernel_name, graph.storage_type_of(mat1));
175+
add_storage_type_suffix(kernel_name, graph.storage_type_of(q_mat2));
176+
add_storage_type_suffix(kernel_name, graph.storage_type_of(scales));
177+
add_dtype_suffix(kernel_name, graph.dtype_of(out));
178+
157179
std::vector<int64_t> mat1_sizes = graph.sizes_of(mat1);
158180
const int64_t M = utils::val_at(-2, mat1_sizes);
159181
int out_tile_nrows = 4;
160182
if (M % 6 == 0) {
161183
kernel_name += "_o4x6";
162184
out_tile_nrows = 6;
185+
} else if (M % 4 == 0) {
186+
kernel_name += "_o4x4";
187+
out_tile_nrows = 4;
188+
} else if (M % 1 == 0) {
189+
kernel_name += "_o4x1";
190+
out_tile_nrows = 1;
163191
} else {
164192
kernel_name += "_o4x4";
165193
out_tile_nrows = 4;
166194
}
167195

168-
add_storage_type_suffix(kernel_name, graph.storage_type_of(out));
169-
add_dtype_suffix(kernel_name, graph.dtype_of(out));
170-
171196
utils::uvec3 global_wg_size = graph.logical_limits_of(out);
172197
global_wg_size[1] = global_wg_size[1] / out_tile_nrows;
173198

@@ -209,18 +234,13 @@ bool can_use_tiled_impl(
209234
if (graph.size_at<int>(-1, mat1) % 4 != 0) {
210235
return false;
211236
}
212-
// Check that M is a multiple of 4 or 6
213-
if (graph.size_at<int>(-2, mat1) % 4 != 0 &&
214-
graph.size_at<int>(-2, mat1) % 6 != 0) {
215-
return false;
216-
}
217-
// Check that the storage type is texture
218-
// TODO(ssjia): Add support for buffer storage in the tiled impl
219-
if (graph.storage_type_of(out) != utils::kTexture3D) {
237+
// Check that N is a multiple of 4
238+
if (graph.size_at<int>(-1, out) % 4 != 0) {
220239
return false;
221240
}
222241
// Check that the packed dim is the width dim
223-
if (graph.packed_dim_of(mat1) != WHCN::kWidthDim) {
242+
if (graph.packed_dim_of(mat1) != WHCN::kWidthDim &&
243+
graph.packed_dim_of(out) != WHCN::kWidthDim) {
224244
return false;
225245
}
226246
// Check that no special axis mapping is used for the input

backends/vulkan/test/op_tests/cases.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,7 @@ def get_linear_inputs():
152152
@register_test_suite("aten._weight_int8pack_mm.default")
153153
def get_weight_int8pack_mm_inputs():
154154
MKN_list = [
155+
[3, 480, 256],
155156
[6, 480, 256],
156157
[6, 256, 1024],
157158
[6, 1024, 256],

0 commit comments

Comments
 (0)