@@ -29,16 +29,20 @@ ${layout_declare_tensor(2, "r", "t_qmat2", "int8", STORAGE)}
2929${layout_declare_tensor(3 , "r", "t_scales", DTYPE, STORAGE)}
3030
3131$if STORAGE == "buffer ":
32- ${layout_declare_ubo(4 , "ivec4 ", "out_sizes")}
33- ${layout_declare_ubo(5 , "ivec4 ", "out_strides")}
34- ${layout_declare_ubo(6 , "int ", "out_numel")}
35- ${layout_declare_ubo(7 , "ivec4 ", "mat1_sizes")}
36- ${layout_declare_ubo(8 , "ivec4 ", "mat1_strides")}
37- ${layout_declare_ubo(9 , "ivec4 ", "qmat2_strides")}
38- ${layout_declare_ubo(10 , "ivec4 ", "scales_strides")}
32+ layout (push_constant) uniform restrict Block {
33+ ivec4 out_sizes;
34+ ivec4 out_strides;
35+ ivec4 mat1_sizes;
36+ ivec4 mat1_strides;
37+ ivec4 qmat2_strides;
38+ ivec4 scales_strides;
39+ int out_numel;
40+ };
3941$else :
40- ${layout_declare_ubo(4 , "ivec3 ", "out_limits")}
41- ${layout_declare_ubo(5 , "ivec4 ", "mat1_sizes")}
42+ layout (push_constant) uniform restrict Block {
43+ ivec3 out_limits;
44+ ivec4 mat1_sizes;
45+ };
4246
4347layout (local_size_x_id = 0 , local_size_y_id = 1 , local_size_z_id = 2 ) in ;
4448
@@ -83,42 +87,40 @@ void main() {
8387
8488#else // USING_TEXTURE
8589
86- #extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
87-
8890void main() {
89- const u16vec2 out_pos = u16vec2 (
90- gl_GlobalInvocationID.x,
91- gl_GlobalInvocationID.y );
91+ const ivec2 out_pos = ivec2 (
92+ gl_GlobalInvocationID.x % out_limits.x ,
93+ gl_GlobalInvocationID.x / out_limits.x );
9294
93- if (out_pos.x >= out_limits.x || out_pos. y >= out_limits.y) {
95+ if (out_pos.y >= out_limits.y) {
9496 return ;
9597 }
9698
97- const uint16_t qmat2_pos_x = out_pos.x;
99+ const int qmat2_pos_x = out_pos.x;
98100
99101 VEC4_T outtex = VEC4_T(0 );
100102
101- const VEC4_T scales = load_texel(t_scales, u16vec3 (out_pos.x, 0 , 0 ));
103+ const VEC4_T scales = load_texel(t_scales, ivec3 (out_pos.x, 0 , 0 ));
102104
103105 VEC4_T mat1_tex;
104106 VEC4_T mat2_tex[4 ];
105107 for (
106- uint16_t i = uint16_t( 0 ) , x = uint16_t( 0 ) ;
107- i < uint16_t( mat1_sizes.x) ;
108- i += uint16_t( 4 ) , x++ )
108+ int i = 0 , x = 0 ;
109+ i < mat1_sizes.x;
110+ i += 4 , x++ )
109111 {
110- mat1_tex = load_texel(t_mat1, u16vec3 (x, out_pos.y, 0 ));
112+ mat1_tex = load_texel(t_mat1, ivec3 (x, out_pos.y, 0 ));
111113
112- mat2_tex[0 ] = load_texel(t_qmat2, u16vec3 (out_pos.x, i, 0 ));
113- mat2_tex[1 ] = load_texel(t_qmat2, u16vec3 (out_pos.x, i + uint16_t( 1 ) , 0 ));
114- mat2_tex[2 ] = load_texel(t_qmat2, u16vec3 (out_pos.x, i + uint16_t( 2 ) , 0 ));
115- mat2_tex[3 ] = load_texel(t_qmat2, u16vec3 (out_pos.x, i + uint16_t( 3 ) , 0 ));
114+ mat2_tex[0 ] = load_texel(t_qmat2, ivec3 (out_pos.x, i, 0 ));
115+ mat2_tex[1 ] = load_texel(t_qmat2, ivec3 (out_pos.x, i + 1 , 0 ));
116+ mat2_tex[2 ] = load_texel(t_qmat2, ivec3 (out_pos.x, i + 2 , 0 ));
117+ mat2_tex[3 ] = load_texel(t_qmat2, ivec3 (out_pos.x, i + 3 , 0 ));
116118
117119 outtex += mat1_tex.x * mat2_tex[0 ] + mat1_tex.y * mat2_tex[1 ] + mat1_tex.z * mat2_tex[2 ] + mat1_tex.w * mat2_tex[3 ];
118120 }
119121
120122 outtex *= scales;
121- write_texel(t_out, u16vec3 (out_pos, 0 ), outtex);
123+ write_texel(t_out, ivec3 (out_pos, 0 ), outtex);
122124}
123125
124126#endif
0 commit comments