33
33
struct ggml_metal_context {
34
34
int n_cb;
35
35
36
- float * logits;
37
-
38
36
id <MTLDevice > device;
39
37
id <MTLCommandQueue > queue;
40
38
id <MTLLibrary > library;
41
39
40
+ id <MTLCommandBuffer > command_buffers [GGML_METAL_MAX_COMMAND_BUFFERS];
41
+ id <MTLComputeCommandEncoder > command_encoders[GGML_METAL_MAX_COMMAND_BUFFERS];
42
+
43
+ dispatch_queue_t d_queue;
44
+
42
45
int n_buffers;
43
46
struct ggml_metal_buffer buffers[GGML_METAL_MAX_BUFFERS];
44
47
@@ -114,12 +117,13 @@ @implementation GGMLMetalClass
114
117
115
118
struct ggml_metal_context * ctx = malloc (sizeof (struct ggml_metal_context));
116
119
117
- ctx->n_cb = n_cb;
120
+ ctx->n_cb = MIN ( n_cb, GGML_METAL_MAX_BUFFERS) ;
118
121
ctx->device = MTLCreateSystemDefaultDevice ();
119
122
ctx->queue = [ctx->device newCommandQueue ];
120
123
ctx->n_buffers = 0 ;
121
124
ctx->concur_list_len = 0 ;
122
125
126
+ ctx->d_queue = dispatch_queue_create (" llama.cpp" , DISPATCH_QUEUE_CONCURRENT);
123
127
124
128
#if 0
125
129
// compile from source string and show compile log
@@ -239,9 +243,67 @@ @implementation GGMLMetalClass
239
243
240
244
void ggml_metal_free (struct ggml_metal_context * ctx) {
241
245
fprintf (stderr, " %s : deallocating\n " , __func__);
246
+ #define GGML_METAL_DEL_KERNEL (name ) \
247
+ [ctx->function_##name release ]; \
248
+ [ctx->pipeline_##name release ];
249
+
250
+ GGML_METAL_DEL_KERNEL (add);
251
+ GGML_METAL_DEL_KERNEL (add_row);
252
+ GGML_METAL_DEL_KERNEL (mul);
253
+ GGML_METAL_DEL_KERNEL (mul_row);
254
+ GGML_METAL_DEL_KERNEL (scale);
255
+ GGML_METAL_DEL_KERNEL (silu);
256
+ GGML_METAL_DEL_KERNEL (relu);
257
+ GGML_METAL_DEL_KERNEL (gelu);
258
+ GGML_METAL_DEL_KERNEL (soft_max);
259
+ GGML_METAL_DEL_KERNEL (diag_mask_inf);
260
+ GGML_METAL_DEL_KERNEL (get_rows_f16);
261
+ GGML_METAL_DEL_KERNEL (get_rows_q4_0);
262
+ GGML_METAL_DEL_KERNEL (get_rows_q4_1);
263
+ GGML_METAL_DEL_KERNEL (get_rows_q8_0);
264
+ GGML_METAL_DEL_KERNEL (get_rows_q2_K);
265
+ GGML_METAL_DEL_KERNEL (get_rows_q3_K);
266
+ GGML_METAL_DEL_KERNEL (get_rows_q4_K);
267
+ GGML_METAL_DEL_KERNEL (get_rows_q5_K);
268
+ GGML_METAL_DEL_KERNEL (get_rows_q6_K);
269
+ GGML_METAL_DEL_KERNEL (rms_norm);
270
+ GGML_METAL_DEL_KERNEL (norm);
271
+ GGML_METAL_DEL_KERNEL (mul_mat_f16_f32);
272
+ GGML_METAL_DEL_KERNEL (mul_mat_q4_0_f32);
273
+ GGML_METAL_DEL_KERNEL (mul_mat_q4_1_f32);
274
+ GGML_METAL_DEL_KERNEL (mul_mat_q8_0_f32);
275
+ GGML_METAL_DEL_KERNEL (mul_mat_q2_K_f32);
276
+ GGML_METAL_DEL_KERNEL (mul_mat_q3_K_f32);
277
+ GGML_METAL_DEL_KERNEL (mul_mat_q4_K_f32);
278
+ GGML_METAL_DEL_KERNEL (mul_mat_q5_K_f32);
279
+ GGML_METAL_DEL_KERNEL (mul_mat_q6_K_f32);
280
+ GGML_METAL_DEL_KERNEL (mul_mm_f16_f32);
281
+ GGML_METAL_DEL_KERNEL (mul_mm_q4_0_f32);
282
+ GGML_METAL_DEL_KERNEL (mul_mm_q8_0_f32);
283
+ GGML_METAL_DEL_KERNEL (mul_mm_q4_1_f32);
284
+ GGML_METAL_DEL_KERNEL (mul_mm_q2_K_f32);
285
+ GGML_METAL_DEL_KERNEL (mul_mm_q3_K_f32);
286
+ GGML_METAL_DEL_KERNEL (mul_mm_q4_K_f32);
287
+ GGML_METAL_DEL_KERNEL (mul_mm_q5_K_f32);
288
+ GGML_METAL_DEL_KERNEL (mul_mm_q6_K_f32);
289
+ GGML_METAL_DEL_KERNEL (rope);
290
+ GGML_METAL_DEL_KERNEL (alibi_f32);
291
+ GGML_METAL_DEL_KERNEL (cpy_f32_f16);
292
+ GGML_METAL_DEL_KERNEL (cpy_f32_f32);
293
+ GGML_METAL_DEL_KERNEL (cpy_f16_f16);
294
+
295
+ #undef GGML_METAL_DEL_KERNEL
296
+
242
297
for (int i = 0 ; i < ctx->n_buffers ; ++i) {
243
298
[ctx->buffers[i].metal release ];
244
299
}
300
+
301
+ [ctx->library release ];
302
+ [ctx->queue release ];
303
+ [ctx->device release ];
304
+
305
+ dispatch_release (ctx->d_queue );
306
+
245
307
free (ctx);
246
308
}
247
309
@@ -261,7 +323,7 @@ void ggml_metal_host_free(void * data) {
261
323
}
262
324
263
325
void ggml_metal_set_n_cb (struct ggml_metal_context * ctx, int n_cb) {
264
- ctx->n_cb = n_cb;
326
+ ctx->n_cb = MIN ( n_cb, GGML_METAL_MAX_BUFFERS) ;
265
327
}
266
328
267
329
int ggml_metal_if_optimized (struct ggml_metal_context * ctx) {
@@ -507,6 +569,8 @@ void ggml_metal_graph_compute(
507
569
struct ggml_cgraph * gf) {
508
570
metal_printf (" %s : evaluating graph\n " , __func__);
509
571
572
+ @autoreleasepool {
573
+
510
574
// if there is ctx->concur_list, dispatch concurrently
511
575
// else fallback to serial dispatch
512
576
MTLComputePassDescriptor * edesc = MTLComputePassDescriptor.computePassDescriptor ;
@@ -521,29 +585,25 @@ void ggml_metal_graph_compute(
521
585
522
586
const int n_cb = ctx->n_cb ;
523
587
524
- NSMutableArray * command_buffers = [NSMutableArray arrayWithCapacity: n_cb];
525
-
526
588
for (int i = 0 ; i < n_cb; ++i) {
527
- command_buffers[i] = [ctx->queue commandBuffer ];
589
+ ctx-> command_buffers [i] = [ctx->queue commandBuffer ];
528
590
529
591
// enqueue the command buffers in order to specify their execution order
530
- [command_buffers[i] enqueue ];
531
- }
592
+ [ctx->command_buffers[i] enqueue ];
532
593
533
- // TODO: is this the best way to start threads?
534
- dispatch_queue_t queue = dispatch_queue_create ( " llama.cpp " , DISPATCH_QUEUE_CONCURRENT);
594
+ ctx-> command_encoders [i] = [ctx->command_buffers[i] computeCommandEncoderWithDescriptor: edesc];
595
+ }
535
596
536
597
for (int cb_idx = 0 ; cb_idx < n_cb; ++cb_idx) {
537
598
const int n_nodes_per_cb = (n_nodes + n_cb - 1 ) / n_cb;
538
599
539
- dispatch_async (queue , ^{
600
+ dispatch_async (ctx-> d_queue , ^{
540
601
size_t offs_src0 = 0 ;
541
602
size_t offs_src1 = 0 ;
542
603
size_t offs_dst = 0 ;
543
604
544
- id <MTLCommandBuffer > command_buffer = command_buffers[cb_idx];
545
-
546
- id <MTLComputeCommandEncoder > encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
605
+ id <MTLCommandBuffer > command_buffer = ctx->command_buffers [cb_idx];
606
+ id <MTLComputeCommandEncoder > encoder = ctx->command_encoders [cb_idx];
547
607
548
608
const int node_start = (cb_idx + 0 ) * n_nodes_per_cb;
549
609
const int node_end = MIN ((cb_idx == n_cb - 1 ) ? n_nodes : (cb_idx + 1 ) * n_nodes_per_cb, n_nodes);
@@ -1117,17 +1177,19 @@ void ggml_metal_graph_compute(
1117
1177
}
1118
1178
1119
1179
// wait for all threads to finish
1120
- dispatch_barrier_sync (queue, ^{});
1121
-
1122
- [command_buffers[n_cb - 1 ] waitUntilCompleted ];
1180
+ dispatch_barrier_sync (ctx->d_queue , ^{});
1123
1181
1124
1182
// check status of command buffers
1125
1183
// needed to detect if the device ran out-of-memory for example (#1881)
1126
1184
for (int i = 0 ; i < n_cb; i++) {
1127
- MTLCommandBufferStatus status = (MTLCommandBufferStatus ) [command_buffers[i] status ];
1185
+ [ctx->command_buffers[i] waitUntilCompleted ];
1186
+
1187
+ MTLCommandBufferStatus status = (MTLCommandBufferStatus ) [ctx->command_buffers[i] status ];
1128
1188
if (status != MTLCommandBufferStatusCompleted ) {
1129
1189
fprintf (stderr, " %s : command buffer %d failed with status %lu \n " , __func__, i, status);
1130
1190
GGML_ASSERT (false );
1131
1191
}
1132
1192
}
1193
+
1194
+ }
1133
1195
}
0 commit comments