diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp index a8fd93fdeadee..22ed2e27c2721 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp @@ -91,8 +91,14 @@ layout (constant_id = 10) const uint WARP = 32; #define SHMEM_STRIDE (BK + 1) #endif -shared FLOAT_TYPE buf_a[BM * SHMEM_STRIDE]; -shared FLOAT_TYPE buf_b[BN * SHMEM_STRIDE]; +#if defined(FLOAT16) && defined(COOPMAT) +#define BUF_TYPE float16_t +#else +#define BUF_TYPE FLOAT_TYPE +#endif + +shared BUF_TYPE buf_a[BM * SHMEM_STRIDE]; +shared BUF_TYPE buf_b[BN * SHMEM_STRIDE]; #ifdef MUL_MAT_ID shared u16vec2 row_ids[3072]; @@ -202,7 +208,7 @@ void main() { #endif #ifdef COOPMAT - coopmat cache_a; + coopmat cache_a[cms_per_row]; coopmat cache_b; coopmat sums[cms_per_row * cms_per_col]; @@ -226,26 +232,26 @@ void main() { #if LOAD_VEC_A == 8 const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; - buf_a[buf_idx ] = FLOAT_TYPE(data_a[idx][0].x); - buf_a[buf_idx + 1] = FLOAT_TYPE(data_a[idx][0].y); - buf_a[buf_idx + 2] = FLOAT_TYPE(data_a[idx][0].z); - buf_a[buf_idx + 3] = FLOAT_TYPE(data_a[idx][0].w); - buf_a[buf_idx + 4] = FLOAT_TYPE(data_a[idx][1].x); - buf_a[buf_idx + 5] = FLOAT_TYPE(data_a[idx][1].y); - buf_a[buf_idx + 6] = FLOAT_TYPE(data_a[idx][1].z); - buf_a[buf_idx + 7] = FLOAT_TYPE(data_a[idx][1].w); + buf_a[buf_idx ] = BUF_TYPE(data_a[idx][0].x); + buf_a[buf_idx + 1] = BUF_TYPE(data_a[idx][0].y); + buf_a[buf_idx + 2] = BUF_TYPE(data_a[idx][0].z); + buf_a[buf_idx + 3] = BUF_TYPE(data_a[idx][0].w); + buf_a[buf_idx + 4] = BUF_TYPE(data_a[idx][1].x); + buf_a[buf_idx + 5] = BUF_TYPE(data_a[idx][1].y); + buf_a[buf_idx + 6] = BUF_TYPE(data_a[idx][1].z); + buf_a[buf_idx + 7] = BUF_TYPE(data_a[idx][1].w); #elif LOAD_VEC_A == 4 const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; - buf_a[buf_idx ] = FLOAT_TYPE(data_a[idx].x); - buf_a[buf_idx + 1] = FLOAT_TYPE(data_a[idx].y); - buf_a[buf_idx + 2] = FLOAT_TYPE(data_a[idx].z); - buf_a[buf_idx + 3] = FLOAT_TYPE(data_a[idx].w); + buf_a[buf_idx ] = BUF_TYPE(data_a[idx].x); + buf_a[buf_idx + 1] = BUF_TYPE(data_a[idx].y); + buf_a[buf_idx + 2] = BUF_TYPE(data_a[idx].z); + buf_a[buf_idx + 3] = BUF_TYPE(data_a[idx].w); #else if (ir * BM + loadc_a + l < p.M && block + loadr_a < end_k) { - buf_a[(loadc_a + l) * SHMEM_STRIDE + loadr_a] = FLOAT_TYPE(data_a[pos_a + (loadc_a + l) * p.stride_a + loadr_a]); + buf_a[(loadc_a + l) * SHMEM_STRIDE + loadr_a] = BUF_TYPE(data_a[pos_a + (loadc_a + l) * p.stride_a + loadr_a]); } else { - buf_a[(loadc_a + l) * SHMEM_STRIDE + loadr_a] = FLOAT_TYPE(0.0f); + buf_a[(loadc_a + l) * SHMEM_STRIDE + loadr_a] = BUF_TYPE(0.0f); } #endif #elif defined(DATA_A_Q4_0) @@ -260,14 +266,14 @@ void main() { const vec4 v0 = (vec4(unpack8(vui & 0x0F0F0F0F)) - 8.0f) * d; const vec4 v1 = (vec4(unpack8((vui >> 4) & 0x0F0F0F0F)) - 8.0f) * d; - buf_a[buf_idx ] = FLOAT_TYPE(v0.x); - buf_a[buf_idx + 1 ] = FLOAT_TYPE(v0.y); - buf_a[buf_idx + 2 ] = FLOAT_TYPE(v0.z); - buf_a[buf_idx + 3 ] = FLOAT_TYPE(v0.w); - buf_a[buf_idx + 16] = FLOAT_TYPE(v1.x); - buf_a[buf_idx + 17] = FLOAT_TYPE(v1.y); - buf_a[buf_idx + 18] = FLOAT_TYPE(v1.z); - buf_a[buf_idx + 19] = FLOAT_TYPE(v1.w); + buf_a[buf_idx ] = BUF_TYPE(v0.x); + buf_a[buf_idx + 1 ] = BUF_TYPE(v0.y); + buf_a[buf_idx + 2 ] = BUF_TYPE(v0.z); + buf_a[buf_idx + 3 ] = BUF_TYPE(v0.w); + buf_a[buf_idx + 16] = BUF_TYPE(v1.x); + buf_a[buf_idx + 17] = BUF_TYPE(v1.y); + buf_a[buf_idx + 18] = BUF_TYPE(v1.z); + buf_a[buf_idx + 19] = BUF_TYPE(v1.w); #elif defined(DATA_A_Q4_1) const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + 4 * loadr_a; @@ -281,14 +287,14 @@ void main() { const vec4 v0 = vec4(unpack8(vui & 0x0F0F0F0F)) * d + m; const vec4 v1 = vec4(unpack8((vui >> 4) & 0x0F0F0F0F)) * d + m; - buf_a[buf_idx ] = FLOAT_TYPE(v0.x); - buf_a[buf_idx + 1 ] = FLOAT_TYPE(v0.y); - buf_a[buf_idx + 2 ] = FLOAT_TYPE(v0.z); - buf_a[buf_idx + 3 ] = FLOAT_TYPE(v0.w); - buf_a[buf_idx + 16] = FLOAT_TYPE(v1.x); - buf_a[buf_idx + 17] = FLOAT_TYPE(v1.y); - buf_a[buf_idx + 18] = FLOAT_TYPE(v1.z); - buf_a[buf_idx + 19] = FLOAT_TYPE(v1.w); + buf_a[buf_idx ] = BUF_TYPE(v0.x); + buf_a[buf_idx + 1 ] = BUF_TYPE(v0.y); + buf_a[buf_idx + 2 ] = BUF_TYPE(v0.z); + buf_a[buf_idx + 3 ] = BUF_TYPE(v0.w); + buf_a[buf_idx + 16] = BUF_TYPE(v1.x); + buf_a[buf_idx + 17] = BUF_TYPE(v1.y); + buf_a[buf_idx + 18] = BUF_TYPE(v1.z); + buf_a[buf_idx + 19] = BUF_TYPE(v1.w); #elif defined(DATA_A_Q5_0) const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + 2 * loadr_a; @@ -304,10 +310,10 @@ void main() { const uint vui = uint(data_a_packed16[ib].qs[iqs]); const vec4 v = (vec4((vui & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, (vui >> 12) | qh1.y) - 16.0f) * d; - buf_a[buf_idx ] = FLOAT_TYPE(v.x); - buf_a[buf_idx + 1 ] = FLOAT_TYPE(v.z); - buf_a[buf_idx + 16] = FLOAT_TYPE(v.y); - buf_a[buf_idx + 17] = FLOAT_TYPE(v.w); + buf_a[buf_idx ] = BUF_TYPE(v.x); + buf_a[buf_idx + 1 ] = BUF_TYPE(v.z); + buf_a[buf_idx + 16] = BUF_TYPE(v.y); + buf_a[buf_idx + 17] = BUF_TYPE(v.w); #elif defined(DATA_A_Q5_1) const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + 2 * loadr_a; @@ -324,10 +330,10 @@ void main() { const uint vui = uint(data_a_packed16[ib].qs[iqs]); const vec4 v = vec4((vui & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, (vui >> 12) | qh1.y) * d + m; - buf_a[buf_idx ] = FLOAT_TYPE(v.x); - buf_a[buf_idx + 1 ] = FLOAT_TYPE(v.z); - buf_a[buf_idx + 16] = FLOAT_TYPE(v.y); - buf_a[buf_idx + 17] = FLOAT_TYPE(v.w); + buf_a[buf_idx ] = BUF_TYPE(v.x); + buf_a[buf_idx + 1 ] = BUF_TYPE(v.z); + buf_a[buf_idx + 16] = BUF_TYPE(v.y); + buf_a[buf_idx + 17] = BUF_TYPE(v.w); #elif defined(DATA_A_Q8_0) const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; @@ -340,10 +346,10 @@ void main() { const i8vec2 v1 = unpack8(data_a_packed16[ib].qs[2*iqs + 1]); const vec4 v = vec4(v0.x, v0.y, v1.x, v1.y) * d; - buf_a[buf_idx ] = FLOAT_TYPE(v.x); - buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); - buf_a[buf_idx + 2] = FLOAT_TYPE(v.z); - buf_a[buf_idx + 3] = FLOAT_TYPE(v.w); + buf_a[buf_idx ] = BUF_TYPE(v.x); + buf_a[buf_idx + 1] = BUF_TYPE(v.y); + buf_a[buf_idx + 2] = BUF_TYPE(v.z); + buf_a[buf_idx + 3] = BUF_TYPE(v.w); #elif defined(DATA_A_Q2_K) const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; @@ -361,8 +367,8 @@ void main() { const vec2 v = d.x * float(scales & 0xF) * vec2((qs >> qsshift) & 3) - d.y * float(scales >> 4); - buf_a[buf_idx ] = FLOAT_TYPE(v.x); - buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); + buf_a[buf_idx ] = BUF_TYPE(v.x); + buf_a[buf_idx + 1] = BUF_TYPE(v.y); #elif defined(DATA_A_Q3_K) const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; @@ -383,8 +389,8 @@ void main() { | (((data_a[ib].scales[8 + (is % 4)] >> (2 * int(is / 4))) & 3) << 4)); const float dl = float(data_a[ib].d) * float(us - 32); - buf_a[buf_idx ] = FLOAT_TYPE(dl * float(int8_t((data_a[ib].qs[qsi ] >> qsshift) & 3) - (((data_a[ib].hmask[hmi ] & m) != 0) ? 0 : 4))); - buf_a[buf_idx + 1] = FLOAT_TYPE(dl * float(int8_t((data_a[ib].qs[qsi + 1] >> qsshift) & 3) - (((data_a[ib].hmask[hmi + 1] & m) != 0) ? 0 : 4))); + buf_a[buf_idx ] = BUF_TYPE(dl * float(int8_t((data_a[ib].qs[qsi ] >> qsshift) & 3) - (((data_a[ib].hmask[hmi ] & m) != 0) ? 0 : 4))); + buf_a[buf_idx + 1] = BUF_TYPE(dl * float(int8_t((data_a[ib].qs[qsi + 1] >> qsshift) & 3) - (((data_a[ib].hmask[hmi + 1] & m) != 0) ? 0 : 4))); #elif defined(DATA_A_Q4_K) const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; @@ -416,8 +422,8 @@ void main() { const float d = loadd.x * sc; const float m = -loadd.y * mbyte; - buf_a[buf_idx ] = FLOAT_TYPE(fma(d, float((data_a[ib].qs[qsi ] >> (b * 4)) & 0xF), m)); - buf_a[buf_idx + 1] = FLOAT_TYPE(fma(d, float((data_a[ib].qs[qsi + 1] >> (b * 4)) & 0xF), m)); + buf_a[buf_idx ] = BUF_TYPE(fma(d, float((data_a[ib].qs[qsi ] >> (b * 4)) & 0xF), m)); + buf_a[buf_idx + 1] = BUF_TYPE(fma(d, float((data_a[ib].qs[qsi + 1] >> (b * 4)) & 0xF), m)); #elif defined(DATA_A_Q5_K) const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; @@ -452,8 +458,8 @@ void main() { const float d = loadd.x * sc; const float m = -loadd.y * mbyte; - buf_a[buf_idx ] = FLOAT_TYPE(fma(d, float((data_a[ib].qs[qsi ] >> (b * 4)) & 0xF) + float((data_a[ib].qh[qhi ] & hm) != 0 ? 16 : 0), m)); - buf_a[buf_idx + 1] = FLOAT_TYPE(fma(d, float((data_a[ib].qs[qsi + 1] >> (b * 4)) & 0xF) + float((data_a[ib].qh[qhi + 1] & hm) != 0 ? 16 : 0), m)); + buf_a[buf_idx ] = BUF_TYPE(fma(d, float((data_a[ib].qs[qsi ] >> (b * 4)) & 0xF) + float((data_a[ib].qh[qhi ] & hm) != 0 ? 16 : 0), m)); + buf_a[buf_idx + 1] = BUF_TYPE(fma(d, float((data_a[ib].qs[qsi + 1] >> (b * 4)) & 0xF) + float((data_a[ib].qh[qhi + 1] & hm) != 0 ? 16 : 0), m)); #elif defined(DATA_A_Q6_K) const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; @@ -471,16 +477,15 @@ void main() { const float dscale = float(data_a[ib].d) * float(data_a[ib].scales[is]); - buf_a[buf_idx ] = FLOAT_TYPE(dscale * float(int8_t(((data_a[ib].ql[qsi ] >> (b * 4)) & 0xF) | (((data_a[ib].qh[qhi ] >> qhshift) & 3) << 4)) - 32)); - buf_a[buf_idx + 1] = FLOAT_TYPE(dscale * float(int8_t(((data_a[ib].ql[qsi + 1] >> (b * 4)) & 0xF) | (((data_a[ib].qh[qhi + 1] >> qhshift) & 3) << 4)) - 32)); + buf_a[buf_idx ] = BUF_TYPE(dscale * float(int8_t(((data_a[ib].ql[qsi ] >> (b * 4)) & 0xF) | (((data_a[ib].qh[qhi ] >> qhshift) & 3) << 4)) - 32)); + buf_a[buf_idx + 1] = BUF_TYPE(dscale * float(int8_t(((data_a[ib].ql[qsi + 1] >> (b * 4)) & 0xF) | (((data_a[ib].qh[qhi + 1] >> qhshift) & 3) << 4)) - 32)); #elif defined(DATA_A_IQ1_S) const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; - const uint ib = idx / 128; // 2 values per idx - const uint ib32 = (idx % 128) / 16; // 0..7 - const uint ib8 = (idx % 128) / 4; - const int i8 = 2 * int(idx % 4); + const uint ib = idx / 32; // 8 values per idx + const uint ib32 = (idx % 32) / 4; // 0..7 + const uint ib8 = idx % 32; const float d = float(data_a[ib].d); const uint qh = data_a[ib].qh[ib32]; @@ -488,23 +493,16 @@ void main() { const float dl = d * (2 * bitfieldExtract(qh, 12, 3) + 1); const float delta = ((qh & 0x8000) != 0) ? -IQ1S_DELTA : IQ1S_DELTA; const int16_t grid = int16_t(iq1s_grid[qs | (bitfieldExtract(qh, 3 * int(ib8 & 3), 3) << 8)]); - - const ivec2 gvec = ivec2( - bitfieldExtract(grid, 2 * (i8), 2), - bitfieldExtract(grid, 2 * (i8 + 1), 2) - ); - const vec2 v = dl * (vec2(gvec) + delta); - - buf_a[buf_idx ] = FLOAT_TYPE(v.x); - buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); + [[unroll]] for (int k = 0; k < 8; ++k) { + buf_a[buf_idx + k] = BUF_TYPE(dl * (bitfieldExtract(grid, 2 * k, 2) + delta)); + } #elif defined(DATA_A_IQ1_M) const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; - const uint ib = idx / 128; // 2 values per idx - const uint ib8 = (idx % 128) / 4; + const uint ib = idx / 32; // 8 values per idx + const uint ib8 = idx % 32; const uint ib16 = ib8 / 2; - const int i8 = 2 * int(idx % 4); const uint16_t[4] scales = data_a[ib].scales; const u16vec4 s = u16vec4(scales[0], scales[1], scales[2], scales[3]) >> 12; @@ -515,21 +513,16 @@ void main() { const float dl = d * (2 * bitfieldExtract(sc, 3 * int(ib16 & 3), 3) + 1); const float delta = ((qh & 8) != 0) ? -IQ1M_DELTA : IQ1M_DELTA; const int16_t grid = int16_t(iq1s_grid[qs | ((qh & 7) << 8)]); - const ivec2 gvec = ivec2( - bitfieldExtract(grid, 2 * (i8), 2), - bitfieldExtract(grid, 2 * (i8 + 1), 2) - ); - const vec2 v = dl * (vec2(gvec) + delta); - - buf_a[buf_idx ] = FLOAT_TYPE(v.x); - buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); + [[unroll]] for (int k = 0; k < 8; ++k) { + buf_a[buf_idx + k] = BUF_TYPE(dl * (bitfieldExtract(grid, 2 * k, 2) + delta)); + } #elif defined(DATA_A_IQ2_XXS) const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; - const uint ib = idx / 128; // 2 values per idx - const uint ib32 = (idx % 128) / 16; // 0..7 - const uint ib8 = (idx / 4) % 4; + const uint ib = idx / 32; // 8 values per idx + const uint ib32 = (idx % 32) / 4; // 0..7 + const uint ib8 = idx % 4; const float d = float(data_a[ib].d); const uint qs = data_a[ib].qs[8 * ib32 + ib8]; @@ -539,63 +532,81 @@ void main() { data_a[ib].qs[8*ib32 + 6], data_a[ib].qs[8*ib32 + 7] )); - const float db = d * 0.25 * (0.5 + (signs >> 28)); + const BUF_TYPE db = BUF_TYPE(d * 0.25 * (0.5 + (signs >> 28))); const uint32_t sign7 = bitfieldExtract(signs, 7 * int(ib8), 7); - const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (2 * (idx % 4)); - const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign)))); - const uint grid = iq2xxs_grid[qs][(idx % 4) / 2] >> (16 * (idx & 1)); - const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); - - buf_a[buf_idx ] = FLOAT_TYPE(v.x); - buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); + const uint sign = sign7 | (bitCount(sign7) << 7); + const uvec2 grid = iq2xxs_grid[qs]; + const vec4 grid0 = vec4(unpack8(grid.x)); + const vec4 grid1 = vec4(unpack8(grid.y)); + + buf_a[buf_idx ] = db * BUF_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x); + buf_a[buf_idx + 1] = db * BUF_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y); + buf_a[buf_idx + 2] = db * BUF_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z); + buf_a[buf_idx + 3] = db * BUF_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w); + buf_a[buf_idx + 4] = db * BUF_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x); + buf_a[buf_idx + 5] = db * BUF_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y); + buf_a[buf_idx + 6] = db * BUF_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z); + buf_a[buf_idx + 7] = db * BUF_TYPE((sign & 128) != 0 ? -grid1.w : grid1.w); #elif defined(DATA_A_IQ2_XS) const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; - const uint ib = idx / 128; // 2 values per idx - const uint ib32 = (idx % 128) / 16; // 0..7 - const uint ib8 = (idx / 4) % 4; // 0..3 + const uint ib = idx / 32; // 8 values per idx + const uint ib32 = (idx % 32) / 4; // 0..7 + const uint ib8 = idx % 4; // 0..3 const float d = float(data_a[ib].d); const uint scale = (data_a[ib].scales[ib32] >> (2 * (ib8 & 2))) & 0xf; - const float db = d * 0.25 * (0.5 + scale); + const BUF_TYPE db = BUF_TYPE(d * 0.25 * (0.5 + scale)); const uint qs = data_a[ib].qs[4 * ib32 + ib8]; const uint sign7 = qs >> 9; - const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (2 * (idx % 4)); - const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign)))); - const uint grid = iq2xs_grid[qs & 511][(idx % 4) / 2] >> (16 * (idx & 1)); - const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); - - buf_a[buf_idx ] = FLOAT_TYPE(v.x); - buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); + const uint sign = sign7 | (bitCount(sign7) << 7); + const uvec2 grid = iq2xs_grid[qs & 511]; + const vec4 grid0 = vec4(unpack8(grid.x)); + const vec4 grid1 = vec4(unpack8(grid.y)); + + buf_a[buf_idx ] = db * BUF_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x); + buf_a[buf_idx + 1] = db * BUF_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y); + buf_a[buf_idx + 2] = db * BUF_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z); + buf_a[buf_idx + 3] = db * BUF_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w); + buf_a[buf_idx + 4] = db * BUF_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x); + buf_a[buf_idx + 5] = db * BUF_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y); + buf_a[buf_idx + 6] = db * BUF_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z); + buf_a[buf_idx + 7] = db * BUF_TYPE((sign & 128) != 0 ? -grid1.w : grid1.w); #elif defined(DATA_A_IQ2_S) const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; - const uint ib = idx / 128; // 2 values per idx - const uint ib8 = (idx % 128) / 4; // 0..31 - const uint ib32 = ib8 / 4; // 0..7 + const uint ib = idx / 32; // 8 values per idx + const uint ib8 = idx % 32; // 0..31 + const uint ib32 = ib8 / 4; // 0..7 const uint scale = (data_a[ib].scales[ib32] >> (2 * (ib8 & 2))) & 0xf; const uint qs = data_a[ib].qs[ib8]; const uint qh = data_a[ib].qh[ib32]; const uint qhshift = 2 * (ib8 % 4); - const uint sign = data_a[ib].qs[QUANT_K / 8 + ib8] >> (2 * (idx % 4)); + const uint sign = data_a[ib].qs[QUANT_K / 8 + ib8]; const float d = float(data_a[ib].d); - const float db = d * 0.25 * (0.5 + scale); - const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign)))); - const uint16_t grid = unpack16(iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)][(idx & 2) >> 1])[idx & 1]; - const vec2 v = db * vec2(sign01) * vec2(unpack8(grid)); - - buf_a[buf_idx ] = FLOAT_TYPE(v.x); - buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); + const BUF_TYPE db = BUF_TYPE(d * 0.25 * (0.5 + scale)); + const uvec2 grid = iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)]; + const vec4 grid0 = vec4(unpack8(grid.x)); + const vec4 grid1 = vec4(unpack8(grid.y)); + + buf_a[buf_idx ] = db * BUF_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x); + buf_a[buf_idx + 1] = db * BUF_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y); + buf_a[buf_idx + 2] = db * BUF_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z); + buf_a[buf_idx + 3] = db * BUF_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w); + buf_a[buf_idx + 4] = db * BUF_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x); + buf_a[buf_idx + 5] = db * BUF_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y); + buf_a[buf_idx + 6] = db * BUF_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z); + buf_a[buf_idx + 7] = db * BUF_TYPE((sign & 128) != 0 ? -grid1.w : grid1.w); #elif defined(DATA_A_IQ3_XXS) const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; - const uint ib = idx / 128; // 2 values per idx - const uint iqs = (idx % 128) / 2; // 0..63 + const uint ib = idx / 64; // 4 values per idx + const uint iqs = idx % 64; // 0..63 const uint is = QUANT_K / 4 + 4 * (iqs / 8); // 8 values const float d = float(data_a[ib].d); @@ -608,33 +619,35 @@ void main() { )); const float db = d * 0.5 * (0.5 + (signs >> 28)); const uint32_t sign7 = bitfieldExtract(signs, 7 * (int(iqs / 2) % 4), 7); - const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (2 * (idx % 4)); - const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign)))); - const uint grid = iq3xxs_grid[qs] >> (16 * (idx & 1)); - const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); - - buf_a[buf_idx ] = FLOAT_TYPE(v.x); - buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); + const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (4 * (idx % 2)); + const uint grid = iq3xxs_grid[qs]; + const vec4 v = db * vec4(unpack8(grid)); + + buf_a[buf_idx ] = BUF_TYPE((sign & 1) != 0 ? -v.x : v.x); + buf_a[buf_idx + 1] = BUF_TYPE((sign & 2) != 0 ? -v.y : v.y); + buf_a[buf_idx + 2] = BUF_TYPE((sign & 4) != 0 ? -v.z : v.z); + buf_a[buf_idx + 3] = BUF_TYPE((sign & 8) != 0 ? -v.w : v.w); #elif defined(DATA_A_IQ3_S) const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; - const uint ib = idx / 128; // 2 values per idx - const uint iqs = (idx % 128) / 2; // 0..63 + const uint ib = idx / 64; // 4 values per idx + const uint iqs = idx % 64; // 0..63 const uint iqh = iqs / 8; const float d = float(data_a[ib].d); const uint qs = data_a[ib].qs[iqs]; const uint qh = data_a[ib].qh[iqh]; - const int8_t sign = int8_t(data_a[ib].signs[iqs / 2] >> (2 * (idx % 4))); + const int8_t sign = int8_t(data_a[ib].signs[iqs / 2] >> (4 * (idx % 2))); const uint scale = data_a[ib].scales[iqs / 16]; - const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(sign << 1, sign))); const float db = d * (1 + 2 * ((scale >> (4 * (iqh & 1))) & 0xf)); - const uint32_t grid = iq3s_grid[qs | ((qh << (8 - (iqs % 8))) & 256)] >> (16 * (idx % 2)); - const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); + const uint32_t grid = iq3s_grid[qs | ((qh << (8 - (iqs % 8))) & 256)]; + const vec4 v = db * vec4(unpack8(grid)); - buf_a[buf_idx ] = FLOAT_TYPE(v.x); - buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); + buf_a[buf_idx ] = BUF_TYPE((sign & 1) != 0 ? -v.x : v.x); + buf_a[buf_idx + 1] = BUF_TYPE((sign & 2) != 0 ? -v.y : v.y); + buf_a[buf_idx + 2] = BUF_TYPE((sign & 4) != 0 ? -v.z : v.z); + buf_a[buf_idx + 3] = BUF_TYPE((sign & 8) != 0 ? -v.w : v.w); #elif defined(DATA_A_IQ4_XS) const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; @@ -652,8 +665,8 @@ void main() { const float d = float(data_a[ib].d); const vec2 v = d * float(int(sl | (sh << 4)) - 32) * vec2(kvalues_iq4nl[qs.x], kvalues_iq4nl[qs.y]); - buf_a[buf_idx ] = FLOAT_TYPE(v.x); - buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); + buf_a[buf_idx ] = BUF_TYPE(v.x); + buf_a[buf_idx + 1] = BUF_TYPE(v.y); #elif defined(DATA_A_IQ4_NL) const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + 2 * loadr_a; @@ -661,13 +674,13 @@ void main() { const uint ib = idx / 8; const uint iqs = idx & 0x07; - const FLOAT_TYPE d = FLOAT_TYPE(data_a_packed16[ib].d); + const BUF_TYPE d = BUF_TYPE(data_a_packed16[ib].d); const uint vui = uint(data_a_packed16[ib].qs[iqs]); - buf_a[buf_idx ] = FLOAT_TYPE(kvalues_iq4nl[vui & 0xF]) * d; - buf_a[buf_idx + 1 ] = FLOAT_TYPE(kvalues_iq4nl[bitfieldExtract(vui, 8, 4)]) * d; - buf_a[buf_idx + 16] = FLOAT_TYPE(kvalues_iq4nl[bitfieldExtract(vui, 4, 4)]) * d; - buf_a[buf_idx + 17] = FLOAT_TYPE(kvalues_iq4nl[vui >> 12]) * d; + buf_a[buf_idx ] = BUF_TYPE(kvalues_iq4nl[vui & 0xF]) * d; + buf_a[buf_idx + 1 ] = BUF_TYPE(kvalues_iq4nl[bitfieldExtract(vui, 8, 4)]) * d; + buf_a[buf_idx + 16] = BUF_TYPE(kvalues_iq4nl[bitfieldExtract(vui, 4, 4)]) * d; + buf_a[buf_idx + 17] = BUF_TYPE(kvalues_iq4nl[vui >> 12]) * d; #endif } [[unroll]] for (uint l = 0; l < BN; l += loadstride_b) { @@ -679,14 +692,14 @@ void main() { const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b; #endif const uint buf_idx = (loadc_b + l) * SHMEM_STRIDE + loadr_b * LOAD_VEC_B; - buf_b[buf_idx + 0] = FLOAT_TYPE(data_b[idx][0].x); - buf_b[buf_idx + 1] = FLOAT_TYPE(data_b[idx][0].y); - buf_b[buf_idx + 2] = FLOAT_TYPE(data_b[idx][0].z); - buf_b[buf_idx + 3] = FLOAT_TYPE(data_b[idx][0].w); - buf_b[buf_idx + 4] = FLOAT_TYPE(data_b[idx][1].x); - buf_b[buf_idx + 5] = FLOAT_TYPE(data_b[idx][1].y); - buf_b[buf_idx + 6] = FLOAT_TYPE(data_b[idx][1].z); - buf_b[buf_idx + 7] = FLOAT_TYPE(data_b[idx][1].w); + buf_b[buf_idx + 0] = BUF_TYPE(data_b[idx][0].x); + buf_b[buf_idx + 1] = BUF_TYPE(data_b[idx][0].y); + buf_b[buf_idx + 2] = BUF_TYPE(data_b[idx][0].z); + buf_b[buf_idx + 3] = BUF_TYPE(data_b[idx][0].w); + buf_b[buf_idx + 4] = BUF_TYPE(data_b[idx][1].x); + buf_b[buf_idx + 5] = BUF_TYPE(data_b[idx][1].y); + buf_b[buf_idx + 6] = BUF_TYPE(data_b[idx][1].z); + buf_b[buf_idx + 7] = BUF_TYPE(data_b[idx][1].w); #elif LOAD_VEC_B == 4 #ifdef MUL_MAT_ID const u16vec2 row_idx = row_ids[ic * BN + loadc_b + l]; @@ -695,23 +708,23 @@ void main() { const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b; #endif const uint buf_idx = (loadc_b + l) * SHMEM_STRIDE + loadr_b * LOAD_VEC_B; - buf_b[buf_idx + 0] = FLOAT_TYPE(data_b[idx].x); - buf_b[buf_idx + 1] = FLOAT_TYPE(data_b[idx].y); - buf_b[buf_idx + 2] = FLOAT_TYPE(data_b[idx].z); - buf_b[buf_idx + 3] = FLOAT_TYPE(data_b[idx].w); + buf_b[buf_idx + 0] = BUF_TYPE(data_b[idx].x); + buf_b[buf_idx + 1] = BUF_TYPE(data_b[idx].y); + buf_b[buf_idx + 2] = BUF_TYPE(data_b[idx].z); + buf_b[buf_idx + 3] = BUF_TYPE(data_b[idx].w); #elif !MUL_MAT_ID if (ic * BN + loadc_b + l < p.N && block + loadr_b < end_k) { - buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(data_b[pos_b + (loadc_b + l) * p.stride_b + loadr_b]); + buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = BUF_TYPE(data_b[pos_b + (loadc_b + l) * p.stride_b + loadr_b]); } else { - buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(0.0f); + buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = BUF_TYPE(0.0f); } #else const uint row_i = ic * BN + loadc_b + l; if (row_i < _ne1) { const u16vec2 row_idx = row_ids[row_i]; - buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(data_b[pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + loadr_b]); + buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = BUF_TYPE(data_b[pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + loadr_b]); } else { - buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(0.0f); + buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = BUF_TYPE(0.0f); } #endif } @@ -725,12 +738,12 @@ void main() { [[unroll]] for (uint i = 0; i < BK; i += TK) { [[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) { // Load from shared into cache - coopMatLoad(cache_a, buf_a, (warp_r * WM + cm_row * TM) * SHMEM_STRIDE + i, SHMEM_STRIDE, gl_CooperativeMatrixLayoutRowMajor); - - [[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) { - coopMatLoad(cache_b, buf_b, (warp_c * WN + cm_col * TN) * SHMEM_STRIDE + i, SHMEM_STRIDE, gl_CooperativeMatrixLayoutColumnMajor); - - sums[cm_col * cms_per_row + cm_row] = coopMatMulAdd(cache_a, cache_b, sums[cm_col * cms_per_row + cm_row]); + coopMatLoad(cache_a[cm_row], buf_a, (warp_r * WM + cm_row * TM) * SHMEM_STRIDE + i, SHMEM_STRIDE, gl_CooperativeMatrixLayoutRowMajor); + } + [[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) { + coopMatLoad(cache_b, buf_b, (warp_c * WN + cm_col * TN) * SHMEM_STRIDE + i, SHMEM_STRIDE, gl_CooperativeMatrixLayoutColumnMajor); + [[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) { + sums[cm_col * cms_per_row + cm_row] = coopMatMulAdd(cache_a[cm_row], cache_b, sums[cm_col * cms_per_row + cm_row]); } } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index ee1fec4e114aa..9ba48680e1440 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -326,9 +326,9 @@ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool for (const auto& tname : type_names) { std::string load_vec_quant = "2"; - if ((tname == "q4_0") || (tname == "q4_1")) + if ((tname == "q4_0") || (tname == "q4_1") || (tname == "iq1_s") || (tname == "iq1_m") || (tname == "iq2_xxs") || (tname == "iq2_xs") || (tname == "iq2_s")) load_vec_quant = "8"; - else if ((tname == "q5_0") || (tname == "q5_1") || (tname == "q8_0") || (tname == "iq4_nl")) + else if ((tname == "q5_0") || (tname == "q5_1") || (tname == "q8_0") || (tname == "iq3_xxs") || (tname == "iq3_s") || (tname == "iq4_nl")) load_vec_quant = "4"; std::string data_a_key = "DATA_A_" + to_uppercase(tname); @@ -343,7 +343,8 @@ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool string_to_spv(shader_name + "_" + tname + "_f32_aligned", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); } - if (tname != "f16" && tname != "f32") { + // don't generate f16 variants for coopmat + if (tname != "f16" && tname != "f32" && !coopmat) { string_to_spv(shader_name + "_" + tname + "_f16", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}}), fp16, coopmat, coopmat2, f16acc); string_to_spv(shader_name + "_" + tname + "_f16_aligned", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); }