Skip to content

Commit bc09711

Browse files
committed
[ET-VK] Modify quantized linear tiling shader to linearly dispatch work to improve thread occupancy and performance.
Pull Request resolved: #10508 This diff changes tiled 8 bit quantized linear mat mul op to linearly dispatch work which increases thread occupancy and improves performance. Differential Revision: [D73751979](https://our.internmc.facebook.com/intern/diff/D73751979/) ghstack-source-id: 280651869
1 parent df75088 commit bc09711

File tree

3 files changed

+14
-7
lines changed

3 files changed

+14
-7
lines changed

backends/vulkan/runtime/graph/ops/glsl/q_8w_linear_tiled.glsl

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,13 +36,16 @@ layout(push_constant) uniform restrict Block {
3636
ivec4 weight_sizes;
3737
};
3838

39+
#include "indexing_utils.h"
40+
3941
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
4042

4143
void main() {
42-
const uint out_row = gl_GlobalInvocationID.y * TILE_ROWS;
43-
const uint out_col = gl_GlobalInvocationID.x << 2;
44+
const uint out_size_x_div_4 = divup4(out_sizes.x);
45+
const uint out_col = (gl_GlobalInvocationID.x % out_size_x_div_4) << 2;
46+
const uint out_row = (gl_GlobalInvocationID.x / out_size_x_div_4) * TILE_ROWS;
4447

45-
if (out_col >= out_sizes.x || out_row >= out_sizes.y) {
48+
if (out_row >= out_sizes.y) {
4649
return;
4750
}
4851

backends/vulkan/runtime/graph/ops/glsl/q_8w_linear_tiled.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,10 @@ q_8w_linear_tiled:
1616
TILE_ROWS:
1717
- VALUE: 1
1818
SUFFIX: o4x1
19+
- VALUE: 2
20+
SUFFIX: o4x2
1921
- VALUE: 4
2022
SUFFIX: o4x4
21-
- VALUE: 6
22-
SUFFIX: o4x6
2323
shader_variants:
2424
- NAME: q_8w_linear_tiled_texture3d_texture3d_texture2d_texture2d_float
2525
- NAME: q_8w_linear_tiled_buffer_buffer_texture2d_texture2d_float

backends/vulkan/runtime/graph/ops/impl/QuantizedLinearInt8.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -182,8 +182,8 @@ void add_q_8w_linear_tiled_node(
182182
const int64_t M = utils::val_at(-2, mat1_sizes);
183183
int out_tile_nrows = 4;
184184
if (M % 6 == 0) {
185-
kernel_name += "_o4x6";
186-
out_tile_nrows = 6;
185+
kernel_name += "_o4x2";
186+
out_tile_nrows = 2;
187187
} else if (M % 4 == 0) {
188188
kernel_name += "_o4x4";
189189
out_tile_nrows = 4;
@@ -197,6 +197,10 @@ void add_q_8w_linear_tiled_node(
197197

198198
utils::uvec3 global_wg_size = graph.logical_limits_of(out);
199199
global_wg_size[1] = global_wg_size[1] / out_tile_nrows;
200+
if (!use_coop_algorithm) {
201+
global_wg_size[0] *= global_wg_size[1];
202+
global_wg_size[1] = 1;
203+
}
200204

201205
utils::uvec3 local_wg_size{64, 1, 1};
202206
if (use_coop_algorithm) {

0 commit comments

Comments
 (0)