Skip to content

Commit 1bacb9f

Browse files
authored
vulkan: further optimize mul_mat_vec using larger loads (#10387)
* vulkan: Use pipeline_robustness to disable robustness in mul_mat_vec. Add some early returns for nonexistent rows in mul_mat_vec shaders. These can only be hit when dispatching a 2D grid of workgroups. Fix the logic for the 2D grid of workgroups to round up. Enable the pipeline robustness extension if it's available, and use it to disable robustness for these pipelines. The instructions to do the bounds checking contend for the same ALU resources as the bit twiddling dequant instructions. * vulkan: Add GLSL structure aliases for quant types to allow larger loads In Vulkan it's not possible to cast pointer types, so instead you have to declare an aliased binding for the memory with a different type. This commit adds aliases for the quant formats using 16b ints, and in a few places where the struct size is a multiple of 4 also using 32b ints. Currently only q4_k's aliases are used, but others will be used in subsequent commits. * vulkan: use larger loads in q5_k and q6_k shaders. Similar to the optimization I did in q4_k recently, this vectorizes some loads and reduces the number of bit twiddling instructions. * vulkan: use larger K step per iteration in mul_mat_vec. Add vec4 dequantization functions, and use them to do K=8 per iteration in mul_mat_vec. This uses 16b loads for the quant values and 128b loads for B which helps reduce the load on the memory system. The K_PER_ITER==2 logic is still there, just for F16/F32, and really only because they support unaligned sizes. Tweak the num_iters/unrolling logic to be simpler and catch a couple missed unrolling opportunities.
1 parent ad21c9e commit 1bacb9f

11 files changed

+459
-149
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 63 additions & 38 deletions
Large diffs are not rendered by default.

ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,15 @@
22
#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require
33
#endif
44

5+
#include "types.comp"
6+
7+
#if defined(A_TYPE_PACKED16)
8+
layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];};
9+
#endif
10+
#if defined(A_TYPE_PACKED32)
11+
layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];};
12+
#endif
13+
514
#if defined(DATA_A_F32)
615
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
716
return vec2(data_a[a_offset + ib], data_a[a_offset + ib + 1]);
@@ -20,6 +29,11 @@ vec2 dequantize(uint ib, uint iqs, uint a_offset) {
2029
const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
2130
return (vec2(vui & 0xF, vui >> 4) - 8.0f) * d;
2231
}
32+
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
33+
const float d = float(data_a_packed16[a_offset + ib].d);
34+
const uint vui = uint(data_a_packed16[a_offset + ib].qs[iqs/2]);
35+
return (vec4(vui & 0xF, (vui >> 4) & 0xF, (vui >> 8) & 0xF, (vui >> 12) & 0xF) - 8.0f) * d;
36+
}
2337
#endif
2438

2539
#if defined(DATA_A_Q4_1)
@@ -29,6 +43,12 @@ vec2 dequantize(uint ib, uint iqs, uint a_offset) {
2943
const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
3044
return vec2(vui & 0xF, vui >> 4) * d + m;
3145
}
46+
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
47+
const float d = float(data_a_packed16[a_offset + ib].d);
48+
const float m = float(data_a_packed16[a_offset + ib].m);
49+
const uint vui = uint(data_a_packed16[a_offset + ib].qs[iqs/2]);
50+
return vec4(vui & 0xF, (vui >> 4) & 0xF, (vui >> 8) & 0xF, (vui >> 12) & 0xF) * d + m;
51+
}
3252
#endif
3353

3454
#if defined(DATA_A_Q5_0)
@@ -39,6 +59,14 @@ vec2 dequantize(uint ib, uint iqs, uint a_offset) {
3959
const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
4060
return (vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y) - 16.0f) * d;
4161
}
62+
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
63+
const float d = float(data_a_packed16[a_offset + ib].d);
64+
const uint uint_qh = uint(data_a_packed16[a_offset + ib].qh[1]) << 16 | data_a_packed16[a_offset + ib].qh[0];
65+
const ivec2 qh0 = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10);
66+
const ivec2 qh1 = ivec2(((uint_qh >> (iqs + 1)) << 4) & 0x10, (uint_qh >> (iqs + 13)) & 0x10);
67+
const uint vui = uint(data_a_packed16[a_offset + ib].qs[iqs/2]);
68+
return (vec4(((vui >> 0) & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, ((vui >> 12) & 0xF) | qh1.y) - 16.0f) * d;
69+
}
4270
#endif
4371

4472
#if defined(DATA_A_Q5_1)
@@ -50,13 +78,28 @@ vec2 dequantize(uint ib, uint iqs, uint a_offset) {
5078
const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
5179
return vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y) * d + m;
5280
}
81+
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
82+
const float d = float(data_a_packed16[a_offset + ib].d);
83+
const float m = float(data_a_packed16[a_offset + ib].m);
84+
const uint uint_qh = data_a_packed16[a_offset + ib].qh;
85+
const ivec2 qh0 = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10);
86+
const ivec2 qh1 = ivec2(((uint_qh >> (iqs + 1)) << 4) & 0x10, (uint_qh >> (iqs + 13)) & 0x10);
87+
const uint vui = uint(data_a_packed16[a_offset + ib].qs[iqs/2]);
88+
return vec4(((vui >> 0) & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, ((vui >> 12) & 0xF) | qh1.y) * d + m;
89+
}
5390
#endif
5491

5592
#if defined(DATA_A_Q8_0)
5693
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
5794
const float d = float(data_a[a_offset + ib].d);
5895
return vec2(int(data_a[a_offset + ib].qs[iqs]), int(data_a[a_offset + ib].qs[iqs + 1])) * d;
5996
}
97+
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
98+
const float d = float(data_a_packed16[a_offset + ib].d);
99+
uint32_t v0 = data_a_packed16[a_offset + ib].qs[iqs/2];
100+
uint32_t v1 = data_a_packed16[a_offset + ib].qs[iqs/2 + 1];
101+
return vec4(int8_t(v0 & 0xFF), int8_t((v0 >> 8) & 0xFF), int8_t(v1 & 0xFF), int8_t((v1 >> 8) & 0xFF)) * d;
102+
}
60103
#endif
61104

62105
#if defined(DATA_A_IQ4_NL)
@@ -65,4 +108,9 @@ vec2 dequantize(uint ib, uint iqs, uint a_offset) {
65108
const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
66109
return vec2(kvalues_iq4nl[vui & 0xF], kvalues_iq4nl[vui >> 4]) * d;
67110
}
111+
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
112+
const float d = float(data_a_packed16[a_offset + ib].d);
113+
const uint vui = uint(data_a_packed16[a_offset + ib].qs[iqs/2]);
114+
return vec4(kvalues_iq4nl[vui & 0xF], kvalues_iq4nl[(vui >> 4) & 0xF], kvalues_iq4nl[(vui >> 8) & 0xF], kvalues_iq4nl[(vui >> 12) & 0xF]) * d;
115+
}
68116
#endif

ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp

Lines changed: 72 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
#ifdef FLOAT16
44
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
55
#endif
6-
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
6+
#extension GL_EXT_shader_explicit_arithmetic_types : require
77

88
#include "mul_mat_vec_base.comp"
99

@@ -12,16 +12,48 @@ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
1212
layout (constant_id = 0) const uint BLOCK_SIZE = 32;
1313
layout (constant_id = 1) const uint NUM_ROWS = 1;
1414

15+
#if !defined(DATA_A_F32) && !defined(DATA_A_F16)
16+
#define K_PER_ITER 8
17+
#else
18+
#define K_PER_ITER 2
19+
#endif
20+
21+
1522
uint a_offset, b_offset, d_offset, y_offset;
1623

1724
shared FLOAT_TYPE tmpsh[NUM_ROWS][BLOCK_SIZE];
1825

1926
void iter(inout FLOAT_TYPE temp[NUM_ROWS], const uint first_row, const uint num_rows, const uint tid, const uint i, bool lastiter)
2027
{
21-
const uint col = i*BLOCK_SIZE + 2*tid;
28+
const uint col = i*BLOCK_SIZE + K_PER_ITER*tid;
2229
const uint iqs = (col%QUANT_K)/QUANT_R; // quant index
2330
const uint iybs = col - col%QUANT_K; // y block start index
2431

32+
#if K_PER_ITER == 8
33+
#if QUANT_R == 2
34+
B_TYPE_VEC4 bv02 = data_b_v4[(b_offset + iybs + iqs) / 4];
35+
B_TYPE_VEC4 bv13 = data_b_v4[(b_offset + iybs + iqs + y_offset) / 4];
36+
FLOAT_TYPE b0 = FLOAT_TYPE(bv02.x);
37+
FLOAT_TYPE b1 = FLOAT_TYPE(bv13.x);
38+
FLOAT_TYPE b2 = FLOAT_TYPE(bv02.y);
39+
FLOAT_TYPE b3 = FLOAT_TYPE(bv13.y);
40+
FLOAT_TYPE b4 = FLOAT_TYPE(bv02.z);
41+
FLOAT_TYPE b5 = FLOAT_TYPE(bv13.z);
42+
FLOAT_TYPE b6 = FLOAT_TYPE(bv02.w);
43+
FLOAT_TYPE b7 = FLOAT_TYPE(bv13.w);
44+
#else
45+
B_TYPE_VEC4 bv0 = data_b_v4[(b_offset + iybs + iqs) / 4];
46+
B_TYPE_VEC4 bv1 = data_b_v4[(b_offset + iybs + iqs) / 4 + 1];
47+
FLOAT_TYPE b0 = FLOAT_TYPE(bv0.x);
48+
FLOAT_TYPE b1 = FLOAT_TYPE(bv0.y);
49+
FLOAT_TYPE b2 = FLOAT_TYPE(bv0.z);
50+
FLOAT_TYPE b3 = FLOAT_TYPE(bv0.w);
51+
FLOAT_TYPE b4 = FLOAT_TYPE(bv1.x);
52+
FLOAT_TYPE b5 = FLOAT_TYPE(bv1.y);
53+
FLOAT_TYPE b6 = FLOAT_TYPE(bv1.z);
54+
FLOAT_TYPE b7 = FLOAT_TYPE(bv1.w);
55+
#endif
56+
#else
2557
// Check if the second of the pair of elements is OOB, and don't fetch B or
2658
// accumulate it. We still fetch a pair of elements for A, which is fine for
2759
// quantized formats since they'll be within the same block. We should
@@ -34,16 +66,32 @@ void iter(inout FLOAT_TYPE temp[NUM_ROWS], const uint first_row, const uint num_
3466
if (!OOB) {
3567
b1 = FLOAT_TYPE(data_b[b_offset + iybs + iqs + y_offset]);
3668
}
69+
#endif
3770
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
3871
const uint ib = ((first_row + n)*p.ncols + col)/QUANT_K; // block index
3972

73+
#if K_PER_ITER == 8
74+
const vec4 v = dequantize4(ib, iqs, a_offset);
75+
const vec4 v2 = dequantize4(ib, iqs+(4/QUANT_R), a_offset);
76+
77+
// matrix multiplication
78+
temp[n] = fma(FLOAT_TYPE(v.x), b0, temp[n]);
79+
temp[n] = fma(FLOAT_TYPE(v.y), b1, temp[n]);
80+
temp[n] = fma(FLOAT_TYPE(v.z), b2, temp[n]);
81+
temp[n] = fma(FLOAT_TYPE(v.w), b3, temp[n]);
82+
temp[n] = fma(FLOAT_TYPE(v2.x), b4, temp[n]);
83+
temp[n] = fma(FLOAT_TYPE(v2.y), b5, temp[n]);
84+
temp[n] = fma(FLOAT_TYPE(v2.z), b6, temp[n]);
85+
temp[n] = fma(FLOAT_TYPE(v2.w), b7, temp[n]);
86+
#else
4087
const vec2 v = dequantize(ib, iqs, a_offset);
4188

4289
// matrix multiplication
4390
temp[n] = fma(FLOAT_TYPE(v.x), b0, temp[n]);
4491
if (!OOB) {
4592
temp[n] = fma(FLOAT_TYPE(v.y), b1, temp[n]);
4693
}
94+
#endif
4795
}
4896
}
4997

@@ -61,22 +109,33 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
61109
temp[i] = FLOAT_TYPE(0);
62110
}
63111

64-
const int unroll_count = 8;
65-
66-
const uint num_iters = (p.ncols >= 2*tid) ? ((p.ncols - 2*tid + BLOCK_SIZE - 1) / BLOCK_SIZE) : 0;
67-
const uint unrolled_iters = num_iters & ~(2*unroll_count - 1);
112+
uint num_iters = p.ncols / (K_PER_ITER * BLOCK_SIZE);
113+
if (num_iters * K_PER_ITER * BLOCK_SIZE + K_PER_ITER*tid < p.ncols) {
114+
num_iters++;
115+
}
116+
int unroll_count = 4;
117+
uint unrolled_iters = num_iters & ~(unroll_count - 1);
68118

69119
uint i = 0;
70120
while (i < unrolled_iters) {
71121
// Manually partially unroll the loop
72122
[[unroll]] for (uint k = 0; k < unroll_count; ++k) {
73-
iter(temp, first_row, num_rows, tid, i, false);
74-
i += 2;
123+
iter(temp, first_row, num_rows, tid, i*K_PER_ITER, false);
124+
i++;
125+
}
126+
}
127+
unroll_count = 2;
128+
unrolled_iters = num_iters & ~(unroll_count - 1);
129+
while (i < unrolled_iters) {
130+
// Manually partially unroll the loop
131+
[[unroll]] for (uint k = 0; k < unroll_count; ++k) {
132+
iter(temp, first_row, num_rows, tid, i*K_PER_ITER, false);
133+
i++;
75134
}
76135
}
77136
while (i < num_iters) {
78-
iter(temp, first_row, num_rows, tid, i, true);
79-
i += 2;
137+
iter(temp, first_row, num_rows, tid, i*K_PER_ITER, true);
138+
i++;
80139
}
81140

82141
// sum up partial sums and write back result
@@ -106,6 +165,9 @@ void main() {
106165
if (first_row + NUM_ROWS <= p.stride_d) {
107166
compute_outputs(first_row, NUM_ROWS);
108167
} else {
168+
if (first_row >= p.stride_d) {
169+
return;
170+
}
109171
compute_outputs(first_row, p.stride_d - first_row);
110172
}
111173
}

ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@
1212

1313
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
1414
layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
15+
layout (binding = 1) readonly buffer BV2 {B_TYPE_VEC2 data_b_v2[];};
16+
layout (binding = 1) readonly buffer BV4 {B_TYPE_VEC4 data_b_v4[];};
17+
1518
layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
1619
#ifdef MUL_MAT_ID
1720
layout (binding = 3) readonly buffer IDS {int data_ids[];};

ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,10 @@ shared FLOAT_TYPE tmp[32];
99
void main() {
1010
const uint row = gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z;
1111

12+
if (row >= p.stride_d) {
13+
return;
14+
}
15+
1216
uint a_offset, b_offset, d_offset;
1317
get_offsets(a_offset, b_offset, d_offset);
1418

ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,10 @@ shared FLOAT_TYPE tmp[32];
99
void main() {
1010
const uint row = gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z;
1111

12+
if (row >= p.stride_d) {
13+
return;
14+
}
15+
1216
uint a_offset, b_offset, d_offset;
1317
get_offsets(a_offset, b_offset, d_offset);
1418

ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp

Lines changed: 9 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -8,30 +8,14 @@ layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
88

99
shared FLOAT_TYPE tmp[32];
1010

11-
// Declare aliased versions of A and B bindings that can use 16b/32b loads for
12-
// the quantized values, and vec4 loads for B.
13-
struct block_q4_K_u32
14-
{
15-
f16vec2 d;
16-
uint32_t scales[3*QUANT_K/64/4];
17-
uint32_t qs[QUANT_K/2/4];
18-
};
19-
20-
struct block_q4_K_u16
21-
{
22-
f16vec2 d;
23-
uint16_t scales[3*QUANT_K/64/2];
24-
uint16_t qs[QUANT_K/2/2];
25-
};
26-
27-
layout (binding = 0) readonly buffer A_u32 {block_q4_K_u32 data_a_u32[];};
28-
layout (binding = 0) readonly buffer A_u16 {block_q4_K_u16 data_a_u16[];};
29-
layout (binding = 1) readonly buffer BV4 {B_TYPE_VEC4 data_b_v4[];};
30-
3111
// This shader assumes K_QUANTS_PER_ITERATION == 2 for alignment of loads
3212
void main() {
3313
const uint row = gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z;
3414

15+
if (row >= p.stride_d) {
16+
return;
17+
}
18+
3519
uint a_offset, b_offset, d_offset;
3620
get_offsets(a_offset, b_offset, d_offset);
3721

@@ -64,9 +48,9 @@ void main() {
6448
const FLOAT_TYPE dall = FLOAT_TYPE(d.x);
6549
const FLOAT_TYPE dmin = FLOAT_TYPE(d.y);
6650

67-
uint32_t scale0_u32 = data_a_u16[ib0 + i].scales[v_im ];
68-
uint32_t scale4_u32 = data_a_u16[ib0 + i].scales[v_im + 2];
69-
uint32_t scale8_u32 = data_a_u16[ib0 + i].scales[v_im + 4];
51+
uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im ];
52+
uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2];
53+
uint32_t scale8_u32 = data_a_packed16[ib0 + i].scales[v_im + 4];
7054
uvec4 scale0 = uvec4(unpack8(scale0_u32));
7155
uvec4 scale4 = uvec4(unpack8(scale4_u32));
7256
uvec4 scale8 = uvec4(unpack8(scale8_u32));
@@ -80,8 +64,8 @@ void main() {
8064
const uint32_t sc6 = (((scale8.x >> 4) & 0x0f) | ((scale4.x & 0xc0) >> 2));
8165
const uint32_t sc7 = (((scale8.y >> 4) & 0x0f) | ((scale4.y & 0xc0) >> 2));
8266

83-
uint32_t qs0_u32 = data_a_u32[ib0 + i].qs[q_offset / 4];
84-
uint32_t qs64_u32 = data_a_u32[ib0 + i].qs[q_offset / 4 + 16];
67+
uint32_t qs0_u32 = data_a_packed32[ib0 + i].qs[q_offset / 4];
68+
uint32_t qs64_u32 = data_a_packed32[ib0 + i].qs[q_offset / 4 + 16];
8569

8670
uint32_t qs0_u32_lo4 = qs0_u32 & 0x0F0F0F0F;
8771
uint32_t qs0_u32_hi4 = (qs0_u32 >> 4) & 0x0F0F0F0F;

0 commit comments

Comments
 (0)