|
62 | 62 | GGML_METAL_DECL_KERNEL(mul);
|
63 | 63 | GGML_METAL_DECL_KERNEL(mul_row); // TODO: avoid this extra kernel, instead extend the "mul" kernel to support broadcast
|
64 | 64 | GGML_METAL_DECL_KERNEL(scale);
|
| 65 | + GGML_METAL_DECL_KERNEL(scale_4); |
65 | 66 | GGML_METAL_DECL_KERNEL(silu);
|
66 | 67 | GGML_METAL_DECL_KERNEL(relu);
|
67 | 68 | GGML_METAL_DECL_KERNEL(gelu);
|
@@ -249,6 +250,7 @@ static void ggml_metal_log(enum ggml_log_level level, const char* format, ...){
|
249 | 250 | GGML_METAL_ADD_KERNEL(mul);
|
250 | 251 | GGML_METAL_ADD_KERNEL(mul_row);
|
251 | 252 | GGML_METAL_ADD_KERNEL(scale);
|
| 253 | + GGML_METAL_ADD_KERNEL(scale_4); |
252 | 254 | GGML_METAL_ADD_KERNEL(silu);
|
253 | 255 | GGML_METAL_ADD_KERNEL(relu);
|
254 | 256 | GGML_METAL_ADD_KERNEL(gelu);
|
@@ -347,6 +349,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
|
347 | 349 | GGML_METAL_DEL_KERNEL(mul);
|
348 | 350 | GGML_METAL_DEL_KERNEL(mul_row);
|
349 | 351 | GGML_METAL_DEL_KERNEL(scale);
|
| 352 | + GGML_METAL_DEL_KERNEL(scale_4); |
350 | 353 | GGML_METAL_DEL_KERNEL(silu);
|
351 | 354 | GGML_METAL_DEL_KERNEL(relu);
|
352 | 355 | GGML_METAL_DEL_KERNEL(gelu);
|
@@ -923,15 +926,20 @@ void ggml_metal_graph_compute(
|
923 | 926 |
|
924 | 927 | const float scale = *(const float *) src1->data;
|
925 | 928 |
|
926 |
| - [encoder setComputePipelineState:ctx->pipeline_scale]; |
| 929 | + int64_t n = ggml_nelements(dst); |
| 930 | + |
| 931 | + if (n % 4 == 0) { |
| 932 | + n /= 4; |
| 933 | + [encoder setComputePipelineState:ctx->pipeline_scale_4]; |
| 934 | + } else { |
| 935 | + [encoder setComputePipelineState:ctx->pipeline_scale]; |
| 936 | + } |
| 937 | + |
927 | 938 | [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
928 | 939 | [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
929 | 940 | [encoder setBytes:&scale length:sizeof(scale) atIndex:2];
|
930 | 941 |
|
931 |
| - const int64_t n = ggml_nelements(dst); |
932 |
| - GGML_ASSERT(n % 4 == 0); |
933 |
| - |
934 |
| - [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; |
| 942 | + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; |
935 | 943 | } break;
|
936 | 944 | case GGML_OP_UNARY:
|
937 | 945 | switch (ggml_get_unary_op(gf->nodes[i])) {
|
|
0 commit comments