Skip to content

Commit 0734b2f

Browse files
netrunnerevemglambda
authored andcommitted
vulkan: multi-row k quants (ggml-org#10846)
* multi row k quant shaders! * better row selection * more row choices * readjust row selection * rm_kq=2 by default
1 parent e981e66 commit 0734b2f

File tree

6 files changed

+477
-372
lines changed

6 files changed

+477
-372
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 43 additions & 38 deletions
Large diffs are not rendered by default.

ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp

Lines changed: 77 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -6,21 +6,15 @@
66
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
77

88
layout (constant_id = 0) const uint BLOCK_SIZE = 32;
9+
layout (constant_id = 1) const uint NUM_ROWS = 1;
910

10-
shared FLOAT_TYPE tmp[BLOCK_SIZE];
11-
12-
void main() {
13-
const uint row = gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z;
14-
15-
if (row >= p.stride_d) {
16-
return;
17-
}
11+
shared FLOAT_TYPE tmpsh[NUM_ROWS][BLOCK_SIZE];
1812

13+
void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
1914
uint a_offset, b_offset, d_offset;
2015
get_offsets(a_offset, b_offset, d_offset);
2116

2217
const uint num_blocks_per_row = p.ncols / QUANT_K;
23-
const uint ib0 = a_offset / QUANT_K + row*num_blocks_per_row;
2418

2519
// 16 threads are used to process each block
2620
const uint it_size = gl_WorkGroupSize.x/16;
@@ -38,15 +32,15 @@ void main() {
3832
const uint s_offset = 8*v_im;
3933
const uint y_offset = 128*v_im + l0;
4034

41-
FLOAT_TYPE temp = FLOAT_TYPE(0.0); // partial sum for thread in warp
35+
FLOAT_TYPE temp[NUM_ROWS];
36+
37+
[[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
38+
temp[i] = FLOAT_TYPE(0);
39+
}
4240

4341
[[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) {
4442
const uint y_idx = i * QUANT_K + y_offset;
4543

46-
f16vec2 d = data_a[ib0 + i].d;
47-
const FLOAT_TYPE dall = d.x;
48-
const FLOAT_TYPE dmin = d.y;
49-
5044
B_TYPE_VEC2 b0 = data_b_v2[(b_offset + y_idx) / 2 + 0];
5145
B_TYPE_VEC2 b16 = data_b_v2[(b_offset + y_idx) / 2 + 8];
5246
B_TYPE_VEC2 b32 = data_b_v2[(b_offset + y_idx) / 2 + 16];
@@ -56,58 +50,84 @@ void main() {
5650
B_TYPE_VEC2 b96 = data_b_v2[(b_offset + y_idx) / 2 + 48];
5751
B_TYPE_VEC2 b112 = data_b_v2[(b_offset + y_idx) / 2 + 56];
5852

59-
uint32_t s0_u32 = data_a_packed32[ib0 + i].scales[s_offset / 4 + 0];
60-
uint32_t s4_u32 = data_a_packed32[ib0 + i].scales[s_offset / 4 + 1];
61-
62-
uint32_t s0_lo4_u32 = s0_u32 & 0x0F0F0F0F;
63-
uint32_t s0_hi4_u32 = (s0_u32 >> 4) & 0x0F0F0F0F;
64-
uint32_t s4_lo4_u32 = s4_u32 & 0x0F0F0F0F;
65-
uint32_t s4_hi4_u32 = (s4_u32 >> 4) & 0x0F0F0F0F;
66-
67-
uvec4 s0_lo4 = uvec4(unpack8(s0_lo4_u32));
68-
uvec4 s4_lo4 = uvec4(unpack8(s4_lo4_u32));
69-
uvec4 s0_hi4 = uvec4(unpack8(s0_hi4_u32));
70-
uvec4 s4_hi4 = uvec4(unpack8(s4_hi4_u32));
71-
72-
uint16_t qs0_u16 = data_a_packed16[ib0 + i].qs[q_offset / 2 + 0];
73-
uint16_t qs16_u16 = data_a_packed16[ib0 + i].qs[q_offset / 2 + 8];
74-
uvec2 qs0 = uvec2(unpack8(qs0_u16));
75-
uvec2 qs16 = uvec2(unpack8(qs16_u16));
76-
77-
FLOAT_TYPE sum1 = FLOAT_TYPE(0.0);
78-
FLOAT_TYPE sum2 = FLOAT_TYPE(0.0);
79-
[[unroll]] for (int l = 0; l < 2; ++l) {
80-
sum1 = fma(FLOAT_TYPE(b0[l]), FLOAT_TYPE(s0_lo4[0]) * FLOAT_TYPE((qs0[l] >> 0) & 3),
81-
fma(FLOAT_TYPE(b16[l]), FLOAT_TYPE(s0_lo4[1]) * FLOAT_TYPE((qs16[l] >> 0) & 3),
82-
fma(FLOAT_TYPE(b32[l]), FLOAT_TYPE(s0_lo4[2]) * FLOAT_TYPE((qs0[l] >> 2) & 3),
83-
fma(FLOAT_TYPE(b48[l]), FLOAT_TYPE(s0_lo4[3]) * FLOAT_TYPE((qs16[l] >> 2) & 3),
84-
fma(FLOAT_TYPE(b64[l]), FLOAT_TYPE(s4_lo4[0]) * FLOAT_TYPE((qs0[l] >> 4) & 3),
85-
fma(FLOAT_TYPE(b80[l]), FLOAT_TYPE(s4_lo4[1]) * FLOAT_TYPE((qs16[l] >> 4) & 3),
86-
fma(FLOAT_TYPE(b96[l]), FLOAT_TYPE(s4_lo4[2]) * FLOAT_TYPE((qs0[l] >> 6) & 3),
87-
fma(FLOAT_TYPE(b112[l]), FLOAT_TYPE(s4_lo4[3]) * FLOAT_TYPE((qs16[l] >> 6) & 3), sum1))))))));
88-
sum2 = fma(FLOAT_TYPE(b0[l]), FLOAT_TYPE(s0_hi4[0]),
89-
fma(FLOAT_TYPE(b16[l]), FLOAT_TYPE(s0_hi4[1]),
90-
fma(FLOAT_TYPE(b32[l]), FLOAT_TYPE(s0_hi4[2]),
91-
fma(FLOAT_TYPE(b48[l]), FLOAT_TYPE(s0_hi4[3]),
92-
fma(FLOAT_TYPE(b64[l]), FLOAT_TYPE(s4_hi4[0]),
93-
fma(FLOAT_TYPE(b80[l]), FLOAT_TYPE(s4_hi4[1]),
94-
fma(FLOAT_TYPE(b96[l]), FLOAT_TYPE(s4_hi4[2]),
95-
fma(FLOAT_TYPE(b112[l]), FLOAT_TYPE(s4_hi4[3]), sum2))))))));
53+
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
54+
const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
55+
f16vec2 d = data_a[ib0 + i].d;
56+
const FLOAT_TYPE dall = d.x;
57+
const FLOAT_TYPE dmin = d.y;
58+
59+
uint32_t s0_u32 = data_a_packed32[ib0 + i].scales[s_offset / 4 + 0];
60+
uint32_t s4_u32 = data_a_packed32[ib0 + i].scales[s_offset / 4 + 1];
61+
62+
uint32_t s0_lo4_u32 = s0_u32 & 0x0F0F0F0F;
63+
uint32_t s0_hi4_u32 = (s0_u32 >> 4) & 0x0F0F0F0F;
64+
uint32_t s4_lo4_u32 = s4_u32 & 0x0F0F0F0F;
65+
uint32_t s4_hi4_u32 = (s4_u32 >> 4) & 0x0F0F0F0F;
66+
67+
uvec4 s0_lo4 = uvec4(unpack8(s0_lo4_u32));
68+
uvec4 s4_lo4 = uvec4(unpack8(s4_lo4_u32));
69+
uvec4 s0_hi4 = uvec4(unpack8(s0_hi4_u32));
70+
uvec4 s4_hi4 = uvec4(unpack8(s4_hi4_u32));
71+
72+
uint16_t qs0_u16 = data_a_packed16[ib0 + i].qs[q_offset / 2 + 0];
73+
uint16_t qs16_u16 = data_a_packed16[ib0 + i].qs[q_offset / 2 + 8];
74+
uvec2 qs0 = uvec2(unpack8(qs0_u16));
75+
uvec2 qs16 = uvec2(unpack8(qs16_u16));
76+
77+
FLOAT_TYPE sum1 = FLOAT_TYPE(0.0);
78+
FLOAT_TYPE sum2 = FLOAT_TYPE(0.0);
79+
[[unroll]] for (int l = 0; l < 2; ++l) {
80+
sum1 = fma(FLOAT_TYPE(b0[l]), FLOAT_TYPE(s0_lo4[0]) * FLOAT_TYPE((qs0[l] >> 0) & 3),
81+
fma(FLOAT_TYPE(b16[l]), FLOAT_TYPE(s0_lo4[1]) * FLOAT_TYPE((qs16[l] >> 0) & 3),
82+
fma(FLOAT_TYPE(b32[l]), FLOAT_TYPE(s0_lo4[2]) * FLOAT_TYPE((qs0[l] >> 2) & 3),
83+
fma(FLOAT_TYPE(b48[l]), FLOAT_TYPE(s0_lo4[3]) * FLOAT_TYPE((qs16[l] >> 2) & 3),
84+
fma(FLOAT_TYPE(b64[l]), FLOAT_TYPE(s4_lo4[0]) * FLOAT_TYPE((qs0[l] >> 4) & 3),
85+
fma(FLOAT_TYPE(b80[l]), FLOAT_TYPE(s4_lo4[1]) * FLOAT_TYPE((qs16[l] >> 4) & 3),
86+
fma(FLOAT_TYPE(b96[l]), FLOAT_TYPE(s4_lo4[2]) * FLOAT_TYPE((qs0[l] >> 6) & 3),
87+
fma(FLOAT_TYPE(b112[l]), FLOAT_TYPE(s4_lo4[3]) * FLOAT_TYPE((qs16[l] >> 6) & 3), sum1))))))));
88+
sum2 = fma(FLOAT_TYPE(b0[l]), FLOAT_TYPE(s0_hi4[0]),
89+
fma(FLOAT_TYPE(b16[l]), FLOAT_TYPE(s0_hi4[1]),
90+
fma(FLOAT_TYPE(b32[l]), FLOAT_TYPE(s0_hi4[2]),
91+
fma(FLOAT_TYPE(b48[l]), FLOAT_TYPE(s0_hi4[3]),
92+
fma(FLOAT_TYPE(b64[l]), FLOAT_TYPE(s4_hi4[0]),
93+
fma(FLOAT_TYPE(b80[l]), FLOAT_TYPE(s4_hi4[1]),
94+
fma(FLOAT_TYPE(b96[l]), FLOAT_TYPE(s4_hi4[2]),
95+
fma(FLOAT_TYPE(b112[l]), FLOAT_TYPE(s4_hi4[3]), sum2))))))));
96+
}
97+
temp[n] = fma(dall, sum1, fma(-dmin, sum2, temp[n]));
9698
}
97-
temp = fma(dall, sum1, fma(-dmin, sum2, temp));
9899
}
99100

100-
tmp[gl_LocalInvocationID.x] = temp;
101-
102101
// sum up partial sums and write back result
102+
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
103+
tmpsh[n][tid] = temp[n];
104+
}
103105
barrier();
104-
[[unroll]] for (uint s = gl_WorkGroupSize.x/2; s > 0; s >>= 1) {
106+
[[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) {
105107
if (tid < s) {
106-
tmp[tid] += tmp[tid + s];
108+
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
109+
tmpsh[n][tid] += tmpsh[n][tid + s];
110+
}
107111
}
108112
barrier();
109113
}
110114
if (tid == 0) {
111-
data_d[d_offset + row] = D_TYPE(tmp[0]);
115+
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
116+
data_d[d_offset + first_row + n] = D_TYPE(tmpsh[n][0]);
117+
}
118+
}
119+
}
120+
121+
void main() {
122+
const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z);
123+
124+
// do NUM_ROWS at a time, unless there aren't enough remaining rows
125+
if (first_row + NUM_ROWS <= p.stride_d) {
126+
compute_outputs(first_row, NUM_ROWS);
127+
} else {
128+
if (first_row >= p.stride_d) {
129+
return;
130+
}
131+
compute_outputs(first_row, p.stride_d - first_row);
112132
}
113133
}

ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp

Lines changed: 62 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -6,21 +6,15 @@
66
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
77

88
layout (constant_id = 0) const uint BLOCK_SIZE = 32;
9+
layout (constant_id = 1) const uint NUM_ROWS = 1;
910

10-
shared FLOAT_TYPE tmp[BLOCK_SIZE];
11-
12-
void main() {
13-
const uint row = gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z;
14-
15-
if (row >= p.stride_d) {
16-
return;
17-
}
11+
shared FLOAT_TYPE tmpsh[NUM_ROWS][BLOCK_SIZE];
1812

13+
void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
1914
uint a_offset, b_offset, d_offset;
2015
get_offsets(a_offset, b_offset, d_offset);
2116

2217
const uint num_blocks_per_row = p.ncols / QUANT_K;
23-
const uint ib0 = a_offset / QUANT_K + row*num_blocks_per_row;
2418

2519
// 16 threads are used to process each block
2620
const uint it_size = gl_WorkGroupSize.x/16;
@@ -35,19 +29,21 @@ void main() {
3529

3630
const uint8_t m = uint8_t(1 << (4 * v_im));
3731

38-
const uint l0 = 2*v_in; // 0...15
32+
const uint l0 = 2*v_in; // 0...15
3933
const uint q_offset = 32*v_im + l0;
4034
const uint y_offset = 128*v_im + l0;
4135

42-
FLOAT_TYPE temp = FLOAT_TYPE(0.0); // partial sum for thread in warp
36+
FLOAT_TYPE temp[NUM_ROWS];
37+
38+
[[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
39+
temp[i] = FLOAT_TYPE(0);
40+
}
4341

4442
const uint s_shift = 4 * v_im;
4543

4644
[[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) {
4745
const uint y_idx = i * QUANT_K + y_offset;
4846

49-
const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d);
50-
5147
B_TYPE_VEC2 b0 = data_b_v2[(b_offset + y_idx) / 2 + 0];
5248
B_TYPE_VEC2 b16 = data_b_v2[(b_offset + y_idx) / 2 + 8];
5349
B_TYPE_VEC2 b32 = data_b_v2[(b_offset + y_idx) / 2 + 16];
@@ -57,44 +53,68 @@ void main() {
5753
B_TYPE_VEC2 b96 = data_b_v2[(b_offset + y_idx) / 2 + 48];
5854
B_TYPE_VEC2 b112 = data_b_v2[(b_offset + y_idx) / 2 + 56];
5955

60-
uint16_t s0_16 = data_a_packed16[ib0 + i].scales[0];
61-
uint16_t s2_16 = data_a_packed16[ib0 + i].scales[1];
62-
uint16_t s4_16 = data_a_packed16[ib0 + i].scales[2];
63-
uint16_t s6_16 = data_a_packed16[ib0 + i].scales[3];
64-
uint16_t s8_16 = data_a_packed16[ib0 + i].scales[4];
65-
uint16_t s10_16 = data_a_packed16[ib0 + i].scales[5];
66-
u8vec2 s0 = unpack8(s0_16);
67-
u8vec2 s2 = unpack8(s2_16);
68-
u8vec2 s4 = unpack8(s4_16);
69-
u8vec2 s6 = unpack8(s6_16);
70-
u8vec2 s8 = unpack8(s8_16);
71-
u8vec2 s10 = unpack8(s10_16);
72-
73-
FLOAT_TYPE sum = FLOAT_TYPE(0.0);
74-
[[unroll]] for (int l = 0; l < 2; ++l) {
75-
sum = fma(FLOAT_TYPE(b0[l]) * FLOAT_TYPE(int8_t(((s0[0] >> s_shift) & 0xF) | ((s8[0] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] ) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 0)) != 0) ? 0 : 4)),
76-
fma(FLOAT_TYPE(b32[l]) * FLOAT_TYPE(int8_t(((s2[0] >> s_shift) & 0xF) | ((s10[0] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 2) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 1)) != 0) ? 0 : 4)),
77-
fma(FLOAT_TYPE(b64[l]) * FLOAT_TYPE(int8_t(((s4[0] >> s_shift) & 0xF) | ((s8[0] >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 4) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 2)) != 0) ? 0 : 4)),
78-
fma(FLOAT_TYPE(b96[l]) * FLOAT_TYPE(int8_t(((s6[0] >> s_shift) & 0xF) | ((s10[0] >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 6) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 3)) != 0) ? 0 : 4)),
79-
fma(FLOAT_TYPE(b16[l]) * FLOAT_TYPE(int8_t(((s0[1] >> s_shift) & 0xF) | ((s8[1] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] ) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 0)) != 0) ? 0 : 4)),
80-
fma(FLOAT_TYPE(b48[l]) * FLOAT_TYPE(int8_t(((s2[1] >> s_shift) & 0xF) | ((s10[1] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 2) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 1)) != 0) ? 0 : 4)),
81-
fma(FLOAT_TYPE(b80[l]) * FLOAT_TYPE(int8_t(((s4[1] >> s_shift) & 0xF) | ((s8[1] >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 4) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 2)) != 0) ? 0 : 4)),
82-
fma(FLOAT_TYPE(b112[l]) * FLOAT_TYPE(int8_t(((s6[1] >> s_shift) & 0xF) | ((s10[1] >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 6) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 3)) != 0) ? 0 : 4)), sum))))))));
56+
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
57+
const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
58+
const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d);
59+
60+
uint16_t s0_16 = data_a_packed16[ib0 + i].scales[0];
61+
uint16_t s2_16 = data_a_packed16[ib0 + i].scales[1];
62+
uint16_t s4_16 = data_a_packed16[ib0 + i].scales[2];
63+
uint16_t s6_16 = data_a_packed16[ib0 + i].scales[3];
64+
uint16_t s8_16 = data_a_packed16[ib0 + i].scales[4];
65+
uint16_t s10_16 = data_a_packed16[ib0 + i].scales[5];
66+
u8vec2 s0 = unpack8(s0_16);
67+
u8vec2 s2 = unpack8(s2_16);
68+
u8vec2 s4 = unpack8(s4_16);
69+
u8vec2 s6 = unpack8(s6_16);
70+
u8vec2 s8 = unpack8(s8_16);
71+
u8vec2 s10 = unpack8(s10_16);
72+
73+
FLOAT_TYPE sum = FLOAT_TYPE(0.0);
74+
[[unroll]] for (int l = 0; l < 2; ++l) {
75+
sum = fma(FLOAT_TYPE(b0[l]) * FLOAT_TYPE(int8_t(((s0[0] >> s_shift) & 0xF) | ((s8[0] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] ) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 0)) != 0) ? 0 : 4)),
76+
fma(FLOAT_TYPE(b32[l]) * FLOAT_TYPE(int8_t(((s2[0] >> s_shift) & 0xF) | ((s10[0] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 2) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 1)) != 0) ? 0 : 4)),
77+
fma(FLOAT_TYPE(b64[l]) * FLOAT_TYPE(int8_t(((s4[0] >> s_shift) & 0xF) | ((s8[0] >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 4) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 2)) != 0) ? 0 : 4)),
78+
fma(FLOAT_TYPE(b96[l]) * FLOAT_TYPE(int8_t(((s6[0] >> s_shift) & 0xF) | ((s10[0] >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 6) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 3)) != 0) ? 0 : 4)),
79+
fma(FLOAT_TYPE(b16[l]) * FLOAT_TYPE(int8_t(((s0[1] >> s_shift) & 0xF) | ((s8[1] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] ) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 0)) != 0) ? 0 : 4)),
80+
fma(FLOAT_TYPE(b48[l]) * FLOAT_TYPE(int8_t(((s2[1] >> s_shift) & 0xF) | ((s10[1] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 2) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 1)) != 0) ? 0 : 4)),
81+
fma(FLOAT_TYPE(b80[l]) * FLOAT_TYPE(int8_t(((s4[1] >> s_shift) & 0xF) | ((s8[1] >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 4) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 2)) != 0) ? 0 : 4)),
82+
fma(FLOAT_TYPE(b112[l]) * FLOAT_TYPE(int8_t(((s6[1] >> s_shift) & 0xF) | ((s10[1] >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 6) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 3)) != 0) ? 0 : 4)), sum))))))));
83+
}
84+
temp[n] = fma(d, sum, temp[n]);
8385
}
84-
temp = fma(d, sum, temp);
8586
}
8687

87-
tmp[gl_LocalInvocationID.x] = temp;
88-
8988
// sum up partial sums and write back result
89+
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
90+
tmpsh[n][tid] = temp[n];
91+
}
9092
barrier();
91-
[[unroll]] for (uint s = gl_WorkGroupSize.x/2; s > 0; s >>= 1) {
93+
[[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) {
9294
if (tid < s) {
93-
tmp[tid] += tmp[tid + s];
95+
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
96+
tmpsh[n][tid] += tmpsh[n][tid + s];
97+
}
9498
}
9599
barrier();
96100
}
97101
if (tid == 0) {
98-
data_d[d_offset + row] = D_TYPE(tmp[0]);
102+
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
103+
data_d[d_offset + first_row + n] = D_TYPE(tmpsh[n][0]);
104+
}
105+
}
106+
}
107+
108+
void main() {
109+
const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z);
110+
111+
// do NUM_ROWS at a time, unless there aren't enough remaining rows
112+
if (first_row + NUM_ROWS <= p.stride_d) {
113+
compute_outputs(first_row, NUM_ROWS);
114+
} else {
115+
if (first_row >= p.stride_d) {
116+
return;
117+
}
118+
compute_outputs(first_row, p.stride_d - first_row);
99119
}
100120
}

0 commit comments

Comments
 (0)