Skip to content

Commit c917b67

Browse files
authored
metal : template-ify some of the kernels (#8447)
ggml-ci
1 parent 4e24cff commit c917b67

File tree

2 files changed

+193
-564
lines changed

2 files changed

+193
-564
lines changed

ggml/src/ggml-metal.m

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -193,16 +193,16 @@
193193
//GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, // https://github.com/ggerganov/llama.cpp/issues/7261
194194
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,
195195
//GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, // https://github.com/ggerganov/llama.cpp/issues/7261
196-
GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
197196
GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
197+
GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
198+
GGML_METAL_KERNEL_TYPE_CPY_F16_F16,
199+
GGML_METAL_KERNEL_TYPE_CPY_F16_F32,
198200
GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0,
199201
GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0,
200202
GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1,
201203
GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0,
202204
GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1,
203205
GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL,
204-
GGML_METAL_KERNEL_TYPE_CPY_F16_F16,
205-
GGML_METAL_KERNEL_TYPE_CPY_F16_F32,
206206
GGML_METAL_KERNEL_TYPE_CONCAT,
207207
GGML_METAL_KERNEL_TYPE_SQR,
208208
GGML_METAL_KERNEL_TYPE_SUM_ROWS,
@@ -651,14 +651,14 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
651651
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, ctx->support_simdgroup_reduction);
652652
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
653653
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
654+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true);
655+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F32, cpy_f16_f32, true);
654656
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true);
655657
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, cpy_f32_q4_0, true);
656658
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, cpy_f32_q4_1, true);
657659
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, cpy_f32_q5_0, true);
658660
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, cpy_f32_q5_1, true);
659661
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL, cpy_f32_iq4_nl, true);
660-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true);
661-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F32, cpy_f16_f32, true);
662662
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT, concat, true);
663663
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQR, sqr, true);
664664
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
@@ -810,8 +810,8 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
810810
switch (op->src[0]->type) {
811811
case GGML_TYPE_F32:
812812
switch (op->type) {
813-
case GGML_TYPE_F16:
814813
case GGML_TYPE_F32:
814+
case GGML_TYPE_F16:
815815
case GGML_TYPE_Q8_0:
816816
case GGML_TYPE_Q4_0:
817817
case GGML_TYPE_Q4_1:
@@ -824,8 +824,8 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
824824
}
825825
case GGML_TYPE_F16:
826826
switch (op->type) {
827-
case GGML_TYPE_F16:
828827
case GGML_TYPE_F32:
828+
case GGML_TYPE_F16:
829829
return true;
830830
default:
831831
return false;
@@ -837,7 +837,7 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
837837
case GGML_OP_DIAG_MASK_INF:
838838
case GGML_OP_GET_ROWS:
839839
{
840-
return op->src[0]->type != GGML_TYPE_BF16 && op->ne[3] == 1;
840+
return op->ne[3] == 1;
841841
}
842842
default:
843843
return false;
@@ -1580,8 +1580,8 @@ static enum ggml_status ggml_metal_graph_compute(
15801580
// some Metal matrix data types require aligned pointers
15811581
// ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
15821582
switch (src0->type) {
1583-
case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break;
1584-
case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8 == 0); break;
1583+
case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break;
1584+
case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8 == 0); break;
15851585
default: break;
15861586
}
15871587

@@ -2775,8 +2775,8 @@ static enum ggml_status ggml_metal_graph_compute(
27752775
GGML_ASSERT(ne0 % ggml_blck_size(dst->type) == 0);
27762776

27772777
switch (dstt) {
2778-
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F16].pipeline; break;
2779-
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline; break;
2778+
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline; break;
2779+
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F16].pipeline; break;
27802780
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0].pipeline; break;
27812781
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0].pipeline; break;
27822782
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1].pipeline; break;
@@ -2789,8 +2789,8 @@ static enum ggml_status ggml_metal_graph_compute(
27892789
case GGML_TYPE_F16:
27902790
{
27912791
switch (dstt) {
2792-
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F16].pipeline; break;
2793-
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F32].pipeline; break;
2792+
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F32].pipeline; break;
2793+
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F16].pipeline; break;
27942794
default: GGML_ASSERT(false && "not implemented");
27952795
};
27962796
} break;

0 commit comments

Comments
 (0)