File tree Expand file tree Collapse file tree 4 files changed +22
-13
lines changed
backends/vulkan/runtime/graph/ops Expand file tree Collapse file tree 4 files changed +22
-13
lines changed Original file line number Diff line number Diff line change @@ -38,18 +38,21 @@ layout(push_constant) uniform restrict Block {
38
38
ivec4 weight_sizes;
39
39
};
40
40
41
+ #include "indexing_utils.h"
42
+
41
43
layout (local_size_x_id = 0 , local_size_y_id = 1 , local_size_z_id = 2 ) in ;
42
44
43
45
shared VEC4_T partial_c[NGROUPS][NWORKERS][TILE_ROWS];
44
46
45
47
void main() {
46
- const uint out_row = gl_GlobalInvocationID.y * TILE_ROWS;
47
- const uint out_col = gl_GlobalInvocationID.x << 2 ;
48
+ const uint out_width_ntexels = divup4(out_sizes.x);
49
+ const uint out_col = (gl_GlobalInvocationID.x % out_width_ntexels) << 2 ;
50
+ const uint out_row = (gl_GlobalInvocationID.x / out_width_ntexels) * TILE_ROWS;
48
51
49
52
const int gid = int (gl_LocalInvocationID.x); // group id
50
53
const int wid = int (gl_LocalInvocationID.z); // worker id
51
54
52
- if (out_col >= out_sizes.x || out_row >= out_sizes.y) {
55
+ if (out_row >= out_sizes.y) {
53
56
return ;
54
57
}
55
58
Original file line number Diff line number Diff line change @@ -36,13 +36,16 @@ layout(push_constant) uniform restrict Block {
36
36
ivec4 weight_sizes;
37
37
};
38
38
39
+ #include "indexing_utils.h"
40
+
39
41
layout (local_size_x_id = 0 , local_size_y_id = 1 , local_size_z_id = 2 ) in ;
40
42
41
43
void main() {
42
- const uint out_row = gl_GlobalInvocationID.y * TILE_ROWS;
43
- const uint out_col = gl_GlobalInvocationID.x << 2 ;
44
+ const uint out_width_ntexels = divup4(out_sizes.x);
45
+ const uint out_col = (gl_GlobalInvocationID.x % out_width_ntexels) << 2 ;
46
+ const uint out_row = (gl_GlobalInvocationID.x / out_width_ntexels) * TILE_ROWS;
44
47
45
- if (out_col >= out_sizes.x || out_row >= out_sizes.y) {
48
+ if (out_row >= out_sizes.y) {
46
49
return ;
47
50
}
48
51
Original file line number Diff line number Diff line change @@ -16,10 +16,10 @@ q_8w_linear_tiled:
16
16
TILE_ROWS :
17
17
- VALUE : 1
18
18
SUFFIX : o4x1
19
+ - VALUE : 2
20
+ SUFFIX : o4x2
19
21
- VALUE : 4
20
22
SUFFIX : o4x4
21
- - VALUE : 6
22
- SUFFIX : o4x6
23
23
shader_variants :
24
24
- NAME : q_8w_linear_tiled_texture3d_texture3d_texture2d_texture2d_float
25
25
- NAME : q_8w_linear_tiled_buffer_buffer_texture2d_texture2d_float
Original file line number Diff line number Diff line change @@ -180,10 +180,10 @@ void add_q_8w_linear_tiled_node(
180
180
181
181
std::vector<int64_t > mat1_sizes = graph.sizes_of (mat1);
182
182
const int64_t M = utils::val_at (-2 , mat1_sizes);
183
- int out_tile_nrows = 4 ;
183
+ uint32_t out_tile_nrows = 4 ;
184
184
if (M % 6 == 0 ) {
185
- kernel_name += " _o4x6 " ;
186
- out_tile_nrows = 6 ;
185
+ kernel_name += " _o4x2 " ;
186
+ out_tile_nrows = 2 ;
187
187
} else if (M % 4 == 0 ) {
188
188
kernel_name += " _o4x4" ;
189
189
out_tile_nrows = 4 ;
@@ -195,8 +195,11 @@ void add_q_8w_linear_tiled_node(
195
195
out_tile_nrows = 4 ;
196
196
}
197
197
198
- utils::uvec3 global_wg_size = graph.logical_limits_of (out);
199
- global_wg_size[1 ] = global_wg_size[1 ] / out_tile_nrows;
198
+ utils::uvec3 out_limits = graph.logical_limits_of (out);
199
+ utils::uvec3 global_wg_size = {
200
+ out_limits[0 ] * (utils::div_up (out_limits[1 ], out_tile_nrows)),
201
+ 1 ,
202
+ out_limits[2 ]};
200
203
201
204
utils::uvec3 local_wg_size{64 , 1 , 1 };
202
205
if (use_coop_algorithm) {
You can’t perform that action at this time.
0 commit comments