193
193
// GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, // https://github.com/ggerganov/llama.cpp/issues/7261
194
194
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,
195
195
// GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, // https://github.com/ggerganov/llama.cpp/issues/7261
196
- GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
197
196
GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
197
+ GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
198
+ GGML_METAL_KERNEL_TYPE_CPY_F16_F16,
199
+ GGML_METAL_KERNEL_TYPE_CPY_F16_F32,
198
200
GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0,
199
201
GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0,
200
202
GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1,
201
203
GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0,
202
204
GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1,
203
205
GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL,
204
- GGML_METAL_KERNEL_TYPE_CPY_F16_F16,
205
- GGML_METAL_KERNEL_TYPE_CPY_F16_F32,
206
206
GGML_METAL_KERNEL_TYPE_CONCAT,
207
207
GGML_METAL_KERNEL_TYPE_SQR,
208
208
GGML_METAL_KERNEL_TYPE_SUM_ROWS,
@@ -651,14 +651,14 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
651
651
// GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, ctx->support_simdgroup_reduction);
652
652
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true );
653
653
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true );
654
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true );
655
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_F16_F32, cpy_f16_f32, true );
654
656
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true );
655
657
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, cpy_f32_q4_0, true );
656
658
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, cpy_f32_q4_1, true );
657
659
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, cpy_f32_q5_0, true );
658
660
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, cpy_f32_q5_1, true );
659
661
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL, cpy_f32_iq4_nl, true );
660
- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true );
661
- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_F16_F32, cpy_f16_f32, true );
662
662
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CONCAT, concat, true );
663
663
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SQR, sqr, true );
664
664
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true );
@@ -810,8 +810,8 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
810
810
switch (op->src [0 ]->type ) {
811
811
case GGML_TYPE_F32:
812
812
switch (op->type ) {
813
- case GGML_TYPE_F16:
814
813
case GGML_TYPE_F32:
814
+ case GGML_TYPE_F16:
815
815
case GGML_TYPE_Q8_0:
816
816
case GGML_TYPE_Q4_0:
817
817
case GGML_TYPE_Q4_1:
@@ -824,8 +824,8 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
824
824
}
825
825
case GGML_TYPE_F16:
826
826
switch (op->type ) {
827
- case GGML_TYPE_F16:
828
827
case GGML_TYPE_F32:
828
+ case GGML_TYPE_F16:
829
829
return true ;
830
830
default :
831
831
return false ;
@@ -837,7 +837,7 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
837
837
case GGML_OP_DIAG_MASK_INF:
838
838
case GGML_OP_GET_ROWS:
839
839
{
840
- return op->src [ 0 ]-> type != GGML_TYPE_BF16 && op-> ne [3 ] == 1 ;
840
+ return op->ne [3 ] == 1 ;
841
841
}
842
842
default :
843
843
return false ;
@@ -1580,8 +1580,8 @@ static enum ggml_status ggml_metal_graph_compute(
1580
1580
// some Metal matrix data types require aligned pointers
1581
1581
// ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
1582
1582
switch (src0->type ) {
1583
- case GGML_TYPE_F32: GGML_ASSERT (nb01 % 16 == 0 ); break ;
1584
- case GGML_TYPE_F16: GGML_ASSERT (nb01 % 8 == 0 ); break ;
1583
+ case GGML_TYPE_F32: GGML_ASSERT (nb01 % 16 == 0 ); break ;
1584
+ case GGML_TYPE_F16: GGML_ASSERT (nb01 % 8 == 0 ); break ;
1585
1585
default : break ;
1586
1586
}
1587
1587
@@ -2775,8 +2775,8 @@ static enum ggml_status ggml_metal_graph_compute(
2775
2775
GGML_ASSERT (ne0 % ggml_blck_size (dst->type ) == 0 );
2776
2776
2777
2777
switch (dstt) {
2778
- case GGML_TYPE_F16 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_CPY_F32_F16 ].pipeline ; break ;
2779
- case GGML_TYPE_F32 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_CPY_F32_F32 ].pipeline ; break ;
2778
+ case GGML_TYPE_F32 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_CPY_F32_F32 ].pipeline ; break ;
2779
+ case GGML_TYPE_F16 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_CPY_F32_F16 ].pipeline ; break ;
2780
2780
case GGML_TYPE_Q8_0: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0].pipeline ; break ;
2781
2781
case GGML_TYPE_Q4_0: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0].pipeline ; break ;
2782
2782
case GGML_TYPE_Q4_1: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1].pipeline ; break ;
@@ -2789,8 +2789,8 @@ static enum ggml_status ggml_metal_graph_compute(
2789
2789
case GGML_TYPE_F16:
2790
2790
{
2791
2791
switch (dstt) {
2792
- case GGML_TYPE_F16: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_CPY_F16_F16 ].pipeline ; break ;
2793
- case GGML_TYPE_F32: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_CPY_F16_F32 ].pipeline ; break ;
2792
+ case GGML_TYPE_F32: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_CPY_F16_F32 ].pipeline ; break ;
2793
+ case GGML_TYPE_F16: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_CPY_F16_F16 ].pipeline ; break ;
2794
2794
default : GGML_ASSERT (false && " not implemented" );
2795
2795
};
2796
2796
} break ;
0 commit comments