Skip to content

Commit 469c9ad

Browse files
committed
metal : handle ggml_scale for n%4 != 0 (close #3754)
ggml-ci
1 parent e393259 commit 469c9ad

File tree

2 files changed

+22
-6
lines changed

2 files changed

+22
-6
lines changed

ggml-metal.m

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262
GGML_METAL_DECL_KERNEL(mul);
6363
GGML_METAL_DECL_KERNEL(mul_row); // TODO: avoid this extra kernel, instead extend the "mul" kernel to support broadcast
6464
GGML_METAL_DECL_KERNEL(scale);
65+
GGML_METAL_DECL_KERNEL(scale_4);
6566
GGML_METAL_DECL_KERNEL(silu);
6667
GGML_METAL_DECL_KERNEL(relu);
6768
GGML_METAL_DECL_KERNEL(gelu);
@@ -249,6 +250,7 @@ static void ggml_metal_log(enum ggml_log_level level, const char* format, ...){
249250
GGML_METAL_ADD_KERNEL(mul);
250251
GGML_METAL_ADD_KERNEL(mul_row);
251252
GGML_METAL_ADD_KERNEL(scale);
253+
GGML_METAL_ADD_KERNEL(scale_4);
252254
GGML_METAL_ADD_KERNEL(silu);
253255
GGML_METAL_ADD_KERNEL(relu);
254256
GGML_METAL_ADD_KERNEL(gelu);
@@ -347,6 +349,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
347349
GGML_METAL_DEL_KERNEL(mul);
348350
GGML_METAL_DEL_KERNEL(mul_row);
349351
GGML_METAL_DEL_KERNEL(scale);
352+
GGML_METAL_DEL_KERNEL(scale_4);
350353
GGML_METAL_DEL_KERNEL(silu);
351354
GGML_METAL_DEL_KERNEL(relu);
352355
GGML_METAL_DEL_KERNEL(gelu);
@@ -923,15 +926,20 @@ void ggml_metal_graph_compute(
923926

924927
const float scale = *(const float *) src1->data;
925928

926-
[encoder setComputePipelineState:ctx->pipeline_scale];
929+
int64_t n = ggml_nelements(dst);
930+
931+
if (n % 4 == 0) {
932+
n /= 4;
933+
[encoder setComputePipelineState:ctx->pipeline_scale_4];
934+
} else {
935+
[encoder setComputePipelineState:ctx->pipeline_scale];
936+
}
937+
927938
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
928939
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
929940
[encoder setBytes:&scale length:sizeof(scale) atIndex:2];
930941

931-
const int64_t n = ggml_nelements(dst);
932-
GGML_ASSERT(n % 4 == 0);
933-
934-
[encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
942+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
935943
} break;
936944
case GGML_OP_UNARY:
937945
switch (ggml_get_unary_op(gf->nodes[i])) {

ggml-metal.metal

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,9 +125,17 @@ kernel void kernel_mul_row(
125125
}
126126

127127
kernel void kernel_scale(
128+
device const float * src0,
129+
device float * dst,
130+
constant float & scale,
131+
uint tpig[[thread_position_in_grid]]) {
132+
dst[tpig] = src0[tpig] * scale;
133+
}
134+
135+
kernel void kernel_scale_4(
128136
device const float4 * src0,
129137
device float4 * dst,
130-
constant float & scale,
138+
constant float & scale,
131139
uint tpig[[thread_position_in_grid]]) {
132140
dst[tpig] = src0[tpig] * scale;
133141
}

0 commit comments

Comments
 (0)