Skip to content

Commit 5b2b2dc

Browse files
authored
ggml : sync (unary ops refactor, static-correctness) (#2370)
* ggml : sync (unary ops, tests) ggml-ci * tests : remove unnecessary funcs
1 parent 42f70cb commit 5b2b2dc

File tree

6 files changed

+870
-575
lines changed

6 files changed

+870
-575
lines changed

ggml-cuda.cu

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3962,18 +3962,23 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
39623962
}
39633963
func = ggml_cuda_mul;
39643964
break;
3965-
case GGML_OP_GELU:
3966-
if (!any_on_device) {
3967-
return false;
3968-
}
3969-
func = ggml_cuda_gelu;
3970-
break;
3971-
case GGML_OP_SILU:
3972-
if (!any_on_device) {
3973-
return false;
3974-
}
3975-
func = ggml_cuda_silu;
3976-
break;
3965+
case GGML_OP_UNARY:
3966+
switch (ggml_get_unary_op(tensor)) {
3967+
case GGML_UNARY_OP_GELU:
3968+
if (!any_on_device) {
3969+
return false;
3970+
}
3971+
func = ggml_cuda_gelu;
3972+
break;
3973+
case GGML_UNARY_OP_SILU:
3974+
if (!any_on_device) {
3975+
return false;
3976+
}
3977+
func = ggml_cuda_silu;
3978+
break;
3979+
default:
3980+
return false;
3981+
} break;
39773982
case GGML_OP_NORM:
39783983
if (!any_on_device) {
39793984
return false;

ggml-metal.m

Lines changed: 53 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -519,48 +519,56 @@ void ggml_metal_graph_compute(
519519

520520
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
521521
} break;
522-
case GGML_OP_SILU:
523-
{
524-
if (encoder == nil) {
525-
encoder = [command_buffer computeCommandEncoder];
526-
}
527-
528-
[encoder setComputePipelineState:ctx->pipeline_silu];
529-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
530-
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
531-
532-
const int64_t n = ggml_nelements(dst);
533-
534-
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
535-
} break;
536-
case GGML_OP_RELU:
537-
{
538-
if (encoder == nil) {
539-
encoder = [command_buffer computeCommandEncoder];
540-
}
541-
542-
[encoder setComputePipelineState:ctx->pipeline_relu];
543-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
544-
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
545-
546-
const int64_t n = ggml_nelements(dst);
547-
548-
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
522+
case GGML_OP_UNARY:
523+
switch (ggml_get_unary_op(gf->nodes[i])) {
524+
case GGML_UNARY_OP_SILU:
525+
{
526+
if (encoder == nil) {
527+
encoder = [command_buffer computeCommandEncoder];
528+
}
529+
530+
[encoder setComputePipelineState:ctx->pipeline_silu];
531+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
532+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
533+
534+
const int64_t n = ggml_nelements(dst);
535+
536+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
537+
} break;
538+
case GGML_UNARY_OP_RELU:
539+
{
540+
if (encoder == nil) {
541+
encoder = [command_buffer computeCommandEncoder];
542+
}
543+
544+
[encoder setComputePipelineState:ctx->pipeline_relu];
545+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
546+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
547+
548+
const int64_t n = ggml_nelements(dst);
549+
550+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
551+
} break;
552+
case GGML_UNARY_OP_GELU:
553+
{
554+
if (encoder == nil) {
555+
encoder = [command_buffer computeCommandEncoder];
556+
}
557+
558+
[encoder setComputePipelineState:ctx->pipeline_gelu];
559+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
560+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
561+
562+
const int64_t n = ggml_nelements(dst);
563+
564+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
565+
} break;
566+
default:
567+
{
568+
fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
569+
GGML_ASSERT(false);
570+
}
549571
} break;
550-
case GGML_OP_GELU:
551-
{
552-
if (encoder == nil) {
553-
encoder = [command_buffer computeCommandEncoder];
554-
}
555-
556-
[encoder setComputePipelineState:ctx->pipeline_gelu];
557-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
558-
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
559-
560-
const int64_t n = ggml_nelements(dst);
561-
562-
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
563-
} break;
564572
case GGML_OP_SOFT_MAX:
565573
{
566574
if (encoder == nil) {
@@ -979,8 +987,10 @@ void ggml_metal_graph_compute(
979987
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
980988
} break;
981989
default:
982-
fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
983-
GGML_ASSERT(false);
990+
{
991+
fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
992+
GGML_ASSERT(false);
993+
}
984994
}
985995
}
986996

0 commit comments

Comments
 (0)