Skip to content

metal : implement q5_0 and q5_1 kernels #3648

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Oct 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 44 additions & 3 deletions ggml-metal.m
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@
GGML_METAL_DECL_KERNEL(get_rows_f16);
GGML_METAL_DECL_KERNEL(get_rows_q4_0);
GGML_METAL_DECL_KERNEL(get_rows_q4_1);
GGML_METAL_DECL_KERNEL(get_rows_q5_0);
GGML_METAL_DECL_KERNEL(get_rows_q5_1);
GGML_METAL_DECL_KERNEL(get_rows_q8_0);
GGML_METAL_DECL_KERNEL(get_rows_q2_K);
GGML_METAL_DECL_KERNEL(get_rows_q3_K);
Expand All @@ -87,6 +89,8 @@
GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_l4);
GGML_METAL_DECL_KERNEL(mul_mv_q4_0_f32);
GGML_METAL_DECL_KERNEL(mul_mv_q4_1_f32);
GGML_METAL_DECL_KERNEL(mul_mv_q5_0_f32);
GGML_METAL_DECL_KERNEL(mul_mv_q5_1_f32);
GGML_METAL_DECL_KERNEL(mul_mv_q8_0_f32);
GGML_METAL_DECL_KERNEL(mul_mv_q2_K_f32);
GGML_METAL_DECL_KERNEL(mul_mv_q3_K_f32);
Expand All @@ -97,6 +101,8 @@
GGML_METAL_DECL_KERNEL(mul_mm_f16_f32);
GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32);
GGML_METAL_DECL_KERNEL(mul_mm_q4_1_f32);
GGML_METAL_DECL_KERNEL(mul_mm_q5_0_f32);
GGML_METAL_DECL_KERNEL(mul_mm_q5_1_f32);
GGML_METAL_DECL_KERNEL(mul_mm_q8_0_f32);
GGML_METAL_DECL_KERNEL(mul_mm_q2_K_f32);
GGML_METAL_DECL_KERNEL(mul_mm_q3_K_f32);
Expand Down Expand Up @@ -254,6 +260,8 @@ static void ggml_metal_log(enum ggml_log_level level, const char* format, ...){
GGML_METAL_ADD_KERNEL(get_rows_f16);
GGML_METAL_ADD_KERNEL(get_rows_q4_0);
GGML_METAL_ADD_KERNEL(get_rows_q4_1);
GGML_METAL_ADD_KERNEL(get_rows_q5_0);
GGML_METAL_ADD_KERNEL(get_rows_q5_1);
GGML_METAL_ADD_KERNEL(get_rows_q8_0);
GGML_METAL_ADD_KERNEL(get_rows_q2_K);
GGML_METAL_ADD_KERNEL(get_rows_q3_K);
Expand All @@ -268,6 +276,8 @@ static void ggml_metal_log(enum ggml_log_level level, const char* format, ...){
GGML_METAL_ADD_KERNEL(mul_mv_f16_f32_l4);
GGML_METAL_ADD_KERNEL(mul_mv_q4_0_f32);
GGML_METAL_ADD_KERNEL(mul_mv_q4_1_f32);
GGML_METAL_ADD_KERNEL(mul_mv_q5_0_f32);
GGML_METAL_ADD_KERNEL(mul_mv_q5_1_f32);
GGML_METAL_ADD_KERNEL(mul_mv_q8_0_f32);
GGML_METAL_ADD_KERNEL(mul_mv_q2_K_f32);
GGML_METAL_ADD_KERNEL(mul_mv_q3_K_f32);
Expand All @@ -278,8 +288,10 @@ static void ggml_metal_log(enum ggml_log_level level, const char* format, ...){
GGML_METAL_ADD_KERNEL(mul_mm_f32_f32);
GGML_METAL_ADD_KERNEL(mul_mm_f16_f32);
GGML_METAL_ADD_KERNEL(mul_mm_q4_0_f32);
GGML_METAL_ADD_KERNEL(mul_mm_q8_0_f32);
GGML_METAL_ADD_KERNEL(mul_mm_q4_1_f32);
GGML_METAL_ADD_KERNEL(mul_mm_q5_0_f32);
GGML_METAL_ADD_KERNEL(mul_mm_q5_1_f32);
GGML_METAL_ADD_KERNEL(mul_mm_q8_0_f32);
GGML_METAL_ADD_KERNEL(mul_mm_q2_K_f32);
GGML_METAL_ADD_KERNEL(mul_mm_q3_K_f32);
GGML_METAL_ADD_KERNEL(mul_mm_q4_K_f32);
Expand Down Expand Up @@ -346,6 +358,8 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
GGML_METAL_DEL_KERNEL(get_rows_f16);
GGML_METAL_DEL_KERNEL(get_rows_q4_0);
GGML_METAL_DEL_KERNEL(get_rows_q4_1);
GGML_METAL_DEL_KERNEL(get_rows_q5_0);
GGML_METAL_DEL_KERNEL(get_rows_q5_1);
GGML_METAL_DEL_KERNEL(get_rows_q8_0);
GGML_METAL_DEL_KERNEL(get_rows_q2_K);
GGML_METAL_DEL_KERNEL(get_rows_q3_K);
Expand All @@ -360,6 +374,8 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_l4);
GGML_METAL_DEL_KERNEL(mul_mv_q4_0_f32);
GGML_METAL_DEL_KERNEL(mul_mv_q4_1_f32);
GGML_METAL_DEL_KERNEL(mul_mv_q5_0_f32);
GGML_METAL_DEL_KERNEL(mul_mv_q5_1_f32);
GGML_METAL_DEL_KERNEL(mul_mv_q8_0_f32);
GGML_METAL_DEL_KERNEL(mul_mv_q2_K_f32);
GGML_METAL_DEL_KERNEL(mul_mv_q3_K_f32);
Expand All @@ -370,8 +386,10 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
GGML_METAL_DEL_KERNEL(mul_mm_f32_f32);
GGML_METAL_DEL_KERNEL(mul_mm_f16_f32);
GGML_METAL_DEL_KERNEL(mul_mm_q4_0_f32);
GGML_METAL_DEL_KERNEL(mul_mm_q8_0_f32);
GGML_METAL_DEL_KERNEL(mul_mm_q4_1_f32);
GGML_METAL_DEL_KERNEL(mul_mm_q5_0_f32);
GGML_METAL_DEL_KERNEL(mul_mm_q5_1_f32);
GGML_METAL_DEL_KERNEL(mul_mm_q8_0_f32);
GGML_METAL_DEL_KERNEL(mul_mm_q2_K_f32);
GGML_METAL_DEL_KERNEL(mul_mm_q3_K_f32);
GGML_METAL_DEL_KERNEL(mul_mm_q4_K_f32);
Expand Down Expand Up @@ -1052,6 +1070,8 @@ void ggml_metal_graph_compute(
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f16_f32]; break;
case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_0_f32]; break;
case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_1_f32]; break;
case GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q5_0_f32]; break;
case GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q5_1_f32]; break;
case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q8_0_f32]; break;
case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q2_K_f32]; break;
case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q3_K_f32]; break;
Expand Down Expand Up @@ -1121,6 +1141,24 @@ void ggml_metal_graph_compute(
nth1 = 8;
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_1_f32];
} break;
case GGML_TYPE_Q5_0:
{
GGML_ASSERT(ne02 == 1);
GGML_ASSERT(ne12 == 1);

nth0 = 8;
nth1 = 8;
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_0_f32];
} break;
case GGML_TYPE_Q5_1:
{
GGML_ASSERT(ne02 == 1);
GGML_ASSERT(ne12 == 1);

nth0 = 8;
nth1 = 8;
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_1_f32];
} break;
case GGML_TYPE_Q8_0:
{
GGML_ASSERT(ne02 == 1);
Expand Down Expand Up @@ -1201,7 +1239,8 @@ void ggml_metal_graph_compute(
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:16];
[encoder setBytes:&gqa length:sizeof(gqa) atIndex:17];

if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q8_0 ||
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
src0t == GGML_TYPE_Q5_0 || src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 ||
src0t == GGML_TYPE_Q2_K) { // || src0t == GGML_TYPE_Q4_K) {
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
}
Expand Down Expand Up @@ -1233,6 +1272,8 @@ void ggml_metal_graph_compute(
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break;
case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_1]; break;
case GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_0]; break;
case GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_1]; break;
case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q8_0]; break;
case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q2_K]; break;
case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q3_K]; break;
Expand Down
163 changes: 162 additions & 1 deletion ggml-metal.metal
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,21 @@ typedef struct {
uint8_t qs[QK4_1 / 2]; // nibbles / quants
} block_q4_1;

#define QK5_0 32
typedef struct {
half d; // delta
uint8_t qh[4]; // 5-th bit of quants
uint8_t qs[QK5_0 / 2]; // nibbles / quants
} block_q5_0;

#define QK5_1 32
typedef struct {
half d; // delta
half m; // min
uint8_t qh[4]; // 5-th bit of quants
uint8_t qs[QK5_1 / 2]; // nibbles / quants
} block_q5_1;

#define QK8_0 32
typedef struct {
half d; // delta
Expand Down Expand Up @@ -399,8 +414,11 @@ kernel void kernel_rms_norm(
// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl, int il) {
float d = qb_curr->d;

float2 acc = 0.f;

device const uint16_t * qs = ((device const uint16_t *)qb_curr + 1 + il/2);

for (int i = 0; i < 8; i+=2) {
acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F)
+ yl[i + 1] * (qs[i / 2] & 0x0F00);
Expand All @@ -417,8 +435,11 @@ inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thre
inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thread float * yl, int il) {
float d = qb_curr->d;
float m = qb_curr->m;
device const uint16_t * qs = ((device const uint16_t *)qb_curr + 2 + il/2);

float2 acc = 0.f;

device const uint16_t * qs = ((device const uint16_t *)qb_curr + 2 + il/2);

for (int i = 0; i < 8; i+=2) {
acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F)
+ yl[i + 1] * (qs[i / 2] & 0x0F00);
Expand All @@ -428,6 +449,49 @@ inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thre
return d * (acc[0] + acc[1]) + sumy * m;
}

// function for calculate inner product between half a q5_0 block and 16 floats (yl), sumy is SUM(yl[i])
// il indicates where the q5 quants begin (0 or QK5_0/4)
// we assume that the yl's have been multiplied with the appropriate scale factor
// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
inline float block_q_n_dot_y(device const block_q5_0 * qb_curr, float sumy, thread float * yl, int il) {
float d = qb_curr->d;

float2 acc = 0.f;

device const uint16_t * qs = ((device const uint16_t *)qb_curr + 3 + il/2);
const uint32_t qh = *((device const uint32_t *)qb_curr->qh);

for (int i = 0; i < 8; i+=2) {
acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il ) << 4 ) & 0x00010))
+ yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il ) << 12) & 0x01000));
acc[1] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100))
+ yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000));
}
return d * (sumy * -16.f + acc[0] + acc[1]);
}

// function for calculate inner product between half a q5_1 block and 16 floats (yl), sumy is SUM(yl[i])
// il indicates where the q5 quants begin (0 or QK5_1/4)
// we assume that the yl's have been multiplied with the appropriate scale factor
// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thread float * yl, int il) {
float d = qb_curr->d;
float m = qb_curr->m;

float2 acc = 0.f;

device const uint16_t * qs = ((device const uint16_t *)qb_curr + 4 + il/2);
const uint32_t qh = *((device const uint32_t *)qb_curr->qh);

for (int i = 0; i < 8; i+=2) {
acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il ) << 4 ) & 0x00010))
+ yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il ) << 12) & 0x01000));
acc[1] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100))
+ yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000));
}
return d * (acc[0] + acc[1]) + sumy * m;
}

// putting them in the kernel cause a significant performance penalty
#define N_DST 4 // each SIMD group works on 4 rows
#define N_SIMDGROUP 2 // number of SIMD groups in a thread group
Expand Down Expand Up @@ -525,6 +589,43 @@ kernel void kernel_mul_mv_q4_1_f32(
mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
}

kernel void kernel_mul_mv_q5_0_f32(
device const void * src0,
device const float * src1,
device float * dst,
constant int64_t & ne00,
constant int64_t & ne01[[buffer(4)]],
constant int64_t & ne02[[buffer(5)]],
constant int64_t & ne10[[buffer(9)]],
constant int64_t & ne12[[buffer(11)]],
constant int64_t & ne0[[buffer(15)]],
constant int64_t & ne1[[buffer(16)]],
constant uint & gqa[[buffer(17)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
mul_vec_q_n_f32<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
}

kernel void kernel_mul_mv_q5_1_f32(
device const void * src0,
device const float * src1,
device float * dst,
constant int64_t & ne00,
constant int64_t & ne01[[buffer(4)]],
constant int64_t & ne02[[buffer(5)]],
constant int64_t & ne10[[buffer(9)]],
constant int64_t & ne12[[buffer(11)]],
constant int64_t & ne0[[buffer(15)]],
constant int64_t & ne1[[buffer(16)]],
constant uint & gqa[[buffer(17)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
mul_vec_q_n_f32<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
}


#define NB_Q8_0 8

kernel void kernel_mul_mv_q8_0_f32(
Expand Down Expand Up @@ -2149,6 +2250,62 @@ void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg
}
}

template <typename type4x4>
void dequantize_q5_0(device const block_q5_0 *xb, short il, thread type4x4 & reg) {
device const uint16_t * qs = ((device const uint16_t *)xb + 3);
const float d = xb->d;
const float md = -16.h * xb->d;
const ushort mask = il ? 0x00F0 : 0x000F;

const uint32_t qh = *((device const uint32_t *)xb->qh);

const int x_mv = il ? 4 : 0;

const int gh_mv = il ? 12 : 0;
const int gh_bk = il ? 0 : 4;

for (int i = 0; i < 8; i++) {
// extract the 5-th bits for x0 and x1
const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10;
const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;

// combine the 4-bits from qs with the 5th bit
const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0);
const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);

reg[i/2][2*(i%2)+0] = d * x0 + md;
reg[i/2][2*(i%2)+1] = d * x1 + md;
}
}

template <typename type4x4>
void dequantize_q5_1(device const block_q5_1 *xb, short il, thread type4x4 & reg) {
device const uint16_t * qs = ((device const uint16_t *)xb + 4);
const float d = xb->d;
const float m = xb->m;
const ushort mask = il ? 0x00F0 : 0x000F;

const uint32_t qh = *((device const uint32_t *)xb->qh);

const int x_mv = il ? 4 : 0;

const int gh_mv = il ? 12 : 0;
const int gh_bk = il ? 0 : 4;

for (int i = 0; i < 8; i++) {
// extract the 5-th bits for x0 and x1
const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10;
const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;

// combine the 4-bits from qs with the 5th bit
const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0);
const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);

reg[i/2][2*(i%2)+0] = d * x0 + m;
reg[i/2][2*(i%2)+1] = d * x1 + m;
}
}

template <typename type4x4>
void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg) {
device const int8_t * qs = ((device const int8_t *)xb->qs);
Expand Down Expand Up @@ -2490,6 +2647,8 @@ template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows
template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;
template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows<block_q4_0, 2, dequantize_q4_0>;
template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows<block_q4_1, 2, dequantize_q4_1>;
template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_t kernel_get_rows<block_q5_0, 2, dequantize_q5_0>;
template [[host_name("kernel_get_rows_q5_1")]] kernel get_rows_t kernel_get_rows<block_q5_1, 2, dequantize_q5_1>;
template [[host_name("kernel_get_rows_q8_0")]] kernel get_rows_t kernel_get_rows<block_q8_0, 2, dequantize_q8_0>;
template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_t kernel_get_rows<block_q2_K, QK_NL, dequantize_q2_K>;
template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_t kernel_get_rows<block_q3_K, QK_NL, dequantize_q3_K>;
Expand Down Expand Up @@ -2518,6 +2677,8 @@ template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm<f
template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>;
template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_0, 2, dequantize_q4_0>;
template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_1, 2, dequantize_q4_1>;
template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_0, 2, dequantize_q5_0>;
template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_1, 2, dequantize_q5_1>;
template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q8_0, 2, dequantize_q8_0>;
template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q2_K, QK_NL, dequantize_q2_K>;
template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q3_K, QK_NL, dequantize_q3_K>;
Expand Down