Skip to content

Commit f5d1932

Browse files
jeffbolznvmglambda
authored andcommitted
vulkan: optimize mul_mat for small values of N (ggml-org#10991)
Make the mul_mat_vec shaders support N>1 (as a spec constant, NUM_COLS) where the batch_strides are overloaded to hold the row strides. Put the loads from the B matrix in the innermost loop because it should cache better. Share some code for reducing the result values to memory in mul_mat_vec_base.
1 parent 66ccd3e commit f5d1932

File tree

9 files changed

+290
-351
lines changed

9 files changed

+290
-351
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 51 additions & 38 deletions
Large diffs are not rendered by default.

ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp

Lines changed: 53 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,6 @@
99

1010
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
1111

12-
layout (constant_id = 0) const uint BLOCK_SIZE = 32;
13-
layout (constant_id = 1) const uint NUM_ROWS = 1;
14-
1512
#if !defined(DATA_A_F32) && !defined(DATA_A_F16)
1613
#define K_PER_ITER 8
1714
#else
@@ -21,70 +18,70 @@ layout (constant_id = 1) const uint NUM_ROWS = 1;
2118

2219
uint a_offset, b_offset, d_offset, y_offset;
2320

24-
shared FLOAT_TYPE tmpsh[NUM_ROWS][BLOCK_SIZE];
25-
26-
void iter(inout FLOAT_TYPE temp[NUM_ROWS], const uint first_row, const uint num_rows, const uint tid, const uint i, bool lastiter)
21+
void iter(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const uint num_rows, const uint tid, const uint i, bool lastiter)
2722
{
28-
const uint col = i*BLOCK_SIZE + K_PER_ITER*tid;
29-
const uint iqs = (col%QUANT_K)/QUANT_R; // quant index
30-
const uint iybs = col - col%QUANT_K; // y block start index
23+
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
24+
const uint col = i*BLOCK_SIZE + K_PER_ITER*tid;
25+
const uint iqs = (col%QUANT_K)/QUANT_R; // quant index
26+
const uint iybs = col - col%QUANT_K; // y block start index
3127

3228
#if K_PER_ITER == 8
3329
#if QUANT_R == 2
34-
const B_TYPE_VEC4 bv02 = data_b_v4[(b_offset + iybs + iqs) / 4];
35-
const B_TYPE_VEC4 bv13 = data_b_v4[(b_offset + iybs + iqs + y_offset) / 4];
36-
const vec4 bv0 = vec4(bv02.x, bv13.x, bv02.y, bv13.y);
37-
const vec4 bv1 = vec4(bv02.z, bv13.z, bv02.w, bv13.w);
30+
const B_TYPE_VEC4 bv02 = data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs) / 4];
31+
const B_TYPE_VEC4 bv13 = data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs + y_offset) / 4];
32+
const vec4 bv0 = vec4(bv02.x, bv13.x, bv02.y, bv13.y);
33+
const vec4 bv1 = vec4(bv02.z, bv13.z, bv02.w, bv13.w);
3834
#else
39-
const vec4 bv0 = vec4(data_b_v4[(b_offset + iybs + iqs) / 4]);
40-
const vec4 bv1 = vec4(data_b_v4[(b_offset + iybs + iqs) / 4 + 1]);
35+
const vec4 bv0 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs) / 4]);
36+
const vec4 bv1 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs) / 4 + 1]);
4137
#endif
4238
#else
43-
// Check if the second of the pair of elements is OOB, and don't fetch B or
44-
// accumulate it. We still fetch a pair of elements for A, which is fine for
45-
// quantized formats since they'll be within the same block. We should
46-
// probably skip fetching the second element for F16/F32, but as of now we
47-
// still do.
48-
const bool OOB = lastiter && (iybs + iqs + y_offset >= p.ncols);
49-
50-
FLOAT_TYPE b0 = 0, b1 = 0;
51-
b0 = FLOAT_TYPE(data_b[b_offset + iybs + iqs]);
52-
if (!OOB) {
53-
b1 = FLOAT_TYPE(data_b[b_offset + iybs + iqs + y_offset]);
54-
}
39+
// Check if the second of the pair of elements is OOB, and don't fetch B or
40+
// accumulate it. We still fetch a pair of elements for A, which is fine for
41+
// quantized formats since they'll be within the same block. We should
42+
// probably skip fetching the second element for F16/F32, but as of now we
43+
// still do.
44+
const bool OOB = lastiter && (iybs + iqs + y_offset >= p.ncols);
45+
46+
FLOAT_TYPE b0 = 0, b1 = 0;
47+
b0 = FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs]);
48+
if (!OOB) {
49+
b1 = FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs + y_offset]);
50+
}
5551
#endif
56-
uint ibi = first_row*p.ncols;
57-
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
58-
const uint ib = (ibi + col)/QUANT_K; // block index
59-
ibi += p.ncols;
52+
uint ibi = first_row*p.ncols;
53+
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
54+
const uint ib = (ibi + col)/QUANT_K; // block index
55+
ibi += p.ncols;
6056

6157
#if K_PER_ITER == 8
62-
vec4 v = dequantize4(ib, iqs, a_offset);
63-
vec4 v2 = dequantize4(ib, iqs+(4/QUANT_R), a_offset);
58+
vec4 v = dequantize4(ib, iqs, a_offset);
59+
vec4 v2 = dequantize4(ib, iqs+(4/QUANT_R), a_offset);
6460

65-
const vec2 dm = get_dm(ib, a_offset);
66-
if (dm.y != 0) { // quant has min component
67-
v = v * dm.x + dm.y;
68-
v2 = v2 * dm.x + dm.y;
69-
}
61+
const vec2 dm = get_dm(ib, a_offset);
62+
if (dm.y != 0) { // quant has min component
63+
v = v * dm.x + dm.y;
64+
v2 = v2 * dm.x + dm.y;
65+
}
7066

71-
// matrix multiplication
72-
FLOAT_TYPE rowtmp = dot(bv0, v);
73-
rowtmp += dot(bv1, v2);
67+
// matrix multiplication
68+
FLOAT_TYPE rowtmp = dot(bv0, v);
69+
rowtmp += dot(bv1, v2);
7470

75-
if (dm.y == 0)
76-
rowtmp *= dm.x;
71+
if (dm.y == 0)
72+
rowtmp *= dm.x;
7773

78-
temp[n] += rowtmp;
74+
temp[j][n] += rowtmp;
7975
#else
80-
const vec2 v = dequantize(ib, iqs, a_offset);
76+
const vec2 v = dequantize(ib, iqs, a_offset);
8177

82-
// matrix multiplication
83-
temp[n] = fma(FLOAT_TYPE(v.x), b0, temp[n]);
84-
if (!OOB) {
85-
temp[n] = fma(FLOAT_TYPE(v.y), b1, temp[n]);
86-
}
78+
// matrix multiplication
79+
temp[j][n] = fma(FLOAT_TYPE(v.x), b0, temp[j][n]);
80+
if (!OOB) {
81+
temp[j][n] = fma(FLOAT_TYPE(v.y), b1, temp[j][n]);
82+
}
8783
#endif
84+
}
8885
}
8986
}
9087

@@ -96,10 +93,12 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
9693

9794
y_offset = QUANT_R == 1 ? 1 : QUANT_K/2;
9895

99-
FLOAT_TYPE temp[NUM_ROWS];
96+
FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
10097

101-
for (uint i = 0; i < NUM_ROWS; ++i) {
102-
temp[i] = FLOAT_TYPE(0);
98+
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
99+
[[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
100+
temp[j][i] = FLOAT_TYPE(0);
101+
}
103102
}
104103

105104
uint num_iters = p.ncols / (K_PER_ITER * BLOCK_SIZE);
@@ -131,24 +130,7 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
131130
i++;
132131
}
133132

134-
// sum up partial sums and write back result
135-
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
136-
tmpsh[n][tid] = temp[n];
137-
}
138-
barrier();
139-
[[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) {
140-
if (tid < s) {
141-
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
142-
tmpsh[n][tid] += tmpsh[n][tid + s];
143-
}
144-
}
145-
barrier();
146-
}
147-
if (tid == 0) {
148-
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
149-
data_d[d_offset + first_row + n] = D_TYPE(tmpsh[n][0]);
150-
}
151-
}
133+
reduce_result(temp, d_offset, first_row, num_rows, tid);
152134
}
153135

154136
void main() {

ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,3 +83,36 @@ void get_offsets(out uint a_offset, out uint b_offset, out uint d_offset) {
8383
batch_idx * p.batch_stride_d;
8484
#endif
8585
}
86+
87+
layout (constant_id = 0) const uint BLOCK_SIZE = 32;
88+
layout (constant_id = 1) const uint NUM_ROWS = 1;
89+
layout (constant_id = 2) const uint NUM_COLS = 1;
90+
91+
shared FLOAT_TYPE tmpsh[NUM_COLS][NUM_ROWS][BLOCK_SIZE];
92+
93+
void reduce_result(const in FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t d_offset, const in uint32_t first_row, const in uint32_t num_rows, const in uint32_t tid) {
94+
// sum up partial sums and write back result
95+
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
96+
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
97+
tmpsh[j][n][tid] = temp[j][n];
98+
}
99+
}
100+
barrier();
101+
[[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) {
102+
if (tid < s) {
103+
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
104+
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
105+
tmpsh[j][n][tid] += tmpsh[j][n][tid + s];
106+
}
107+
}
108+
}
109+
barrier();
110+
}
111+
if (tid == 0) {
112+
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
113+
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
114+
data_d[j*p.batch_stride_d + d_offset + first_row + n] = D_TYPE(tmpsh[j][n][0]);
115+
}
116+
}
117+
}
118+
}

ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp

Lines changed: 37 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,6 @@
55

66
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
77

8-
layout (constant_id = 0) const uint BLOCK_SIZE = 32;
9-
layout (constant_id = 1) const uint NUM_ROWS = 1;
10-
11-
shared FLOAT_TYPE tmpsh[NUM_ROWS][BLOCK_SIZE];
12-
138
void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
149
uint a_offset, b_offset, d_offset;
1510
get_offsets(a_offset, b_offset, d_offset);
@@ -32,24 +27,17 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
3227
const uint s_offset = 8*v_im;
3328
const uint y_offset = 128*v_im + l0;
3429

35-
FLOAT_TYPE temp[NUM_ROWS];
30+
FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
3631

37-
[[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
38-
temp[i] = FLOAT_TYPE(0);
32+
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
33+
[[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
34+
temp[j][i] = FLOAT_TYPE(0);
35+
}
3936
}
4037

4138
[[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) {
4239
const uint y_idx = i * QUANT_K + y_offset;
4340

44-
B_TYPE_VEC2 b0 = data_b_v2[(b_offset + y_idx) / 2 + 0];
45-
B_TYPE_VEC2 b16 = data_b_v2[(b_offset + y_idx) / 2 + 8];
46-
B_TYPE_VEC2 b32 = data_b_v2[(b_offset + y_idx) / 2 + 16];
47-
B_TYPE_VEC2 b48 = data_b_v2[(b_offset + y_idx) / 2 + 24];
48-
B_TYPE_VEC2 b64 = data_b_v2[(b_offset + y_idx) / 2 + 32];
49-
B_TYPE_VEC2 b80 = data_b_v2[(b_offset + y_idx) / 2 + 40];
50-
B_TYPE_VEC2 b96 = data_b_v2[(b_offset + y_idx) / 2 + 48];
51-
B_TYPE_VEC2 b112 = data_b_v2[(b_offset + y_idx) / 2 + 56];
52-
5341
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
5442
const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
5543
f16vec2 d = data_a[ib0 + i].d;
@@ -74,48 +62,42 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
7462
uvec2 qs0 = uvec2(unpack8(qs0_u16));
7563
uvec2 qs16 = uvec2(unpack8(qs16_u16));
7664

77-
FLOAT_TYPE sum1 = FLOAT_TYPE(0.0);
78-
FLOAT_TYPE sum2 = FLOAT_TYPE(0.0);
79-
[[unroll]] for (int l = 0; l < 2; ++l) {
80-
sum1 = fma(FLOAT_TYPE(b0[l]), FLOAT_TYPE(s0_lo4[0]) * FLOAT_TYPE((qs0[l] >> 0) & 3),
81-
fma(FLOAT_TYPE(b16[l]), FLOAT_TYPE(s0_lo4[1]) * FLOAT_TYPE((qs16[l] >> 0) & 3),
82-
fma(FLOAT_TYPE(b32[l]), FLOAT_TYPE(s0_lo4[2]) * FLOAT_TYPE((qs0[l] >> 2) & 3),
83-
fma(FLOAT_TYPE(b48[l]), FLOAT_TYPE(s0_lo4[3]) * FLOAT_TYPE((qs16[l] >> 2) & 3),
84-
fma(FLOAT_TYPE(b64[l]), FLOAT_TYPE(s4_lo4[0]) * FLOAT_TYPE((qs0[l] >> 4) & 3),
85-
fma(FLOAT_TYPE(b80[l]), FLOAT_TYPE(s4_lo4[1]) * FLOAT_TYPE((qs16[l] >> 4) & 3),
86-
fma(FLOAT_TYPE(b96[l]), FLOAT_TYPE(s4_lo4[2]) * FLOAT_TYPE((qs0[l] >> 6) & 3),
87-
fma(FLOAT_TYPE(b112[l]), FLOAT_TYPE(s4_lo4[3]) * FLOAT_TYPE((qs16[l] >> 6) & 3), sum1))))))));
88-
sum2 = fma(FLOAT_TYPE(b0[l]), FLOAT_TYPE(s0_hi4[0]),
89-
fma(FLOAT_TYPE(b16[l]), FLOAT_TYPE(s0_hi4[1]),
90-
fma(FLOAT_TYPE(b32[l]), FLOAT_TYPE(s0_hi4[2]),
91-
fma(FLOAT_TYPE(b48[l]), FLOAT_TYPE(s0_hi4[3]),
92-
fma(FLOAT_TYPE(b64[l]), FLOAT_TYPE(s4_hi4[0]),
93-
fma(FLOAT_TYPE(b80[l]), FLOAT_TYPE(s4_hi4[1]),
94-
fma(FLOAT_TYPE(b96[l]), FLOAT_TYPE(s4_hi4[2]),
95-
fma(FLOAT_TYPE(b112[l]), FLOAT_TYPE(s4_hi4[3]), sum2))))))));
65+
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
66+
B_TYPE_VEC2 b0 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 0];
67+
B_TYPE_VEC2 b16 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 8];
68+
B_TYPE_VEC2 b32 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 16];
69+
B_TYPE_VEC2 b48 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 24];
70+
B_TYPE_VEC2 b64 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 32];
71+
B_TYPE_VEC2 b80 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 40];
72+
B_TYPE_VEC2 b96 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 48];
73+
B_TYPE_VEC2 b112 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 56];
74+
75+
FLOAT_TYPE sum1 = FLOAT_TYPE(0.0);
76+
FLOAT_TYPE sum2 = FLOAT_TYPE(0.0);
77+
[[unroll]] for (int l = 0; l < 2; ++l) {
78+
sum1 = fma(FLOAT_TYPE(b0[l]), FLOAT_TYPE(s0_lo4[0]) * FLOAT_TYPE((qs0[l] >> 0) & 3),
79+
fma(FLOAT_TYPE(b16[l]), FLOAT_TYPE(s0_lo4[1]) * FLOAT_TYPE((qs16[l] >> 0) & 3),
80+
fma(FLOAT_TYPE(b32[l]), FLOAT_TYPE(s0_lo4[2]) * FLOAT_TYPE((qs0[l] >> 2) & 3),
81+
fma(FLOAT_TYPE(b48[l]), FLOAT_TYPE(s0_lo4[3]) * FLOAT_TYPE((qs16[l] >> 2) & 3),
82+
fma(FLOAT_TYPE(b64[l]), FLOAT_TYPE(s4_lo4[0]) * FLOAT_TYPE((qs0[l] >> 4) & 3),
83+
fma(FLOAT_TYPE(b80[l]), FLOAT_TYPE(s4_lo4[1]) * FLOAT_TYPE((qs16[l] >> 4) & 3),
84+
fma(FLOAT_TYPE(b96[l]), FLOAT_TYPE(s4_lo4[2]) * FLOAT_TYPE((qs0[l] >> 6) & 3),
85+
fma(FLOAT_TYPE(b112[l]), FLOAT_TYPE(s4_lo4[3]) * FLOAT_TYPE((qs16[l] >> 6) & 3), sum1))))))));
86+
sum2 = fma(FLOAT_TYPE(b0[l]), FLOAT_TYPE(s0_hi4[0]),
87+
fma(FLOAT_TYPE(b16[l]), FLOAT_TYPE(s0_hi4[1]),
88+
fma(FLOAT_TYPE(b32[l]), FLOAT_TYPE(s0_hi4[2]),
89+
fma(FLOAT_TYPE(b48[l]), FLOAT_TYPE(s0_hi4[3]),
90+
fma(FLOAT_TYPE(b64[l]), FLOAT_TYPE(s4_hi4[0]),
91+
fma(FLOAT_TYPE(b80[l]), FLOAT_TYPE(s4_hi4[1]),
92+
fma(FLOAT_TYPE(b96[l]), FLOAT_TYPE(s4_hi4[2]),
93+
fma(FLOAT_TYPE(b112[l]), FLOAT_TYPE(s4_hi4[3]), sum2))))))));
94+
}
95+
temp[j][n] = fma(dall, sum1, fma(-dmin, sum2, temp[j][n]));
9696
}
97-
temp[n] = fma(dall, sum1, fma(-dmin, sum2, temp[n]));
9897
}
9998
}
10099

101-
// sum up partial sums and write back result
102-
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
103-
tmpsh[n][tid] = temp[n];
104-
}
105-
barrier();
106-
[[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) {
107-
if (tid < s) {
108-
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
109-
tmpsh[n][tid] += tmpsh[n][tid + s];
110-
}
111-
}
112-
barrier();
113-
}
114-
if (tid == 0) {
115-
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
116-
data_d[d_offset + first_row + n] = D_TYPE(tmpsh[n][0]);
117-
}
118-
}
100+
reduce_result(temp, d_offset, first_row, num_rows, tid);
119101
}
120102

121103
void main() {

0 commit comments

Comments
 (0)