Skip to content

Commit 158ab15

Browse files
committed
q6_k extract scale
small stuff scale cache there we go float type
1 parent d79d8f3 commit 158ab15

File tree

1 file changed

+29
-22
lines changed

1 file changed

+29
-22
lines changed

ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp

Lines changed: 29 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ layout (constant_id = 0) const uint BLOCK_SIZE = 32;
1010
layout (constant_id = 1) const uint NUM_ROWS = 1;
1111

1212
shared FLOAT_TYPE tmpsh[NUM_ROWS][BLOCK_SIZE];
13+
shared block_q6_K_packed16 blkcache[BLOCK_SIZE/16];
1314

1415
void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
1516
uint a_offset, b_offset, d_offset;
@@ -20,13 +21,11 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
2021
// 16 threads are used to process each block
2122
const uint it_size = gl_WorkGroupSize.x/16;
2223
const uint tid = gl_LocalInvocationID.x;
23-
const uint itid = tid%16; // 0...16
24+
const uint itid = tid%16; // 0...15
2425
const uint ix = tid/16;
2526

26-
const uint step = 8;
27-
28-
const uint v_im = itid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
29-
const uint v_in = itid - step*v_im; // 0...15 or 0...7
27+
const uint v_im = itid/8; // 0 or 1. 0 computes 0..., 1 computes 128...
28+
const uint v_in = itid - 8*v_im; // 0...15 or 0...7
3029

3130
const uint l0 = 4 * v_in; // 0, 4, 8, ..., 28
3231
const uint is = v_in / 4;
@@ -50,28 +49,33 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
5049
B_TYPE_VEC4 by64 = data_b_v4[(b_offset + y_idx) / 4 + 16];
5150
B_TYPE_VEC4 by96 = data_b_v4[(b_offset + y_idx) / 4 + 24];
5251

52+
uint ibi = first_row*num_blocks_per_row;
5353
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
54-
const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
55-
const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d);
54+
const uint ib0 = a_offset / QUANT_K + ibi;
55+
ibi += num_blocks_per_row;
5656

57-
FLOAT_TYPE scales[4];
58-
scales[0] = FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 0]);
59-
scales[1] = FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 2]);
60-
scales[2] = FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 4]);
61-
scales[3] = FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 6]);
57+
// cache full superblock into shared memory with coalesced reads
58+
[[unroll]] for (int l = 0; l < 4; ++l)
59+
blkcache[ix].ql[itid + 16*l] = data_a_packed16[ib0 + i].ql[itid + 16*l];
60+
[[unroll]] for (int l = 0; l < 2; ++l)
61+
blkcache[ix].qh[itid + 16*l] = data_a_packed16[ib0 + i].qh[itid + 16*l];
62+
blkcache[ix].scales[itid] = data_a_packed16[ib0 + i].scales[itid];
63+
barrier();
6264

63-
uint32_t ql0_u32 = uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 1]) << 16);
64-
uint32_t ql32_u32 = uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 16]) | (uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 17]) << 16);
65+
const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d);
66+
67+
uint32_t ql0_u32 = uint32_t(blkcache[ix].ql[ql_offset / 2]) | (uint32_t(blkcache[ix].ql[ql_offset / 2 + 1]) << 16);
68+
uint32_t ql32_u32 = uint32_t(blkcache[ix].ql[ql_offset / 2 + 16]) | (uint32_t(blkcache[ix].ql[ql_offset / 2 + 17]) << 16);
6569

6670
uint32_t ql0_u32_lo4 = ql0_u32 & 0x0F0F0F0F;
6771
uint32_t ql0_u32_hi4 = (ql0_u32 >> 4) & 0x0F0F0F0F;
6872
uint32_t ql32_u32_lo4 = ql32_u32 & 0x0F0F0F0F;
6973
uint32_t ql32_u32_hi4 = (ql32_u32 >> 4) & 0x0F0F0F0F;
7074

71-
uint32_t qh_u32 = uint32_t(data_a_packed16[ib0 + i].qh[qh_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].qh[qh_offset / 2 + 1]) << 16);
75+
uint32_t qh_u32 = uint32_t(blkcache[ix].qh[qh_offset / 2]) | (uint32_t(blkcache[ix].qh[qh_offset / 2 + 1]) << 16);
7276
uint32_t qh0_u32 = (qh_u32 & 0x03030303) << 4;
7377
uint32_t qh2_u32 = (qh_u32 & 0x0C0C0C0C) << 2;
74-
uint32_t qh4_u32 = (qh_u32 & 0x30303030) << 0;
78+
uint32_t qh4_u32 = (qh_u32 & 0x30303030);
7579
uint32_t qh6_u32 = (qh_u32 & 0xC0C0C0C0) >> 2;
7680

7781
uint32_t q0_u32 = ql0_u32_lo4 | qh0_u32;
@@ -84,14 +88,17 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
8488
uvec4 q2 = uvec4(unpack8(q2_u32));
8589
uvec4 q3 = uvec4(unpack8(q3_u32));
8690

87-
FLOAT_TYPE sum = FLOAT_TYPE(0.0);
91+
FLOAT_TYPE sum[4] = {0, 0, 0, 0};
8892
[[unroll]] for (int l = 0; l < 4; ++l) {
89-
sum = fma(FLOAT_TYPE(by0[l]) * scales[0], FLOAT_TYPE(int8_t(q0[l]) - 32),
90-
fma(FLOAT_TYPE(by32[l]) * scales[1], FLOAT_TYPE(int8_t(q1[l]) - 32),
91-
fma(FLOAT_TYPE(by64[l]) * scales[2], FLOAT_TYPE(int8_t(q2[l]) - 32),
92-
fma(FLOAT_TYPE(by96[l]) * scales[3], FLOAT_TYPE(int8_t(q3[l]) - 32), sum))));
93+
sum[0] = fma(FLOAT_TYPE(by0[l]), FLOAT_TYPE(int8_t(q0[l]) - 32), sum[0]);
94+
sum[1] = fma(FLOAT_TYPE(by32[l]), FLOAT_TYPE(int8_t(q1[l]) - 32), sum[1]);
95+
sum[2] = fma(FLOAT_TYPE(by64[l]), FLOAT_TYPE(int8_t(q2[l]) - 32), sum[2]);
96+
sum[3] = fma(FLOAT_TYPE(by96[l]), FLOAT_TYPE(int8_t(q3[l]) - 32), sum[3]);
9397
}
94-
temp[n] += sum * d;
98+
99+
[[unroll]] for (int l = 0; l < 4; ++l)
100+
sum[l] *= FLOAT_TYPE(blkcache[ix].scales[s_offset + l*2]);
101+
temp[n] += (sum[0] + sum[1] + sum[2] + sum[3]) * d;
95102
}
96103
}
97104

0 commit comments

Comments
 (0)