|
32 | 32 | layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
33 | 33 |
|
34 | 34 | layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
|
| 35 | +#if defined(A_TYPE_PACKED16) |
| 36 | +layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];}; |
| 37 | +#endif |
| 38 | +#if defined(A_TYPE_PACKED32) |
| 39 | +layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];}; |
| 40 | +#endif |
| 41 | + |
35 | 42 | layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
|
36 | 43 | layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
|
37 | 44 |
|
@@ -243,74 +250,100 @@ void main() {
|
243 | 250 | #endif
|
244 | 251 | #elif defined(DATA_A_Q4_0)
|
245 | 252 | const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
246 |
| - const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a; |
247 |
| - |
248 |
| - const uint ib = idx / 16; |
249 |
| - const uint iqs = idx & 0xF; |
250 |
| - |
251 |
| - const float d = float(data_a[ib].d); |
252 |
| - const uint vui = uint(data_a[ib].qs[iqs]); |
253 |
| - const vec2 v = (vec2(vui & 0xF, vui >> 4) - 8.0f) * d; |
254 |
| - |
255 |
| - buf_a[buf_idx ] = FLOAT_TYPE(v.x); |
256 |
| - buf_a[buf_idx + 16] = FLOAT_TYPE(v.y); |
| 253 | + const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + 4 * loadr_a; |
| 254 | + |
| 255 | + const uint ib = idx / 4; |
| 256 | + const uint iqs = idx & 0x03; |
| 257 | + |
| 258 | + const float d = float(data_a_packed16[ib].d); |
| 259 | + const uint vui = uint(data_a_packed16[ib].qs[2*iqs]) | (uint(data_a_packed16[ib].qs[2*iqs + 1]) << 16); |
| 260 | + const vec4 v0 = (vec4(unpack8(vui & 0x0F0F0F0F)) - 8.0f) * d; |
| 261 | + const vec4 v1 = (vec4(unpack8((vui >> 4) & 0x0F0F0F0F)) - 8.0f) * d; |
| 262 | + |
| 263 | + buf_a[buf_idx ] = FLOAT_TYPE(v0.x); |
| 264 | + buf_a[buf_idx + 1 ] = FLOAT_TYPE(v0.y); |
| 265 | + buf_a[buf_idx + 2 ] = FLOAT_TYPE(v0.z); |
| 266 | + buf_a[buf_idx + 3 ] = FLOAT_TYPE(v0.w); |
| 267 | + buf_a[buf_idx + 16] = FLOAT_TYPE(v1.x); |
| 268 | + buf_a[buf_idx + 17] = FLOAT_TYPE(v1.y); |
| 269 | + buf_a[buf_idx + 18] = FLOAT_TYPE(v1.z); |
| 270 | + buf_a[buf_idx + 19] = FLOAT_TYPE(v1.w); |
257 | 271 | #elif defined(DATA_A_Q4_1)
|
258 | 272 | const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
259 |
| - const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a; |
260 |
| - |
261 |
| - const uint ib = idx / 16; |
262 |
| - const uint iqs = idx & 0xF; |
263 |
| - |
264 |
| - const float d = float(data_a[ib].d); |
265 |
| - const float m = float(data_a[ib].m); |
266 |
| - const uint vui = uint(data_a[ib].qs[iqs]); |
267 |
| - const vec2 v = vec2(vui & 0xF, vui >> 4) * d + m; |
268 |
| - |
269 |
| - buf_a[buf_idx ] = FLOAT_TYPE(v.x); |
270 |
| - buf_a[buf_idx + 16] = FLOAT_TYPE(v.y); |
| 273 | + const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + 4 * loadr_a; |
| 274 | + |
| 275 | + const uint ib = idx / 4; |
| 276 | + const uint iqs = idx & 0x03; |
| 277 | + |
| 278 | + const float d = float(data_a_packed16[ib].d); |
| 279 | + const float m = float(data_a_packed16[ib].m); |
| 280 | + const uint vui = uint(data_a_packed16[ib].qs[2*iqs]) | (uint(data_a_packed16[ib].qs[2*iqs + 1]) << 16); |
| 281 | + const vec4 v0 = vec4(unpack8(vui & 0x0F0F0F0F)) * d + m; |
| 282 | + const vec4 v1 = vec4(unpack8((vui >> 4) & 0x0F0F0F0F)) * d + m; |
| 283 | + |
| 284 | + buf_a[buf_idx ] = FLOAT_TYPE(v0.x); |
| 285 | + buf_a[buf_idx + 1 ] = FLOAT_TYPE(v0.y); |
| 286 | + buf_a[buf_idx + 2 ] = FLOAT_TYPE(v0.z); |
| 287 | + buf_a[buf_idx + 3 ] = FLOAT_TYPE(v0.w); |
| 288 | + buf_a[buf_idx + 16] = FLOAT_TYPE(v1.x); |
| 289 | + buf_a[buf_idx + 17] = FLOAT_TYPE(v1.y); |
| 290 | + buf_a[buf_idx + 18] = FLOAT_TYPE(v1.z); |
| 291 | + buf_a[buf_idx + 19] = FLOAT_TYPE(v1.w); |
271 | 292 | #elif defined(DATA_A_Q5_0)
|
272 | 293 | const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
273 |
| - const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a; |
| 294 | + const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + 2 * loadr_a; |
274 | 295 |
|
275 |
| - const uint ib = idx / 16; |
276 |
| - const uint iqs = idx & 0xF; |
| 296 | + const uint ib = idx / 8; |
| 297 | + const uint iqs = idx & 0x07; |
277 | 298 |
|
278 |
| - const float d = float(data_a[ib].d); |
279 |
| - const uint uint_qh = uint(data_a[ib].qh[1]) << 16 | data_a[ib].qh[0]; |
280 |
| - const ivec2 qh = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10); |
281 |
| - const uint vui = uint(data_a[ib].qs[iqs]); |
282 |
| - const vec2 v = (vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y) - 16.0f) * d; |
| 299 | + const float d = float(data_a_packed16[ib].d); |
| 300 | + const uint uint_qh = uint(data_a_packed16[ib].qh[1]) << 16 | uint(data_a_packed16[ib].qh[0]); |
| 301 | + const ivec2 qh0 = ivec2(((uint_qh >> 2*iqs) << 4) & 0x10, (uint_qh >> (2*iqs + 12)) & 0x10); |
| 302 | + const ivec2 qh1 = ivec2(((uint_qh >> (2*iqs + 1)) << 4) & 0x10, (uint_qh >> (2*iqs + 13)) & 0x10); |
| 303 | + |
| 304 | + const uint vui = uint(data_a_packed16[ib].qs[iqs]); |
| 305 | + const vec4 v = (vec4((vui & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, (vui >> 12) | qh1.y) - 16.0f) * d; |
283 | 306 |
|
284 | 307 | buf_a[buf_idx ] = FLOAT_TYPE(v.x);
|
| 308 | + buf_a[buf_idx + 1 ] = FLOAT_TYPE(v.z); |
285 | 309 | buf_a[buf_idx + 16] = FLOAT_TYPE(v.y);
|
| 310 | + buf_a[buf_idx + 17] = FLOAT_TYPE(v.w); |
286 | 311 | #elif defined(DATA_A_Q5_1)
|
287 | 312 | const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
288 |
| - const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a; |
| 313 | + const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + 2 * loadr_a; |
289 | 314 |
|
290 |
| - const uint ib = idx / 16; |
291 |
| - const uint iqs = idx & 0xF; |
| 315 | + const uint ib = idx / 8; |
| 316 | + const uint iqs = idx & 0x07; |
292 | 317 |
|
293 |
| - const float d = float(data_a[ib].d); |
294 |
| - const float m = float(data_a[ib].m); |
295 |
| - const uint uint_qh = data_a[ib].qh; |
296 |
| - const ivec2 qh = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10); |
297 |
| - const uint vui = uint(data_a[ib].qs[iqs]); |
298 |
| - const vec2 v = vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y) * d + m; |
| 318 | + const float d = float(data_a_packed16[ib].d); |
| 319 | + const float m = float(data_a_packed16[ib].m); |
| 320 | + const uint uint_qh = data_a_packed16[ib].qh; |
| 321 | + const ivec2 qh0 = ivec2(((uint_qh >> 2*iqs) << 4) & 0x10, (uint_qh >> (2*iqs + 12)) & 0x10); |
| 322 | + const ivec2 qh1 = ivec2(((uint_qh >> (2*iqs + 1)) << 4) & 0x10, (uint_qh >> (2*iqs + 13)) & 0x10); |
| 323 | + |
| 324 | + const uint vui = uint(data_a_packed16[ib].qs[iqs]); |
| 325 | + const vec4 v = vec4((vui & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, (vui >> 12) | qh1.y) * d + m; |
299 | 326 |
|
300 | 327 | buf_a[buf_idx ] = FLOAT_TYPE(v.x);
|
| 328 | + buf_a[buf_idx + 1 ] = FLOAT_TYPE(v.z); |
301 | 329 | buf_a[buf_idx + 16] = FLOAT_TYPE(v.y);
|
| 330 | + buf_a[buf_idx + 17] = FLOAT_TYPE(v.w); |
302 | 331 | #elif defined(DATA_A_Q8_0)
|
303 | 332 | const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
304 | 333 | const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
|
305 | 334 |
|
306 |
| - const uint ib = idx / 16; |
307 |
| - const uint iqs = (idx & 0xF) * 2; |
| 335 | + const uint ib = idx / 8; |
| 336 | + const uint iqs = idx & 0x07; |
308 | 337 |
|
309 |
| - const float d = float(data_a[ib].d); |
310 |
| - const vec2 v = vec2(int(data_a[ib].qs[iqs]), int(data_a[ib].qs[iqs + 1])) * d; |
| 338 | + const float d = float(data_a_packed16[ib].d); |
| 339 | + const i8vec2 v0 = unpack8(data_a_packed16[ib].qs[2*iqs]); |
| 340 | + const i8vec2 v1 = unpack8(data_a_packed16[ib].qs[2*iqs + 1]); |
| 341 | + const vec4 v = vec4(v0.x, v0.y, v1.x, v1.y) * d; |
311 | 342 |
|
312 | 343 | buf_a[buf_idx ] = FLOAT_TYPE(v.x);
|
313 | 344 | buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
|
| 345 | + buf_a[buf_idx + 2] = FLOAT_TYPE(v.z); |
| 346 | + buf_a[buf_idx + 3] = FLOAT_TYPE(v.w); |
314 | 347 | #elif defined(DATA_A_Q2_K)
|
315 | 348 | const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
316 | 349 | const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
|
@@ -623,17 +656,18 @@ void main() {
|
623 | 656 | buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
|
624 | 657 | #elif defined(DATA_A_IQ4_NL)
|
625 | 658 | const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
626 |
| - const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a; |
| 659 | + const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + 2 * loadr_a; |
627 | 660 |
|
628 |
| - const uint ib = idx / 16; |
629 |
| - const uint iqs = idx & 0xF; |
| 661 | + const uint ib = idx / 8; |
| 662 | + const uint iqs = idx & 0x07; |
630 | 663 |
|
631 |
| - const float d = float(data_a[ib].d); |
632 |
| - const uint vui = uint(data_a[ib].qs[iqs]); |
633 |
| - const vec2 v = vec2(kvalues_iq4nl[vui & 0xF], kvalues_iq4nl[vui >> 4]) * d; |
| 664 | + const FLOAT_TYPE d = FLOAT_TYPE(data_a_packed16[ib].d); |
| 665 | + const uint vui = uint(data_a_packed16[ib].qs[iqs]); |
634 | 666 |
|
635 |
| - buf_a[buf_idx ] = FLOAT_TYPE(v.x); |
636 |
| - buf_a[buf_idx + 16] = FLOAT_TYPE(v.y); |
| 667 | + buf_a[buf_idx ] = FLOAT_TYPE(kvalues_iq4nl[vui & 0xF]) * d; |
| 668 | + buf_a[buf_idx + 1 ] = FLOAT_TYPE(kvalues_iq4nl[bitfieldExtract(vui, 8, 4)]) * d; |
| 669 | + buf_a[buf_idx + 16] = FLOAT_TYPE(kvalues_iq4nl[bitfieldExtract(vui, 4, 4)]) * d; |
| 670 | + buf_a[buf_idx + 17] = FLOAT_TYPE(kvalues_iq4nl[vui >> 12]) * d; |
637 | 671 | #endif
|
638 | 672 | }
|
639 | 673 | [[unroll]] for (uint l = 0; l < BN; l += loadstride_b) {
|
|
0 commit comments