Skip to content

Commit 7dc0a49

Browse files
committed
Fix Metal API validation errors
1 parent 45bde02 commit 7dc0a49

File tree

1 file changed

+50
-50
lines changed

1 file changed

+50
-50
lines changed

ggml-metal.m

Lines changed: 50 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -966,9 +966,9 @@ void ggml_metal_graph_compute(
966966
const int64_t nb = ne00;
967967

968968
[encoder setComputePipelineState:ctx->pipeline_concat];
969-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
970-
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
971-
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
969+
if (id_src0) [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
970+
if (id_src1) [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
971+
if (id_dst) [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
972972
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
973973
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
974974
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
@@ -1031,9 +1031,9 @@ void ggml_metal_graph_compute(
10311031
default: GGML_ASSERT(false);
10321032
}
10331033
}
1034-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1035-
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1036-
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1034+
if (id_src0) [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1035+
if (id_src1) [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1036+
if (id_dst) [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
10371037
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
10381038
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
10391039
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
@@ -1085,8 +1085,8 @@ void ggml_metal_graph_compute(
10851085
[encoder setComputePipelineState:ctx->pipeline_scale];
10861086
}
10871087

1088-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1089-
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1088+
if (id_src0) [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1089+
if (id_dst) [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
10901090
[encoder setBytes:&scale length:sizeof(scale) atIndex:2];
10911091

10921092
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
@@ -1096,8 +1096,8 @@ void ggml_metal_graph_compute(
10961096
case GGML_UNARY_OP_SILU:
10971097
{
10981098
[encoder setComputePipelineState:ctx->pipeline_silu];
1099-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1100-
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1099+
if (id_src0) [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1100+
if (id_dst) [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
11011101

11021102
const int64_t n = ggml_nelements(dst);
11031103
GGML_ASSERT(n % 4 == 0);
@@ -1107,8 +1107,8 @@ void ggml_metal_graph_compute(
11071107
case GGML_UNARY_OP_RELU:
11081108
{
11091109
[encoder setComputePipelineState:ctx->pipeline_relu];
1110-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1111-
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1110+
if (id_src0) [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1111+
if (id_dst) [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
11121112

11131113
const int64_t n = ggml_nelements(dst);
11141114

@@ -1117,8 +1117,8 @@ void ggml_metal_graph_compute(
11171117
case GGML_UNARY_OP_GELU:
11181118
{
11191119
[encoder setComputePipelineState:ctx->pipeline_gelu];
1120-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1121-
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1120+
if (id_src0) [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1121+
if (id_dst) [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
11221122

11231123
const int64_t n = ggml_nelements(dst);
11241124
GGML_ASSERT(n % 4 == 0);
@@ -1136,8 +1136,8 @@ void ggml_metal_graph_compute(
11361136
GGML_ASSERT(ggml_is_contiguous(src0));
11371137

11381138
[encoder setComputePipelineState:ctx->pipeline_sqr];
1139-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1140-
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1139+
if (id_src0) [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1140+
if (id_dst) [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
11411141

11421142
const int64_t n = ggml_nelements(dst);
11431143
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
@@ -1147,8 +1147,8 @@ void ggml_metal_graph_compute(
11471147
GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
11481148

11491149
[encoder setComputePipelineState:ctx->pipeline_sum_rows];
1150-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1151-
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1150+
if (id_src0) [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1151+
if (id_dst) [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
11521152
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
11531153
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
11541154
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
@@ -1194,9 +1194,9 @@ void ggml_metal_graph_compute(
11941194

11951195
const float scale = ((float *) dst->op_params)[0];
11961196

1197-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1198-
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1199-
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1197+
if (id_src0) [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1198+
if (id_src1) [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1199+
if (id_dst) [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
12001200
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
12011201
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
12021202
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
@@ -1214,8 +1214,8 @@ void ggml_metal_graph_compute(
12141214
} else {
12151215
[encoder setComputePipelineState:ctx->pipeline_diag_mask_inf];
12161216
}
1217-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1218-
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1217+
if (id_src0) [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1218+
if (id_dst) [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
12191219
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
12201220
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
12211221
[encoder setBytes:&n_past length:sizeof(int) atIndex:4];
@@ -1288,9 +1288,9 @@ void ggml_metal_graph_compute(
12881288
case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q6_K_f32]; break;
12891289
default: GGML_ASSERT(false && "MUL MAT-MAT not implemented");
12901290
}
1291-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1292-
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1293-
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1291+
if (id_src0) [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1292+
if (id_src1) [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1293+
if (id_dst) [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
12941294
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
12951295
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
12961296
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5];
@@ -1405,9 +1405,9 @@ void ggml_metal_graph_compute(
14051405
}
14061406
};
14071407

1408-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1409-
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1410-
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1408+
if (id_src0) [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1409+
if (id_src1) [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1410+
if (id_dst) [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
14111411
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
14121412
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
14131413
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
@@ -1513,9 +1513,9 @@ void ggml_metal_graph_compute(
15131513
case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q6_K_f32]; break;
15141514
default: GGML_ASSERT(false && "MUL_MAT_ID not implemented");
15151515
}
1516-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1517-
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1518-
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1516+
if (id_src0) [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1517+
if (id_src1) [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1518+
if (id_dst) [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
15191519
[encoder setBytes:&ne20 length:sizeof(ne20) atIndex:3];
15201520
[encoder setBytes:&ne22 length:sizeof(ne22) atIndex:4];
15211521
[encoder setBytes:&nb21 length:sizeof(nb21) atIndex:5];
@@ -1561,9 +1561,9 @@ void ggml_metal_graph_compute(
15611561
default: GGML_ASSERT(false && "not implemented");
15621562
}
15631563

1564-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1565-
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1566-
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1564+
if (id_src0) [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1565+
if (id_src1) [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1566+
if (id_dst) [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
15671567
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
15681568
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4];
15691569
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:5];
@@ -1586,8 +1586,8 @@ void ggml_metal_graph_compute(
15861586
}
15871587

15881588
[encoder setComputePipelineState:ctx->pipeline_rms_norm];
1589-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1590-
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1589+
if (id_src0) [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1590+
if (id_dst) [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
15911591
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
15921592
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
15931593
[encoder setBytes:&eps length:sizeof( float) atIndex:4];
@@ -1605,8 +1605,8 @@ void ggml_metal_graph_compute(
16051605
const int nth = MIN(256, ne00);
16061606

16071607
[encoder setComputePipelineState:ctx->pipeline_norm];
1608-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1609-
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1608+
if (id_src0) [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1609+
if (id_dst) [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
16101610
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
16111611
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
16121612
[encoder setBytes:&eps length:sizeof( float) atIndex:4];
@@ -1632,8 +1632,8 @@ void ggml_metal_graph_compute(
16321632
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
16331633

16341634
[encoder setComputePipelineState:ctx->pipeline_alibi_f32];
1635-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1636-
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1635+
if (id_src0) [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1636+
if (id_dst) [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
16371637
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
16381638
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
16391639
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
@@ -1682,9 +1682,9 @@ void ggml_metal_graph_compute(
16821682
default: GGML_ASSERT(false);
16831683
};
16841684

1685-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1686-
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1687-
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1685+
if (id_src0) [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1686+
if (id_src1) [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1687+
if (id_dst) [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
16881688
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
16891689
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:4];
16901690
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:5];
@@ -1750,8 +1750,8 @@ void ggml_metal_graph_compute(
17501750
default: GGML_ASSERT(false);
17511751
};
17521752

1753-
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:0];
1754-
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1753+
if (id_src1) [encoder setBuffer:id_src1 offset:offs_src1 atIndex:0];
1754+
if (id_dst) [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
17551755
[encoder setBytes:&ofs0 length:sizeof( int32_t) atIndex:2];
17561756
[encoder setBytes:&ofs1 length:sizeof( int32_t) atIndex:3];
17571757
[encoder setBytes:&IW length:sizeof( int32_t) atIndex:4];
@@ -1781,8 +1781,8 @@ void ggml_metal_graph_compute(
17811781
default: GGML_ASSERT(false);
17821782
};
17831783

1784-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1785-
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1784+
if (id_src0) [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1785+
if (id_dst) [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
17861786
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
17871787

17881788
[encoder dispatchThreadgroups:MTLSizeMake(1, nrows, 1) threadsPerThreadgroup:MTLSizeMake(ne00, 1, 1)];
@@ -1822,8 +1822,8 @@ void ggml_metal_graph_compute(
18221822
default: GGML_ASSERT(false && "not implemented");
18231823
}
18241824

1825-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1826-
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1825+
if (id_src0) [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1826+
if (id_dst) [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
18271827
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
18281828
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
18291829
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];

0 commit comments

Comments
 (0)