@@ -10,6 +10,7 @@ layout (constant_id = 0) const uint BLOCK_SIZE = 32;
10
10
layout (constant_id = 1) const uint NUM_ROWS = 1;
11
11
12
12
shared FLOAT_TYPE tmpsh[NUM_ROWS][BLOCK_SIZE];
13
+ shared block_q6_K_packed16 blkcache[BLOCK_SIZE/16];
13
14
14
15
void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
15
16
uint a_offset, b_offset, d_offset;
@@ -20,13 +21,11 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
20
21
// 16 threads are used to process each block
21
22
const uint it_size = gl_WorkGroupSize.x/16;
22
23
const uint tid = gl_LocalInvocationID.x;
23
- const uint itid = tid%16; // 0...16
24
+ const uint itid = tid%16; // 0...15
24
25
const uint ix = tid/16;
25
26
26
- const uint step = 8;
27
-
28
- const uint v_im = itid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
29
- const uint v_in = itid - step*v_im; // 0...15 or 0...7
27
+ const uint v_im = itid/8; // 0 or 1. 0 computes 0..., 1 computes 128...
28
+ const uint v_in = itid - 8*v_im; // 0...15 or 0...7
30
29
31
30
const uint l0 = 4 * v_in; // 0, 4, 8, ..., 28
32
31
const uint is = v_in / 4;
@@ -50,28 +49,33 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
50
49
B_TYPE_VEC4 by64 = data_b_v4[(b_offset + y_idx) / 4 + 16];
51
50
B_TYPE_VEC4 by96 = data_b_v4[(b_offset + y_idx) / 4 + 24];
52
51
52
+ uint ibi = first_row*num_blocks_per_row;
53
53
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
54
- const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row ;
55
- const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d) ;
54
+ const uint ib0 = a_offset / QUANT_K + ibi ;
55
+ ibi += num_blocks_per_row ;
56
56
57
- FLOAT_TYPE scales[4];
58
- scales[0] = FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 0]);
59
- scales[1] = FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 2]);
60
- scales[2] = FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 4]);
61
- scales[3] = FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 6]);
57
+ // cache full superblock into shared memory with coalesced reads
58
+ [[unroll]] for (int l = 0; l < 4; ++l)
59
+ blkcache[ix].ql[itid + 16*l] = data_a_packed16[ib0 + i].ql[itid + 16*l];
60
+ [[unroll]] for (int l = 0; l < 2; ++l)
61
+ blkcache[ix].qh[itid + 16*l] = data_a_packed16[ib0 + i].qh[itid + 16*l];
62
+ blkcache[ix].scales[itid] = data_a_packed16[ib0 + i].scales[itid];
63
+ barrier();
62
64
63
- uint32_t ql0_u32 = uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 1]) << 16);
64
- uint32_t ql32_u32 = uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 16]) | (uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 17]) << 16);
65
+ const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d);
66
+
67
+ uint32_t ql0_u32 = uint32_t(blkcache[ix].ql[ql_offset / 2]) | (uint32_t(blkcache[ix].ql[ql_offset / 2 + 1]) << 16);
68
+ uint32_t ql32_u32 = uint32_t(blkcache[ix].ql[ql_offset / 2 + 16]) | (uint32_t(blkcache[ix].ql[ql_offset / 2 + 17]) << 16);
65
69
66
70
uint32_t ql0_u32_lo4 = ql0_u32 & 0x0F0F0F0F;
67
71
uint32_t ql0_u32_hi4 = (ql0_u32 >> 4) & 0x0F0F0F0F;
68
72
uint32_t ql32_u32_lo4 = ql32_u32 & 0x0F0F0F0F;
69
73
uint32_t ql32_u32_hi4 = (ql32_u32 >> 4) & 0x0F0F0F0F;
70
74
71
- uint32_t qh_u32 = uint32_t(data_a_packed16[ib0 + i ].qh[qh_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i ].qh[qh_offset / 2 + 1]) << 16);
75
+ uint32_t qh_u32 = uint32_t(blkcache[ix ].qh[qh_offset / 2]) | (uint32_t(blkcache[ix ].qh[qh_offset / 2 + 1]) << 16);
72
76
uint32_t qh0_u32 = (qh_u32 & 0x03030303) << 4;
73
77
uint32_t qh2_u32 = (qh_u32 & 0x0C0C0C0C) << 2;
74
- uint32_t qh4_u32 = (qh_u32 & 0x30303030) << 0 ;
78
+ uint32_t qh4_u32 = (qh_u32 & 0x30303030);
75
79
uint32_t qh6_u32 = (qh_u32 & 0xC0C0C0C0) >> 2;
76
80
77
81
uint32_t q0_u32 = ql0_u32_lo4 | qh0_u32;
@@ -84,14 +88,17 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
84
88
uvec4 q2 = uvec4(unpack8(q2_u32));
85
89
uvec4 q3 = uvec4(unpack8(q3_u32));
86
90
87
- FLOAT_TYPE sum = FLOAT_TYPE(0.0) ;
91
+ FLOAT_TYPE sum[4] = {0, 0, 0, 0} ;
88
92
[[unroll]] for (int l = 0; l < 4; ++l) {
89
- sum = fma(FLOAT_TYPE(by0[l]) * scales[0] , FLOAT_TYPE(int8_t(q0[l]) - 32),
90
- fma(FLOAT_TYPE(by32[l]) * scales[1] , FLOAT_TYPE(int8_t(q1[l]) - 32),
91
- fma(FLOAT_TYPE(by64[l]) * scales[2] , FLOAT_TYPE(int8_t(q2[l]) - 32),
92
- fma(FLOAT_TYPE(by96[l]) * scales[3] , FLOAT_TYPE(int8_t(q3[l]) - 32), sum))) );
93
+ sum[0] = fma(FLOAT_TYPE(by0[l]), FLOAT_TYPE(int8_t(q0[l]) - 32), sum[0]);
94
+ sum[1] = fma(FLOAT_TYPE(by32[l]), FLOAT_TYPE(int8_t(q1[l]) - 32), sum[1]);
95
+ sum[2] = fma(FLOAT_TYPE(by64[l]), FLOAT_TYPE(int8_t(q2[l]) - 32), sum[2]);
96
+ sum[3] = fma(FLOAT_TYPE(by96[l]), FLOAT_TYPE(int8_t(q3[l]) - 32), sum[3] );
93
97
}
94
- temp[n] += sum * d;
98
+
99
+ [[unroll]] for (int l = 0; l < 4; ++l)
100
+ sum[l] *= FLOAT_TYPE(blkcache[ix].scales[s_offset + l*2]);
101
+ temp[n] += (sum[0] + sum[1] + sum[2] + sum[3]) * d;
95
102
}
96
103
}
97
104
0 commit comments