Skip to content

Commit fbeda90

Browse files
authored
vulkan: matmul dequantization improvements (#12015)
* faster dequant for old quants * dont use unpack for iq4_nl * vec2 unpack for q8
1 parent 581650b commit fbeda90

File tree

5 files changed

+99
-59
lines changed

5 files changed

+99
-59
lines changed

ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -82,9 +82,9 @@ vec2 dequantize(uint ib, uint iqs, uint a_offset) {
8282
return vec2(int(data_a[a_offset + ib].qs[iqs]), int(data_a[a_offset + ib].qs[iqs + 1]));
8383
}
8484
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
85-
uint32_t v0 = data_a_packed16[a_offset + ib].qs[iqs/2];
86-
uint32_t v1 = data_a_packed16[a_offset + ib].qs[iqs/2 + 1];
87-
return vec4(int8_t(v0 & 0xFF), int8_t(v0 >> 8), int8_t(v1 & 0xFF), int8_t(v1 >> 8));
85+
const i8vec2 v0 = unpack8(data_a_packed16[a_offset + ib].qs[iqs/2]);
86+
const i8vec2 v1 = unpack8(data_a_packed16[a_offset + ib].qs[iqs/2 + 1]);
87+
return vec4(v0.x, v0.y, v1.x, v1.y);
8888
}
8989
#endif
9090

ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ float16_t dequantFuncQ8_0(const in decodeBufQ8_0 bl, const in uint blockCoords[2
9292
const uint iqs = idx;
9393

9494
// Load 16b and select the byte for this element
95-
int32_t qs = unpack8(int32_t(bl.block.qs[(iqs & 0x1E) >> 1]))[iqs & 1];
95+
int32_t qs = unpack8(bl.block.qs[(iqs & 0x1E) >> 1])[iqs & 1];
9696
float16_t ret = float16_t(qs) * d;
9797
return ret;
9898
}

ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp

Lines changed: 86 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,13 @@
3232
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
3333

3434
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
35+
#if defined(A_TYPE_PACKED16)
36+
layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];};
37+
#endif
38+
#if defined(A_TYPE_PACKED32)
39+
layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];};
40+
#endif
41+
3542
layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
3643
layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
3744

@@ -243,74 +250,100 @@ void main() {
243250
#endif
244251
#elif defined(DATA_A_Q4_0)
245252
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
246-
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a;
247-
248-
const uint ib = idx / 16;
249-
const uint iqs = idx & 0xF;
250-
251-
const float d = float(data_a[ib].d);
252-
const uint vui = uint(data_a[ib].qs[iqs]);
253-
const vec2 v = (vec2(vui & 0xF, vui >> 4) - 8.0f) * d;
254-
255-
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
256-
buf_a[buf_idx + 16] = FLOAT_TYPE(v.y);
253+
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + 4 * loadr_a;
254+
255+
const uint ib = idx / 4;
256+
const uint iqs = idx & 0x03;
257+
258+
const float d = float(data_a_packed16[ib].d);
259+
const uint vui = uint(data_a_packed16[ib].qs[2*iqs]) | (uint(data_a_packed16[ib].qs[2*iqs + 1]) << 16);
260+
const vec4 v0 = (vec4(unpack8(vui & 0x0F0F0F0F)) - 8.0f) * d;
261+
const vec4 v1 = (vec4(unpack8((vui >> 4) & 0x0F0F0F0F)) - 8.0f) * d;
262+
263+
buf_a[buf_idx ] = FLOAT_TYPE(v0.x);
264+
buf_a[buf_idx + 1 ] = FLOAT_TYPE(v0.y);
265+
buf_a[buf_idx + 2 ] = FLOAT_TYPE(v0.z);
266+
buf_a[buf_idx + 3 ] = FLOAT_TYPE(v0.w);
267+
buf_a[buf_idx + 16] = FLOAT_TYPE(v1.x);
268+
buf_a[buf_idx + 17] = FLOAT_TYPE(v1.y);
269+
buf_a[buf_idx + 18] = FLOAT_TYPE(v1.z);
270+
buf_a[buf_idx + 19] = FLOAT_TYPE(v1.w);
257271
#elif defined(DATA_A_Q4_1)
258272
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
259-
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a;
260-
261-
const uint ib = idx / 16;
262-
const uint iqs = idx & 0xF;
263-
264-
const float d = float(data_a[ib].d);
265-
const float m = float(data_a[ib].m);
266-
const uint vui = uint(data_a[ib].qs[iqs]);
267-
const vec2 v = vec2(vui & 0xF, vui >> 4) * d + m;
268-
269-
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
270-
buf_a[buf_idx + 16] = FLOAT_TYPE(v.y);
273+
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + 4 * loadr_a;
274+
275+
const uint ib = idx / 4;
276+
const uint iqs = idx & 0x03;
277+
278+
const float d = float(data_a_packed16[ib].d);
279+
const float m = float(data_a_packed16[ib].m);
280+
const uint vui = uint(data_a_packed16[ib].qs[2*iqs]) | (uint(data_a_packed16[ib].qs[2*iqs + 1]) << 16);
281+
const vec4 v0 = vec4(unpack8(vui & 0x0F0F0F0F)) * d + m;
282+
const vec4 v1 = vec4(unpack8((vui >> 4) & 0x0F0F0F0F)) * d + m;
283+
284+
buf_a[buf_idx ] = FLOAT_TYPE(v0.x);
285+
buf_a[buf_idx + 1 ] = FLOAT_TYPE(v0.y);
286+
buf_a[buf_idx + 2 ] = FLOAT_TYPE(v0.z);
287+
buf_a[buf_idx + 3 ] = FLOAT_TYPE(v0.w);
288+
buf_a[buf_idx + 16] = FLOAT_TYPE(v1.x);
289+
buf_a[buf_idx + 17] = FLOAT_TYPE(v1.y);
290+
buf_a[buf_idx + 18] = FLOAT_TYPE(v1.z);
291+
buf_a[buf_idx + 19] = FLOAT_TYPE(v1.w);
271292
#elif defined(DATA_A_Q5_0)
272293
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
273-
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a;
294+
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + 2 * loadr_a;
274295

275-
const uint ib = idx / 16;
276-
const uint iqs = idx & 0xF;
296+
const uint ib = idx / 8;
297+
const uint iqs = idx & 0x07;
277298

278-
const float d = float(data_a[ib].d);
279-
const uint uint_qh = uint(data_a[ib].qh[1]) << 16 | data_a[ib].qh[0];
280-
const ivec2 qh = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10);
281-
const uint vui = uint(data_a[ib].qs[iqs]);
282-
const vec2 v = (vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y) - 16.0f) * d;
299+
const float d = float(data_a_packed16[ib].d);
300+
const uint uint_qh = uint(data_a_packed16[ib].qh[1]) << 16 | uint(data_a_packed16[ib].qh[0]);
301+
const ivec2 qh0 = ivec2(((uint_qh >> 2*iqs) << 4) & 0x10, (uint_qh >> (2*iqs + 12)) & 0x10);
302+
const ivec2 qh1 = ivec2(((uint_qh >> (2*iqs + 1)) << 4) & 0x10, (uint_qh >> (2*iqs + 13)) & 0x10);
303+
304+
const uint vui = uint(data_a_packed16[ib].qs[iqs]);
305+
const vec4 v = (vec4((vui & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, (vui >> 12) | qh1.y) - 16.0f) * d;
283306

284307
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
308+
buf_a[buf_idx + 1 ] = FLOAT_TYPE(v.z);
285309
buf_a[buf_idx + 16] = FLOAT_TYPE(v.y);
310+
buf_a[buf_idx + 17] = FLOAT_TYPE(v.w);
286311
#elif defined(DATA_A_Q5_1)
287312
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
288-
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a;
313+
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + 2 * loadr_a;
289314

290-
const uint ib = idx / 16;
291-
const uint iqs = idx & 0xF;
315+
const uint ib = idx / 8;
316+
const uint iqs = idx & 0x07;
292317

293-
const float d = float(data_a[ib].d);
294-
const float m = float(data_a[ib].m);
295-
const uint uint_qh = data_a[ib].qh;
296-
const ivec2 qh = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10);
297-
const uint vui = uint(data_a[ib].qs[iqs]);
298-
const vec2 v = vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y) * d + m;
318+
const float d = float(data_a_packed16[ib].d);
319+
const float m = float(data_a_packed16[ib].m);
320+
const uint uint_qh = data_a_packed16[ib].qh;
321+
const ivec2 qh0 = ivec2(((uint_qh >> 2*iqs) << 4) & 0x10, (uint_qh >> (2*iqs + 12)) & 0x10);
322+
const ivec2 qh1 = ivec2(((uint_qh >> (2*iqs + 1)) << 4) & 0x10, (uint_qh >> (2*iqs + 13)) & 0x10);
323+
324+
const uint vui = uint(data_a_packed16[ib].qs[iqs]);
325+
const vec4 v = vec4((vui & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, (vui >> 12) | qh1.y) * d + m;
299326

300327
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
328+
buf_a[buf_idx + 1 ] = FLOAT_TYPE(v.z);
301329
buf_a[buf_idx + 16] = FLOAT_TYPE(v.y);
330+
buf_a[buf_idx + 17] = FLOAT_TYPE(v.w);
302331
#elif defined(DATA_A_Q8_0)
303332
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
304333
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
305334

306-
const uint ib = idx / 16;
307-
const uint iqs = (idx & 0xF) * 2;
335+
const uint ib = idx / 8;
336+
const uint iqs = idx & 0x07;
308337

309-
const float d = float(data_a[ib].d);
310-
const vec2 v = vec2(int(data_a[ib].qs[iqs]), int(data_a[ib].qs[iqs + 1])) * d;
338+
const float d = float(data_a_packed16[ib].d);
339+
const i8vec2 v0 = unpack8(data_a_packed16[ib].qs[2*iqs]);
340+
const i8vec2 v1 = unpack8(data_a_packed16[ib].qs[2*iqs + 1]);
341+
const vec4 v = vec4(v0.x, v0.y, v1.x, v1.y) * d;
311342

312343
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
313344
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
345+
buf_a[buf_idx + 2] = FLOAT_TYPE(v.z);
346+
buf_a[buf_idx + 3] = FLOAT_TYPE(v.w);
314347
#elif defined(DATA_A_Q2_K)
315348
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
316349
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
@@ -623,17 +656,18 @@ void main() {
623656
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
624657
#elif defined(DATA_A_IQ4_NL)
625658
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
626-
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a;
659+
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + 2 * loadr_a;
627660

628-
const uint ib = idx / 16;
629-
const uint iqs = idx & 0xF;
661+
const uint ib = idx / 8;
662+
const uint iqs = idx & 0x07;
630663

631-
const float d = float(data_a[ib].d);
632-
const uint vui = uint(data_a[ib].qs[iqs]);
633-
const vec2 v = vec2(kvalues_iq4nl[vui & 0xF], kvalues_iq4nl[vui >> 4]) * d;
664+
const FLOAT_TYPE d = FLOAT_TYPE(data_a_packed16[ib].d);
665+
const uint vui = uint(data_a_packed16[ib].qs[iqs]);
634666

635-
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
636-
buf_a[buf_idx + 16] = FLOAT_TYPE(v.y);
667+
buf_a[buf_idx ] = FLOAT_TYPE(kvalues_iq4nl[vui & 0xF]) * d;
668+
buf_a[buf_idx + 1 ] = FLOAT_TYPE(kvalues_iq4nl[bitfieldExtract(vui, 8, 4)]) * d;
669+
buf_a[buf_idx + 16] = FLOAT_TYPE(kvalues_iq4nl[bitfieldExtract(vui, 4, 4)]) * d;
670+
buf_a[buf_idx + 17] = FLOAT_TYPE(kvalues_iq4nl[vui >> 12]) * d;
637671
#endif
638672
}
639673
[[unroll]] for (uint l = 0; l < BN; l += loadstride_b) {

ggml/src/ggml-vulkan/vulkan-shaders/types.comp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ struct block_q8_0
139139
struct block_q8_0_packed16
140140
{
141141
float16_t d;
142-
uint16_t qs[32/2];
142+
int16_t qs[32/2];
143143
};
144144

145145
#if defined(DATA_A_Q8_0)

ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -325,11 +325,17 @@ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool
325325
string_to_spv(shader_name + "_f16", source_name, merge_maps(base_dict, {{"DATA_A_F16", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
326326

327327
for (const auto& tname : type_names) {
328+
std::string load_vec_quant = "2";
329+
if ((tname == "q4_0") || (tname == "q4_1"))
330+
load_vec_quant = "8";
331+
else if ((tname == "q5_0") || (tname == "q5_1") || (tname == "q8_0") || (tname == "iq4_nl"))
332+
load_vec_quant = "4";
333+
328334
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
329335
// For unaligned, load one at a time for f32/f16, or two at a time for quants
330-
std::string load_vec_a_unaligned = (coopmat2 || tname == "f32" || tname == "f16") ? "1" : "2";
336+
std::string load_vec_a_unaligned = (coopmat2 || tname == "f32" || tname == "f16") ? "1" : load_vec_quant;
331337
// For aligned matmul loads
332-
std::string load_vec_a = (coopmat2 || tname == "f32" || tname == "f16") ? load_vec : "2";
338+
std::string load_vec_a = (coopmat2 || tname == "f32" || tname == "f16") ? load_vec : load_vec_quant;
333339

334340
// don't generate f32 variants for coopmat2
335341
if (!coopmat2) {

0 commit comments

Comments
 (0)