58
58
GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF,
59
59
GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8,
60
60
GGML_METAL_KERNEL_TYPE_GET_ROWS_F32,
61
- GGML_METAL_KERNEL_TYPE_GET_ROWS_F16,
62
61
GGML_METAL_KERNEL_TYPE_GET_ROWS_BF16,
62
+ GGML_METAL_KERNEL_TYPE_GET_ROWS_F16,
63
63
GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0,
64
64
GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1,
65
65
GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0,
84
84
GGML_METAL_KERNEL_TYPE_GROUP_NORM,
85
85
GGML_METAL_KERNEL_TYPE_NORM,
86
86
GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,
87
+ GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16,
88
+ GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32,
89
+ GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW,
90
+ GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4,
87
91
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16,
88
92
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32,
89
93
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW,
132
136
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32,
133
137
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32,
134
138
GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32,
139
+ GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32,
135
140
GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32,
136
141
GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32,
137
142
GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32,
@@ -515,8 +520,8 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
515
520
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF, diag_mask_inf, true );
516
521
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8, diag_mask_inf_8, true );
517
522
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_GET_ROWS_F32, get_rows_f32, true );
518
- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_GET_ROWS_F16, get_rows_f16, true );
519
523
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_GET_ROWS_BF16, get_rows_bf16, true );
524
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_GET_ROWS_F16, get_rows_f16, true );
520
525
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0, get_rows_q4_0, true );
521
526
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1, get_rows_q4_1, true );
522
527
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0, get_rows_q5_0, true );
@@ -541,6 +546,10 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
541
546
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, ctx->support_simdgroup_reduction );
542
547
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_NORM, norm, true );
543
548
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, ctx->support_simdgroup_reduction );
549
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16, mul_mv_bf16_bf16, ctx->support_simdgroup_reduction );
550
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32, mul_mv_bf16_f32, ctx->support_simdgroup_reduction );
551
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW, mul_mv_bf16_f32_1row, ctx->support_simdgroup_reduction );
552
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4, mul_mv_bf16_f32_l4, ctx->support_simdgroup_reduction );
544
553
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, ctx->support_simdgroup_reduction );
545
554
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, ctx->support_simdgroup_reduction );
546
555
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, mul_mv_f16_f32_1row, ctx->support_simdgroup_reduction );
@@ -589,6 +598,7 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
589
598
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32, mul_mv_id_iq4_nl_f32, ctx->support_simdgroup_reduction );
590
599
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32, mul_mv_id_iq4_xs_f32, ctx->support_simdgroup_reduction );
591
600
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, ctx->support_simdgroup_mm );
601
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32, mul_mm_bf16_f32, ctx->support_simdgroup_mm );
592
602
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, ctx->support_simdgroup_mm );
593
603
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, ctx->support_simdgroup_mm );
594
604
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32, mul_mm_q4_1_f32, ctx->support_simdgroup_mm );
@@ -739,7 +749,8 @@ static void ggml_metal_free(struct ggml_metal_context * ctx) {
739
749
static bool ggml_metal_supports_op (const struct ggml_metal_context * ctx, const struct ggml_tensor * op) {
740
750
for (size_t i = 0 , n = 3 ; i < n; ++i) {
741
751
if (op->src [i] != NULL && op->src [i]->type == GGML_TYPE_BF16 &&
742
- op->op != GGML_OP_GET_ROWS) {
752
+ op->op != GGML_OP_GET_ROWS &&
753
+ op->op != GGML_OP_MUL_MAT) {
743
754
printf (" op = %s , src[%zu ] = %s \n " , ggml_op_name (op->op ), i, ggml_type_name (op->src [i]->type ));
744
755
GGML_ASSERT (false );
745
756
}
@@ -1584,15 +1595,17 @@ static enum ggml_status ggml_metal_graph_compute(
1584
1595
// some Metal matrix data types require aligned pointers
1585
1596
// ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
1586
1597
switch (src0->type ) {
1587
- case GGML_TYPE_F32: GGML_ASSERT (nb01 % 16 == 0 ); break ;
1588
- case GGML_TYPE_F16: GGML_ASSERT (nb01 % 8 == 0 ); break ;
1598
+ case GGML_TYPE_F32: GGML_ASSERT (nb01 % 16 == 0 ); break ;
1599
+ case GGML_TYPE_BF16: GGML_ASSERT (nb01 % 8 == 0 ); break ;
1600
+ case GGML_TYPE_F16: GGML_ASSERT (nb01 % 8 == 0 ); break ;
1589
1601
default : break ;
1590
1602
}
1591
1603
1592
1604
id <MTLComputePipelineState > pipeline = nil ;
1593
1605
1594
1606
switch (src0->type ) {
1595
1607
case GGML_TYPE_F32: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32 ].pipeline ; break ;
1608
+ case GGML_TYPE_BF16: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32 ].pipeline ; break ;
1596
1609
case GGML_TYPE_F16: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32 ].pipeline ; break ;
1597
1610
case GGML_TYPE_Q4_0: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32 ].pipeline ; break ;
1598
1611
case GGML_TYPE_Q4_1: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32 ].pipeline ; break ;
@@ -1669,6 +1682,25 @@ static enum ggml_status ggml_metal_graph_compute(
1669
1682
nrows = 4 ;
1670
1683
}
1671
1684
} break ;
1685
+ case GGML_TYPE_BF16:
1686
+ {
1687
+ nth0 = 32 ;
1688
+ nth1 = 1 ;
1689
+ if (src1t == GGML_TYPE_F32) {
1690
+ if (ne11 * ne12 < 4 ) {
1691
+ pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW].pipeline ;
1692
+ } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0 ) {
1693
+ pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4].pipeline ;
1694
+ nrows = ne11;
1695
+ } else {
1696
+ pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32].pipeline ;
1697
+ nrows = 4 ;
1698
+ }
1699
+ } else {
1700
+ pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16].pipeline ;
1701
+ nrows = 4 ;
1702
+ }
1703
+ } break ;
1672
1704
case GGML_TYPE_Q4_0:
1673
1705
{
1674
1706
nth0 = 8 ;
@@ -2165,8 +2197,8 @@ static enum ggml_status ggml_metal_graph_compute(
2165
2197
2166
2198
switch (src0->type ) {
2167
2199
case GGML_TYPE_F32: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_GET_ROWS_F32 ].pipeline ; break ;
2168
- case GGML_TYPE_F16: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_GET_ROWS_F16 ].pipeline ; break ;
2169
2200
case GGML_TYPE_BF16: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_GET_ROWS_BF16 ].pipeline ; break ;
2201
+ case GGML_TYPE_F16: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_GET_ROWS_F16 ].pipeline ; break ;
2170
2202
case GGML_TYPE_Q4_0: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0 ].pipeline ; break ;
2171
2203
case GGML_TYPE_Q4_1: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1 ].pipeline ; break ;
2172
2204
case GGML_TYPE_Q5_0: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0 ].pipeline ; break ;
0 commit comments