Skip to content

Commit 18ddfbc

Browse files
committed
ggml : add mul_mat
1 parent 42b5324 commit 18ddfbc

File tree

2 files changed

+170
-337
lines changed

2 files changed

+170
-337
lines changed

ggml/src/ggml-metal.m

Lines changed: 38 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,8 @@
5858
GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF,
5959
GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8,
6060
GGML_METAL_KERNEL_TYPE_GET_ROWS_F32,
61-
GGML_METAL_KERNEL_TYPE_GET_ROWS_F16,
6261
GGML_METAL_KERNEL_TYPE_GET_ROWS_BF16,
62+
GGML_METAL_KERNEL_TYPE_GET_ROWS_F16,
6363
GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0,
6464
GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1,
6565
GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0,
@@ -84,6 +84,10 @@
8484
GGML_METAL_KERNEL_TYPE_GROUP_NORM,
8585
GGML_METAL_KERNEL_TYPE_NORM,
8686
GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,
87+
GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16,
88+
GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32,
89+
GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW,
90+
GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4,
8791
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16,
8892
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32,
8993
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW,
@@ -132,6 +136,7 @@
132136
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32,
133137
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32,
134138
GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32,
139+
GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32,
135140
GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32,
136141
GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32,
137142
GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32,
@@ -515,8 +520,8 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
515520
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF, diag_mask_inf, true);
516521
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8, diag_mask_inf_8, true);
517522
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F32, get_rows_f32, true);
518-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F16, get_rows_f16, true);
519523
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_BF16, get_rows_bf16, true);
524+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F16, get_rows_f16, true);
520525
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0, get_rows_q4_0, true);
521526
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1, get_rows_q4_1, true);
522527
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0, get_rows_q5_0, true);
@@ -541,6 +546,10 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
541546
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, ctx->support_simdgroup_reduction);
542547
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true);
543548
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, ctx->support_simdgroup_reduction);
549+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16, mul_mv_bf16_bf16, ctx->support_simdgroup_reduction);
550+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32, mul_mv_bf16_f32, ctx->support_simdgroup_reduction);
551+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW, mul_mv_bf16_f32_1row, ctx->support_simdgroup_reduction);
552+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4, mul_mv_bf16_f32_l4, ctx->support_simdgroup_reduction);
544553
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, ctx->support_simdgroup_reduction);
545554
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, ctx->support_simdgroup_reduction);
546555
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, mul_mv_f16_f32_1row, ctx->support_simdgroup_reduction);
@@ -589,6 +598,7 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
589598
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32, mul_mv_id_iq4_nl_f32, ctx->support_simdgroup_reduction);
590599
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32, mul_mv_id_iq4_xs_f32, ctx->support_simdgroup_reduction);
591600
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, ctx->support_simdgroup_mm);
601+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32, mul_mm_bf16_f32, ctx->support_simdgroup_mm);
592602
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, ctx->support_simdgroup_mm);
593603
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, ctx->support_simdgroup_mm);
594604
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32, mul_mm_q4_1_f32, ctx->support_simdgroup_mm);
@@ -739,7 +749,8 @@ static void ggml_metal_free(struct ggml_metal_context * ctx) {
739749
static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const struct ggml_tensor * op) {
740750
for (size_t i = 0, n = 3; i < n; ++i) {
741751
if (op->src[i] != NULL && op->src[i]->type == GGML_TYPE_BF16 &&
742-
op->op != GGML_OP_GET_ROWS) {
752+
op->op != GGML_OP_GET_ROWS &&
753+
op->op != GGML_OP_MUL_MAT) {
743754
printf("op = %s, src[%zu] = %s\n", ggml_op_name(op->op), i, ggml_type_name(op->src[i]->type));
744755
GGML_ASSERT(false);
745756
}
@@ -1584,15 +1595,17 @@ static enum ggml_status ggml_metal_graph_compute(
15841595
// some Metal matrix data types require aligned pointers
15851596
// ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
15861597
switch (src0->type) {
1587-
case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break;
1588-
case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8 == 0); break;
1598+
case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break;
1599+
case GGML_TYPE_BF16: GGML_ASSERT(nb01 % 8 == 0); break;
1600+
case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8 == 0); break;
15891601
default: break;
15901602
}
15911603

15921604
id<MTLComputePipelineState> pipeline = nil;
15931605

15941606
switch (src0->type) {
15951607
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32 ].pipeline; break;
1608+
case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32 ].pipeline; break;
15961609
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32 ].pipeline; break;
15971610
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32 ].pipeline; break;
15981611
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32 ].pipeline; break;
@@ -1669,6 +1682,25 @@ static enum ggml_status ggml_metal_graph_compute(
16691682
nrows = 4;
16701683
}
16711684
} break;
1685+
case GGML_TYPE_BF16:
1686+
{
1687+
nth0 = 32;
1688+
nth1 = 1;
1689+
if (src1t == GGML_TYPE_F32) {
1690+
if (ne11 * ne12 < 4) {
1691+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW].pipeline;
1692+
} else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
1693+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4].pipeline;
1694+
nrows = ne11;
1695+
} else {
1696+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32].pipeline;
1697+
nrows = 4;
1698+
}
1699+
} else {
1700+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16].pipeline;
1701+
nrows = 4;
1702+
}
1703+
} break;
16721704
case GGML_TYPE_Q4_0:
16731705
{
16741706
nth0 = 8;
@@ -2165,8 +2197,8 @@ static enum ggml_status ggml_metal_graph_compute(
21652197

21662198
switch (src0->type) {
21672199
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_F32 ].pipeline; break;
2168-
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_F16 ].pipeline; break;
21692200
case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_BF16 ].pipeline; break;
2201+
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_F16 ].pipeline; break;
21702202
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0 ].pipeline; break;
21712203
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1 ].pipeline; break;
21722204
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0 ].pipeline; break;

0 commit comments

Comments
 (0)