Skip to content

Commit 6109cf1

Browse files
committed
ggml : add initial BF16 support
ggml-ci
1 parent 1dc04b2 commit 6109cf1

File tree

3 files changed

+80
-8
lines changed

3 files changed

+80
-8
lines changed

common/common.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1003,6 +1003,9 @@ static ggml_type kv_cache_type_from_str(const std::string & s) {
10031003
if (s == "f16") {
10041004
return GGML_TYPE_F16;
10051005
}
1006+
if (s == "bf16") {
1007+
return GGML_TYPE_BF16;
1008+
}
10061009
if (s == "q8_0") {
10071010
return GGML_TYPE_Q8_0;
10081011
}

ggml/src/ggml-metal.m

Lines changed: 61 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
120120
GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8,
121121
GGML_METAL_KERNEL_TYPE_GET_ROWS_F32,
122122
GGML_METAL_KERNEL_TYPE_GET_ROWS_F16,
123+
GGML_METAL_KERNEL_TYPE_GET_ROWS_BF16,
123124
GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0,
124125
GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1,
125126
GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0,
@@ -150,6 +151,10 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
150151
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32,
151152
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW,
152153
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4,
154+
GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32,
155+
GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW,
156+
GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4,
157+
GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16,
153158
GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32,
154159
GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32,
155160
GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32,
@@ -195,6 +200,7 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
195200
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32,
196201
GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32,
197202
GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32,
203+
GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32,
198204
GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32,
199205
GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32,
200206
GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32,
@@ -300,8 +306,10 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
300306
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256,
301307
GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
302308
GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
309+
GGML_METAL_KERNEL_TYPE_CPY_F32_BF16,
303310
GGML_METAL_KERNEL_TYPE_CPY_F16_F16,
304311
GGML_METAL_KERNEL_TYPE_CPY_F16_F32,
312+
GGML_METAL_KERNEL_TYPE_CPY_BF16_F32,
305313
GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0,
306314
GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0,
307315
GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1,
@@ -615,6 +623,7 @@ @implementation GGMLMetalClass
615623
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8, diag_mask_inf_8, true);
616624
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F32, get_rows_f32, true);
617625
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F16, get_rows_f16, true);
626+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_BF16, get_rows_bf16, true);
618627
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0, get_rows_q4_0, true);
619628
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1, get_rows_q4_1, true);
620629
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0, get_rows_q5_0, true);
@@ -641,6 +650,10 @@ @implementation GGMLMetalClass
641650
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, ssm_conv_f32, true);
642651
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, ssm_scan_f32, true);
643652
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, support_simdgroup_reduction);
653+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32, mul_mv_bf16_f32, support_simdgroup_reduction);
654+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW, mul_mv_bf16_f32_1row, support_simdgroup_reduction);
655+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4, mul_mv_bf16_f32_l4, support_simdgroup_reduction);
656+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16, mul_mv_bf16_bf16, support_simdgroup_reduction);
644657
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, support_simdgroup_reduction);
645658
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, support_simdgroup_reduction);
646659
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, mul_mv_f16_f32_1row, support_simdgroup_reduction);
@@ -690,6 +703,7 @@ @implementation GGMLMetalClass
690703
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32, mul_mv_id_iq4_xs_f32, support_simdgroup_reduction);
691704
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, support_simdgroup_mm);
692705
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, support_simdgroup_mm);
706+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32, mul_mm_bf16_f32, support_simdgroup_mm);
693707
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, support_simdgroup_mm);
694708
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32, mul_mm_q4_1_f32, support_simdgroup_mm);
695709
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32, mul_mm_q5_0_f32, support_simdgroup_mm);
@@ -793,10 +807,12 @@ @implementation GGMLMetalClass
793807
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256, flash_attn_ext_vec_q5_0_h256, support_simdgroup_reduction);
794808
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256, flash_attn_ext_vec_q5_1_h256, support_simdgroup_reduction);
795809
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256, flash_attn_ext_vec_q8_0_h256, support_simdgroup_reduction);
796-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
797810
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
798-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true);
811+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
812+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_BF16, cpy_f32_bf16, true);
799813
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F32, cpy_f16_f32, true);
814+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true);
815+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_BF16_F32, cpy_bf16_f32, true);
800816
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true);
801817
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, cpy_f32_q4_0, true);
802818
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, cpy_f32_q4_1, true);
@@ -887,8 +903,13 @@ static void ggml_metal_free(struct ggml_backend_metal_context * ctx) {
887903

888904
static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_context * ctx_dev, const struct ggml_tensor * op) {
889905
for (size_t i = 0, n = 3; i < n; ++i) {
890-
if (op->src[i] != NULL && op->src[i]->type == GGML_TYPE_BF16) {
891-
return false;
906+
if (op->src[i] != NULL && op->src[i]->type == GGML_TYPE_BF16 &&
907+
op->op != GGML_OP_GET_ROWS &&
908+
op->op != GGML_OP_MUL_MAT &&
909+
op->op != GGML_OP_VIEW &&
910+
op->op != GGML_OP_CPY) {
911+
GGML_LOG_ERROR("unsupported BF16 op = %s, src[%zu] = %s\n", ggml_op_name(op->op), i, ggml_type_name(op->src[i]->type));
912+
GGML_ASSERT(false);
892913
}
893914
}
894915

@@ -969,6 +990,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
969990
switch (op->type) {
970991
case GGML_TYPE_F32:
971992
case GGML_TYPE_F16:
993+
case GGML_TYPE_BF16:
972994
case GGML_TYPE_Q8_0:
973995
case GGML_TYPE_Q4_0:
974996
case GGML_TYPE_Q4_1:
@@ -980,11 +1002,13 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
9801002
return false;
9811003
}
9821004
case GGML_TYPE_F16:
1005+
case GGML_TYPE_BF16:
9831006
switch (op->type) {
984-
case GGML_TYPE_F32:
985-
case GGML_TYPE_F16:
1007+
case GGML_TYPE_F32:
1008+
case GGML_TYPE_F16:
1009+
case GGML_TYPE_BF16:
9861010
return true;
987-
default:
1011+
default:
9881012
return false;
9891013
}
9901014
default:
@@ -1855,6 +1879,7 @@ static void ggml_metal_encode_node(
18551879
switch (src0->type) {
18561880
case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break;
18571881
case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8 == 0); break;
1882+
case GGML_TYPE_BF16: GGML_ASSERT(nb01 % 8 == 0); break;
18581883
default: break;
18591884
}
18601885

@@ -1863,6 +1888,7 @@ static void ggml_metal_encode_node(
18631888
switch (src0->type) {
18641889
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32 ].pipeline; break;
18651890
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32 ].pipeline; break;
1891+
case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32 ].pipeline; break;
18661892
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32 ].pipeline; break;
18671893
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32 ].pipeline; break;
18681894
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32 ].pipeline; break;
@@ -1940,6 +1966,25 @@ static void ggml_metal_encode_node(
19401966
nrows = 4;
19411967
}
19421968
} break;
1969+
case GGML_TYPE_BF16:
1970+
{
1971+
nth0 = 32;
1972+
nth1 = 1;
1973+
if (src1t == GGML_TYPE_F32) {
1974+
if (ne11 * ne12 < 4) {
1975+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW].pipeline;
1976+
} else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
1977+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4].pipeline;
1978+
nrows = ne11;
1979+
} else {
1980+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32].pipeline;
1981+
nrows = 4;
1982+
}
1983+
} else {
1984+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16].pipeline;
1985+
nrows = 4;
1986+
}
1987+
} break;
19431988
case GGML_TYPE_Q4_0:
19441989
{
19451990
nth0 = 8;
@@ -2438,6 +2483,7 @@ static void ggml_metal_encode_node(
24382483
switch (src0->type) {
24392484
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_F32 ].pipeline; break;
24402485
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_F16 ].pipeline; break;
2486+
case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_BF16 ].pipeline; break;
24412487
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0 ].pipeline; break;
24422488
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1 ].pipeline; break;
24432489
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0 ].pipeline; break;
@@ -3237,6 +3283,7 @@ static void ggml_metal_encode_node(
32373283
switch (dstt) {
32383284
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline; break;
32393285
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F16].pipeline; break;
3286+
case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_BF16].pipeline; break;
32403287
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0].pipeline; break;
32413288
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0].pipeline; break;
32423289
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1].pipeline; break;
@@ -3254,6 +3301,13 @@ static void ggml_metal_encode_node(
32543301
default: GGML_ABORT("not implemented");
32553302
};
32563303
} break;
3304+
case GGML_TYPE_BF16:
3305+
{
3306+
switch (dstt) {
3307+
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_BF16_F32].pipeline; break;
3308+
default: GGML_ASSERT(false && "not implemented");
3309+
};
3310+
} break;
32573311
default: GGML_ABORT("not implemented");
32583312
}
32593313

ggml/src/ggml-metal.metal

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ constexpr constant static float kvalues_iq4nl_f[16] = {
1616
-127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f, 1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f
1717
};
1818

19+
typedef matrix<bfloat, 4, 4> bfloat4x4;
20+
1921
// NOTE: this is not dequantizing - we are simply fitting the template
2022
template <typename type4x4>
2123
void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) {
@@ -27,6 +29,11 @@ void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg)
2729
reg = (type4x4)(*src);
2830
}
2931

32+
template <typename type4x4>
33+
void dequantize_bf16(device const bfloat4x4 * src, short il, thread type4x4 & reg) {
34+
reg = (type4x4)(*src);
35+
}
36+
3037
template <typename type4x4>
3138
void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) {
3239
device const uint16_t * qs = ((device const uint16_t *)xb + 1);
@@ -2041,6 +2048,8 @@ typedef decltype(kernel_mul_mv<half, half4, half, half4>) mul_mv_t;
20412048
template [[host_name("kernel_mul_mv_f32_f32")]] kernel mul_mv_t kernel_mul_mv<float, float4, float, float4>;
20422049
template [[host_name("kernel_mul_mv_f16_f32")]] kernel mul_mv_t kernel_mul_mv<half, half4, float, float4>;
20432050
template [[host_name("kernel_mul_mv_f16_f16")]] kernel mul_mv_t kernel_mul_mv<half, half4, half, half4>;
2051+
template [[host_name("kernel_mul_mv_bf16_f32")]] kernel mul_mv_t kernel_mul_mv<bfloat, bfloat4, float, float4>;
2052+
template [[host_name("kernel_mul_mv_bf16_bf16")]] kernel mul_mv_t kernel_mul_mv<bfloat, bfloat4, bfloat, bfloat4>;
20442053

20452054
template<typename T, typename T4>
20462055
kernel void kernel_mul_mv_1row(
@@ -2110,6 +2119,7 @@ kernel void kernel_mul_mv_1row(
21102119
typedef decltype(kernel_mul_mv_1row<half, half4>) mul_mv_1row_t;
21112120

21122121
template [[host_name("kernel_mul_mv_f16_f32_1row")]] kernel mul_mv_1row_t kernel_mul_mv_1row<half, half4>;
2122+
template [[host_name("kernel_mul_mv_bf16_f32_1row")]] kernel mul_mv_1row_t kernel_mul_mv_1row<bfloat, bfloat4>;
21132123

21142124
// Assumes row size (ne00) is a multiple of 4
21152125
template<typename T, typename T4>
@@ -2169,6 +2179,7 @@ kernel void kernel_mul_mv_l4(
21692179
typedef decltype(kernel_mul_mv_l4<half, half4>) mul_mv_l4_t;
21702180

21712181
template [[host_name("kernel_mul_mv_f16_f32_l4")]] kernel mul_mv_l4_t kernel_mul_mv_l4<half, half4>;
2182+
template [[host_name("kernel_mul_mv_bf16_f32_l4")]] kernel mul_mv_l4_t kernel_mul_mv_l4<bfloat, bfloat4>;
21722183

21732184
static float rope_yarn_ramp(const float low, const float high, const int i0) {
21742185
const float y = (i0 / 2 - low) / max(0.001f, high - low);
@@ -3567,8 +3578,10 @@ typedef decltype(kernel_cpy<float, float>) kernel_cpy_t;
35673578

35683579
template [[host_name("kernel_cpy_f32_f32")]] kernel kernel_cpy_t kernel_cpy<float, float>;
35693580
template [[host_name("kernel_cpy_f32_f16")]] kernel kernel_cpy_t kernel_cpy<float, half>;
3570-
template [[host_name("kernel_cpy_f16_f16")]] kernel kernel_cpy_t kernel_cpy<half, half>;
3581+
template [[host_name("kernel_cpy_f32_bf16")]] kernel kernel_cpy_t kernel_cpy<float, bfloat>;
35713582
template [[host_name("kernel_cpy_f16_f32")]] kernel kernel_cpy_t kernel_cpy<half, float>;
3583+
template [[host_name("kernel_cpy_f16_f16")]] kernel kernel_cpy_t kernel_cpy<half, half>;
3584+
template [[host_name("kernel_cpy_bf16_f32")]] kernel kernel_cpy_t kernel_cpy<bfloat, float>;
35723585

35733586
kernel void kernel_cpy_f32_q8_0(
35743587
device const float * src0,
@@ -6473,6 +6486,7 @@ typedef decltype(kernel_get_rows_f<float>) get_rows_f_t;
64736486

64746487
template [[host_name("kernel_get_rows_f32")]] kernel get_rows_f_t kernel_get_rows_f<float>;
64756488
template [[host_name("kernel_get_rows_f16")]] kernel get_rows_f_t kernel_get_rows_f<half>;
6489+
template [[host_name("kernel_get_rows_bf16")]] kernel get_rows_f_t kernel_get_rows_f<bfloat>;
64766490

64776491
typedef decltype(kernel_get_rows_q<block_q4_0, 2, dequantize_q4_0>) get_rows_q_t;
64786492

@@ -6504,6 +6518,7 @@ typedef decltype(kernel_mul_mm<half, half4x4, simdgroup_half8x8, float4x4, 1, de
65046518

65056519
template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, float4x4, 1, dequantize_f32>;
65066520
template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half4x4, 1, dequantize_f16>;
6521+
template [[host_name("kernel_mul_mm_bf16_f32")]] kernel mat_mm_t kernel_mul_mm<bfloat, bfloat4x4, simdgroup_bfloat8x8, bfloat4x4, 1, dequantize_bf16>;
65076522
template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0>;
65086523
template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1>;
65096524
template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0>;

0 commit comments

Comments
 (0)