Skip to content

Commit ce21cc8

Browse files
committed
ggml : add cpy
ggml-ci
1 parent 18ddfbc commit ce21cc8

File tree

3 files changed

+37
-139
lines changed

3 files changed

+37
-139
lines changed

common/common.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2160,6 +2160,9 @@ static ggml_type kv_cache_type_from_str(const std::string & s) {
21602160
if (s == "f32") {
21612161
return GGML_TYPE_F32;
21622162
}
2163+
if (s == "bf16") {
2164+
return GGML_TYPE_BF16;
2165+
}
21632166
if (s == "f16") {
21642167
return GGML_TYPE_F16;
21652168
}

ggml/src/ggml-metal.m

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -199,16 +199,18 @@
199199
//GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, // https://github.com/ggerganov/llama.cpp/issues/7261
200200
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,
201201
//GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, // https://github.com/ggerganov/llama.cpp/issues/7261
202-
GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
203202
GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
203+
GGML_METAL_KERNEL_TYPE_CPY_F32_BF16,
204+
GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
205+
GGML_METAL_KERNEL_TYPE_CPY_BF16_F32,
206+
GGML_METAL_KERNEL_TYPE_CPY_F16_F16,
207+
GGML_METAL_KERNEL_TYPE_CPY_F16_F32,
204208
GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0,
205209
GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0,
206210
GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1,
207211
GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0,
208212
GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1,
209213
GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL,
210-
GGML_METAL_KERNEL_TYPE_CPY_F16_F16,
211-
GGML_METAL_KERNEL_TYPE_CPY_F16_F32,
212214
GGML_METAL_KERNEL_TYPE_CONCAT,
213215
GGML_METAL_KERNEL_TYPE_SQR,
214216
GGML_METAL_KERNEL_TYPE_SUM_ROWS,
@@ -661,16 +663,18 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
661663
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, ctx->support_simdgroup_mm);
662664
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, ctx->support_simdgroup_reduction);
663665
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, ctx->support_simdgroup_reduction);
666+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_BF16, cpy_f32_bf16, true);
664667
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
665668
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
669+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_BF16_F32, cpy_bf16_f32, true);
670+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true);
671+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F32, cpy_f16_f32, true);
666672
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true);
667673
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, cpy_f32_q4_0, true);
668674
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, cpy_f32_q4_1, true);
669675
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, cpy_f32_q5_0, true);
670676
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, cpy_f32_q5_1, true);
671677
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL, cpy_f32_iq4_nl, true);
672-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true);
673-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F32, cpy_f16_f32, true);
674678
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT, concat, true);
675679
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQR, sqr, true);
676680
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
@@ -750,7 +754,9 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
750754
for (size_t i = 0, n = 3; i < n; ++i) {
751755
if (op->src[i] != NULL && op->src[i]->type == GGML_TYPE_BF16 &&
752756
op->op != GGML_OP_GET_ROWS &&
753-
op->op != GGML_OP_MUL_MAT) {
757+
op->op != GGML_OP_MUL_MAT &&
758+
op->op != GGML_OP_VIEW &&
759+
op->op != GGML_OP_CPY) {
754760
printf("op = %s, src[%zu] = %s\n", ggml_op_name(op->op), i, ggml_type_name(op->src[i]->type));
755761
GGML_ASSERT(false);
756762
}
@@ -826,6 +832,7 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
826832
case GGML_TYPE_F32:
827833
switch (op->type) {
828834
case GGML_TYPE_F16:
835+
case GGML_TYPE_BF16:
829836
case GGML_TYPE_F32:
830837
case GGML_TYPE_Q8_0:
831838
case GGML_TYPE_Q4_0:
@@ -840,6 +847,7 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
840847
case GGML_TYPE_F16:
841848
switch (op->type) {
842849
case GGML_TYPE_F16:
850+
case GGML_TYPE_BF16:
843851
case GGML_TYPE_F32:
844852
return true;
845853
default:
@@ -2812,8 +2820,9 @@ static enum ggml_status ggml_metal_graph_compute(
28122820
GGML_ASSERT(ne0 % ggml_blck_size(dst->type) == 0);
28132821

28142822
switch (dstt) {
2815-
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F16].pipeline; break;
2816-
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline; break;
2823+
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline; break;
2824+
case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_BF16].pipeline; break;
2825+
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F16].pipeline; break;
28172826
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0].pipeline; break;
28182827
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0].pipeline; break;
28192828
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1].pipeline; break;
@@ -2826,8 +2835,9 @@ static enum ggml_status ggml_metal_graph_compute(
28262835
case GGML_TYPE_F16:
28272836
{
28282837
switch (dstt) {
2829-
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F16].pipeline; break;
2830-
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F32].pipeline; break;
2838+
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F32].pipeline; break;
2839+
case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_BF16_F32].pipeline; break;
2840+
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F16].pipeline; break;
28312841
default: GGML_ASSERT(false && "not implemented");
28322842
};
28332843
} break;

ggml/src/ggml-metal.metal

Lines changed: 14 additions & 129 deletions
Original file line numberDiff line numberDiff line change
@@ -2597,91 +2597,10 @@ kernel void kernel_flash_attn_ext_vec_f16(
25972597
template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128>;
25982598
//template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256>;
25992599

2600-
kernel void kernel_cpy_f16_f16(
2601-
device const half * src0,
2602-
device half * dst,
2603-
constant int64_t & ne00,
2604-
constant int64_t & ne01,
2605-
constant int64_t & ne02,
2606-
constant int64_t & ne03,
2607-
constant uint64_t & nb00,
2608-
constant uint64_t & nb01,
2609-
constant uint64_t & nb02,
2610-
constant uint64_t & nb03,
2611-
constant int64_t & ne0,
2612-
constant int64_t & ne1,
2613-
constant int64_t & ne2,
2614-
constant int64_t & ne3,
2615-
constant uint64_t & nb0,
2616-
constant uint64_t & nb1,
2617-
constant uint64_t & nb2,
2618-
constant uint64_t & nb3,
2619-
uint3 tgpig[[threadgroup_position_in_grid]],
2620-
uint3 tpitg[[thread_position_in_threadgroup]],
2621-
uint3 ntg[[threads_per_threadgroup]]) {
2622-
const int64_t i03 = tgpig[2];
2623-
const int64_t i02 = tgpig[1];
2624-
const int64_t i01 = tgpig[0];
2625-
2626-
const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
2627-
2628-
const int64_t i3 = n / (ne2*ne1*ne0);
2629-
const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
2630-
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
2631-
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
2632-
2633-
device half * dst_data = (device half *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
2634-
2635-
for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
2636-
device const half * src = (device half *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
2637-
dst_data[i00] = src[0];
2638-
}
2639-
}
2640-
2641-
kernel void kernel_cpy_f16_f32(
2642-
device const half * src0,
2643-
device float * dst,
2644-
constant int64_t & ne00,
2645-
constant int64_t & ne01,
2646-
constant int64_t & ne02,
2647-
constant int64_t & ne03,
2648-
constant uint64_t & nb00,
2649-
constant uint64_t & nb01,
2650-
constant uint64_t & nb02,
2651-
constant uint64_t & nb03,
2652-
constant int64_t & ne0,
2653-
constant int64_t & ne1,
2654-
constant int64_t & ne2,
2655-
constant int64_t & ne3,
2656-
constant uint64_t & nb0,
2657-
constant uint64_t & nb1,
2658-
constant uint64_t & nb2,
2659-
constant uint64_t & nb3,
2660-
uint3 tgpig[[threadgroup_position_in_grid]],
2661-
uint3 tpitg[[thread_position_in_threadgroup]],
2662-
uint3 ntg[[threads_per_threadgroup]]) {
2663-
const int64_t i03 = tgpig[2];
2664-
const int64_t i02 = tgpig[1];
2665-
const int64_t i01 = tgpig[0];
2666-
2667-
const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
2668-
2669-
const int64_t i3 = n / (ne2*ne1*ne0);
2670-
const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
2671-
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
2672-
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
2673-
2674-
device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
2675-
2676-
for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
2677-
device const half * src = (device half *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
2678-
dst_data[i00] = src[0];
2679-
}
2680-
}
2681-
2682-
kernel void kernel_cpy_f32_f16(
2683-
device const float * src0,
2684-
device half * dst,
2600+
template<typename T0, typename T1>
2601+
kernel void kernel_cpy(
2602+
device const void * src0,
2603+
device void * dst,
26852604
constant int64_t & ne00,
26862605
constant int64_t & ne01,
26872606
constant int64_t & ne02,
@@ -2712,56 +2631,22 @@ kernel void kernel_cpy_f32_f16(
27122631
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
27132632
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
27142633

2715-
device half * dst_data = (device half *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
2634+
device T1 * dst_data = (device T1 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
27162635

27172636
for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
2718-
device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
2719-
2720-
dst_data[i00] = src[0];
2637+
device const T0 * src = (device T0 *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
2638+
dst_data[i00] = (T1) src[0];
27212639
}
27222640
}
27232641

2724-
kernel void kernel_cpy_f32_f32(
2725-
device const float * src0,
2726-
device float * dst,
2727-
constant int64_t & ne00,
2728-
constant int64_t & ne01,
2729-
constant int64_t & ne02,
2730-
constant int64_t & ne03,
2731-
constant uint64_t & nb00,
2732-
constant uint64_t & nb01,
2733-
constant uint64_t & nb02,
2734-
constant uint64_t & nb03,
2735-
constant int64_t & ne0,
2736-
constant int64_t & ne1,
2737-
constant int64_t & ne2,
2738-
constant int64_t & ne3,
2739-
constant uint64_t & nb0,
2740-
constant uint64_t & nb1,
2741-
constant uint64_t & nb2,
2742-
constant uint64_t & nb3,
2743-
uint3 tgpig[[threadgroup_position_in_grid]],
2744-
uint3 tpitg[[thread_position_in_threadgroup]],
2745-
uint3 ntg[[threads_per_threadgroup]]) {
2746-
const int64_t i03 = tgpig[2];
2747-
const int64_t i02 = tgpig[1];
2748-
const int64_t i01 = tgpig[0];
2749-
2750-
const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
2751-
2752-
const int64_t i3 = n / (ne2*ne1*ne0);
2753-
const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
2754-
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
2755-
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
2756-
2757-
device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
2642+
typedef decltype(kernel_cpy<float, float>) kernel_cpy_t;
27582643

2759-
for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
2760-
device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
2761-
2762-
dst_data[i00] = src[0];
2763-
}
2764-
}
2644+
template [[host_name("kernel_cpy_f32_f32")]] kernel kernel_cpy_t kernel_cpy<float, float>;
2645+
template [[host_name("kernel_cpy_f32_bf16")]] kernel kernel_cpy_t kernel_cpy<float, bfloat>;
2646+
template [[host_name("kernel_cpy_f32_f16")]] kernel kernel_cpy_t kernel_cpy<float, half>;
2647+
template [[host_name("kernel_cpy_bf16_f32")]] kernel kernel_cpy_t kernel_cpy<bfloat, float>;
2648+
template [[host_name("kernel_cpy_f16_f16")]] kernel kernel_cpy_t kernel_cpy<half, half>;
2649+
template [[host_name("kernel_cpy_f16_f32")]] kernel kernel_cpy_t kernel_cpy<half, float>;
27652650

27662651
kernel void kernel_cpy_f32_q8_0(
27672652
device const float * src0,

0 commit comments

Comments
 (0)