6
6
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
7
7
8
8
layout (constant_id = 0) const uint BLOCK_SIZE = 32;
9
+ layout (constant_id = 1) const uint NUM_ROWS = 1;
9
10
10
- shared FLOAT_TYPE tmp[BLOCK_SIZE];
11
-
12
- void main() {
13
- const uint row = gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z;
14
-
15
- if (row >= p.stride_d) {
16
- return;
17
- }
11
+ shared FLOAT_TYPE tmpsh[NUM_ROWS][BLOCK_SIZE];
18
12
13
+ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
19
14
uint a_offset, b_offset, d_offset;
20
15
get_offsets(a_offset, b_offset, d_offset);
21
16
22
17
const uint num_blocks_per_row = p.ncols / QUANT_K;
23
- const uint ib0 = a_offset / QUANT_K + row*num_blocks_per_row;
24
18
25
19
// 16 threads are used to process each block
26
20
const uint it_size = gl_WorkGroupSize.x/16;
@@ -38,15 +32,15 @@ void main() {
38
32
const uint s_offset = 8*v_im;
39
33
const uint y_offset = 128*v_im + l0;
40
34
41
- FLOAT_TYPE temp = FLOAT_TYPE(0.0); // partial sum for thread in warp
35
+ FLOAT_TYPE temp[NUM_ROWS];
36
+
37
+ [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
38
+ temp[i] = FLOAT_TYPE(0);
39
+ }
42
40
43
41
[[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) {
44
42
const uint y_idx = i * QUANT_K + y_offset;
45
43
46
- f16vec2 d = data_a[ib0 + i].d;
47
- const FLOAT_TYPE dall = d.x;
48
- const FLOAT_TYPE dmin = d.y;
49
-
50
44
B_TYPE_VEC2 b0 = data_b_v2[(b_offset + y_idx) / 2 + 0];
51
45
B_TYPE_VEC2 b16 = data_b_v2[(b_offset + y_idx) / 2 + 8];
52
46
B_TYPE_VEC2 b32 = data_b_v2[(b_offset + y_idx) / 2 + 16];
@@ -56,58 +50,84 @@ void main() {
56
50
B_TYPE_VEC2 b96 = data_b_v2[(b_offset + y_idx) / 2 + 48];
57
51
B_TYPE_VEC2 b112 = data_b_v2[(b_offset + y_idx) / 2 + 56];
58
52
59
- uint32_t s0_u32 = data_a_packed32[ib0 + i].scales[s_offset / 4 + 0];
60
- uint32_t s4_u32 = data_a_packed32[ib0 + i].scales[s_offset / 4 + 1];
61
-
62
- uint32_t s0_lo4_u32 = s0_u32 & 0x0F0F0F0F;
63
- uint32_t s0_hi4_u32 = (s0_u32 >> 4) & 0x0F0F0F0F;
64
- uint32_t s4_lo4_u32 = s4_u32 & 0x0F0F0F0F;
65
- uint32_t s4_hi4_u32 = (s4_u32 >> 4) & 0x0F0F0F0F;
66
-
67
- uvec4 s0_lo4 = uvec4(unpack8(s0_lo4_u32));
68
- uvec4 s4_lo4 = uvec4(unpack8(s4_lo4_u32));
69
- uvec4 s0_hi4 = uvec4(unpack8(s0_hi4_u32));
70
- uvec4 s4_hi4 = uvec4(unpack8(s4_hi4_u32));
71
-
72
- uint16_t qs0_u16 = data_a_packed16[ib0 + i].qs[q_offset / 2 + 0];
73
- uint16_t qs16_u16 = data_a_packed16[ib0 + i].qs[q_offset / 2 + 8];
74
- uvec2 qs0 = uvec2(unpack8(qs0_u16));
75
- uvec2 qs16 = uvec2(unpack8(qs16_u16));
76
-
77
- FLOAT_TYPE sum1 = FLOAT_TYPE(0.0);
78
- FLOAT_TYPE sum2 = FLOAT_TYPE(0.0);
79
- [[unroll]] for (int l = 0; l < 2; ++l) {
80
- sum1 = fma(FLOAT_TYPE(b0[l]), FLOAT_TYPE(s0_lo4[0]) * FLOAT_TYPE((qs0[l] >> 0) & 3),
81
- fma(FLOAT_TYPE(b16[l]), FLOAT_TYPE(s0_lo4[1]) * FLOAT_TYPE((qs16[l] >> 0) & 3),
82
- fma(FLOAT_TYPE(b32[l]), FLOAT_TYPE(s0_lo4[2]) * FLOAT_TYPE((qs0[l] >> 2) & 3),
83
- fma(FLOAT_TYPE(b48[l]), FLOAT_TYPE(s0_lo4[3]) * FLOAT_TYPE((qs16[l] >> 2) & 3),
84
- fma(FLOAT_TYPE(b64[l]), FLOAT_TYPE(s4_lo4[0]) * FLOAT_TYPE((qs0[l] >> 4) & 3),
85
- fma(FLOAT_TYPE(b80[l]), FLOAT_TYPE(s4_lo4[1]) * FLOAT_TYPE((qs16[l] >> 4) & 3),
86
- fma(FLOAT_TYPE(b96[l]), FLOAT_TYPE(s4_lo4[2]) * FLOAT_TYPE((qs0[l] >> 6) & 3),
87
- fma(FLOAT_TYPE(b112[l]), FLOAT_TYPE(s4_lo4[3]) * FLOAT_TYPE((qs16[l] >> 6) & 3), sum1))))))));
88
- sum2 = fma(FLOAT_TYPE(b0[l]), FLOAT_TYPE(s0_hi4[0]),
89
- fma(FLOAT_TYPE(b16[l]), FLOAT_TYPE(s0_hi4[1]),
90
- fma(FLOAT_TYPE(b32[l]), FLOAT_TYPE(s0_hi4[2]),
91
- fma(FLOAT_TYPE(b48[l]), FLOAT_TYPE(s0_hi4[3]),
92
- fma(FLOAT_TYPE(b64[l]), FLOAT_TYPE(s4_hi4[0]),
93
- fma(FLOAT_TYPE(b80[l]), FLOAT_TYPE(s4_hi4[1]),
94
- fma(FLOAT_TYPE(b96[l]), FLOAT_TYPE(s4_hi4[2]),
95
- fma(FLOAT_TYPE(b112[l]), FLOAT_TYPE(s4_hi4[3]), sum2))))))));
53
+ [[unroll]] for (uint n = 0; n < num_rows; ++n) {
54
+ const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
55
+ f16vec2 d = data_a[ib0 + i].d;
56
+ const FLOAT_TYPE dall = d.x;
57
+ const FLOAT_TYPE dmin = d.y;
58
+
59
+ uint32_t s0_u32 = data_a_packed32[ib0 + i].scales[s_offset / 4 + 0];
60
+ uint32_t s4_u32 = data_a_packed32[ib0 + i].scales[s_offset / 4 + 1];
61
+
62
+ uint32_t s0_lo4_u32 = s0_u32 & 0x0F0F0F0F;
63
+ uint32_t s0_hi4_u32 = (s0_u32 >> 4) & 0x0F0F0F0F;
64
+ uint32_t s4_lo4_u32 = s4_u32 & 0x0F0F0F0F;
65
+ uint32_t s4_hi4_u32 = (s4_u32 >> 4) & 0x0F0F0F0F;
66
+
67
+ uvec4 s0_lo4 = uvec4(unpack8(s0_lo4_u32));
68
+ uvec4 s4_lo4 = uvec4(unpack8(s4_lo4_u32));
69
+ uvec4 s0_hi4 = uvec4(unpack8(s0_hi4_u32));
70
+ uvec4 s4_hi4 = uvec4(unpack8(s4_hi4_u32));
71
+
72
+ uint16_t qs0_u16 = data_a_packed16[ib0 + i].qs[q_offset / 2 + 0];
73
+ uint16_t qs16_u16 = data_a_packed16[ib0 + i].qs[q_offset / 2 + 8];
74
+ uvec2 qs0 = uvec2(unpack8(qs0_u16));
75
+ uvec2 qs16 = uvec2(unpack8(qs16_u16));
76
+
77
+ FLOAT_TYPE sum1 = FLOAT_TYPE(0.0);
78
+ FLOAT_TYPE sum2 = FLOAT_TYPE(0.0);
79
+ [[unroll]] for (int l = 0; l < 2; ++l) {
80
+ sum1 = fma(FLOAT_TYPE(b0[l]), FLOAT_TYPE(s0_lo4[0]) * FLOAT_TYPE((qs0[l] >> 0) & 3),
81
+ fma(FLOAT_TYPE(b16[l]), FLOAT_TYPE(s0_lo4[1]) * FLOAT_TYPE((qs16[l] >> 0) & 3),
82
+ fma(FLOAT_TYPE(b32[l]), FLOAT_TYPE(s0_lo4[2]) * FLOAT_TYPE((qs0[l] >> 2) & 3),
83
+ fma(FLOAT_TYPE(b48[l]), FLOAT_TYPE(s0_lo4[3]) * FLOAT_TYPE((qs16[l] >> 2) & 3),
84
+ fma(FLOAT_TYPE(b64[l]), FLOAT_TYPE(s4_lo4[0]) * FLOAT_TYPE((qs0[l] >> 4) & 3),
85
+ fma(FLOAT_TYPE(b80[l]), FLOAT_TYPE(s4_lo4[1]) * FLOAT_TYPE((qs16[l] >> 4) & 3),
86
+ fma(FLOAT_TYPE(b96[l]), FLOAT_TYPE(s4_lo4[2]) * FLOAT_TYPE((qs0[l] >> 6) & 3),
87
+ fma(FLOAT_TYPE(b112[l]), FLOAT_TYPE(s4_lo4[3]) * FLOAT_TYPE((qs16[l] >> 6) & 3), sum1))))))));
88
+ sum2 = fma(FLOAT_TYPE(b0[l]), FLOAT_TYPE(s0_hi4[0]),
89
+ fma(FLOAT_TYPE(b16[l]), FLOAT_TYPE(s0_hi4[1]),
90
+ fma(FLOAT_TYPE(b32[l]), FLOAT_TYPE(s0_hi4[2]),
91
+ fma(FLOAT_TYPE(b48[l]), FLOAT_TYPE(s0_hi4[3]),
92
+ fma(FLOAT_TYPE(b64[l]), FLOAT_TYPE(s4_hi4[0]),
93
+ fma(FLOAT_TYPE(b80[l]), FLOAT_TYPE(s4_hi4[1]),
94
+ fma(FLOAT_TYPE(b96[l]), FLOAT_TYPE(s4_hi4[2]),
95
+ fma(FLOAT_TYPE(b112[l]), FLOAT_TYPE(s4_hi4[3]), sum2))))))));
96
+ }
97
+ temp[n] = fma(dall, sum1, fma(-dmin, sum2, temp[n]));
96
98
}
97
- temp = fma(dall, sum1, fma(-dmin, sum2, temp));
98
99
}
99
100
100
- tmp[gl_LocalInvocationID.x] = temp;
101
-
102
101
// sum up partial sums and write back result
102
+ [[unroll]] for (uint n = 0; n < num_rows; ++n) {
103
+ tmpsh[n][tid] = temp[n];
104
+ }
103
105
barrier();
104
- [[unroll]] for (uint s = gl_WorkGroupSize.x /2; s > 0; s >>= 1) {
106
+ [[unroll]] for (uint s = BLOCK_SIZE /2; s > 0; s >>= 1) {
105
107
if (tid < s) {
106
- tmp[tid] += tmp[tid + s];
108
+ [[unroll]] for (uint n = 0; n < num_rows; ++n) {
109
+ tmpsh[n][tid] += tmpsh[n][tid + s];
110
+ }
107
111
}
108
112
barrier();
109
113
}
110
114
if (tid == 0) {
111
- data_d[d_offset + row] = D_TYPE(tmp[0]);
115
+ [[unroll]] for (uint n = 0; n < num_rows; ++n) {
116
+ data_d[d_offset + first_row + n] = D_TYPE(tmpsh[n][0]);
117
+ }
118
+ }
119
+ }
120
+
121
+ void main() {
122
+ const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z);
123
+
124
+ // do NUM_ROWS at a time, unless there aren't enough remaining rows
125
+ if (first_row + NUM_ROWS <= p.stride_d) {
126
+ compute_outputs(first_row, NUM_ROWS);
127
+ } else {
128
+ if (first_row >= p.stride_d) {
129
+ return;
130
+ }
131
+ compute_outputs(first_row, p.stride_d - first_row);
112
132
}
113
133
}
0 commit comments