Skip to content

Commit f55538c

Browse files
authored
metal : fix memory leak (#2762)
* metal : fix memory leak * metal : fix encoders memory leak * metal : clean up more memory resources * metal : fix more leaks * metal : reuse dispatch queue + autoreleasepool * metal : reuse array for command buffers and encoders * ggml : assert for odd number of blocks on ARM 15M tinyllama is an example
1 parent ebcee20 commit f55538c

File tree

3 files changed

+88
-24
lines changed

3 files changed

+88
-24
lines changed

ggml-metal.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424

2525
// max memory buffers that can be mapped to the device
2626
#define GGML_METAL_MAX_BUFFERS 16
27+
#define GGML_METAL_MAX_COMMAND_BUFFERS 32
2728

2829
struct ggml_tensor;
2930
struct ggml_cgraph;

ggml-metal.m

Lines changed: 81 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -33,12 +33,15 @@
3333
struct ggml_metal_context {
3434
int n_cb;
3535

36-
float * logits;
37-
3836
id<MTLDevice> device;
3937
id<MTLCommandQueue> queue;
4038
id<MTLLibrary> library;
4139

40+
id<MTLCommandBuffer> command_buffers [GGML_METAL_MAX_COMMAND_BUFFERS];
41+
id<MTLComputeCommandEncoder> command_encoders[GGML_METAL_MAX_COMMAND_BUFFERS];
42+
43+
dispatch_queue_t d_queue;
44+
4245
int n_buffers;
4346
struct ggml_metal_buffer buffers[GGML_METAL_MAX_BUFFERS];
4447

@@ -114,12 +117,13 @@ @implementation GGMLMetalClass
114117

115118
struct ggml_metal_context * ctx = malloc(sizeof(struct ggml_metal_context));
116119

117-
ctx->n_cb = n_cb;
120+
ctx->n_cb = MIN(n_cb, GGML_METAL_MAX_BUFFERS);
118121
ctx->device = MTLCreateSystemDefaultDevice();
119122
ctx->queue = [ctx->device newCommandQueue];
120123
ctx->n_buffers = 0;
121124
ctx->concur_list_len = 0;
122125

126+
ctx->d_queue = dispatch_queue_create("llama.cpp", DISPATCH_QUEUE_CONCURRENT);
123127

124128
#if 0
125129
// compile from source string and show compile log
@@ -239,9 +243,67 @@ @implementation GGMLMetalClass
239243

240244
void ggml_metal_free(struct ggml_metal_context * ctx) {
241245
fprintf(stderr, "%s: deallocating\n", __func__);
246+
#define GGML_METAL_DEL_KERNEL(name) \
247+
[ctx->function_##name release]; \
248+
[ctx->pipeline_##name release];
249+
250+
GGML_METAL_DEL_KERNEL(add);
251+
GGML_METAL_DEL_KERNEL(add_row);
252+
GGML_METAL_DEL_KERNEL(mul);
253+
GGML_METAL_DEL_KERNEL(mul_row);
254+
GGML_METAL_DEL_KERNEL(scale);
255+
GGML_METAL_DEL_KERNEL(silu);
256+
GGML_METAL_DEL_KERNEL(relu);
257+
GGML_METAL_DEL_KERNEL(gelu);
258+
GGML_METAL_DEL_KERNEL(soft_max);
259+
GGML_METAL_DEL_KERNEL(diag_mask_inf);
260+
GGML_METAL_DEL_KERNEL(get_rows_f16);
261+
GGML_METAL_DEL_KERNEL(get_rows_q4_0);
262+
GGML_METAL_DEL_KERNEL(get_rows_q4_1);
263+
GGML_METAL_DEL_KERNEL(get_rows_q8_0);
264+
GGML_METAL_DEL_KERNEL(get_rows_q2_K);
265+
GGML_METAL_DEL_KERNEL(get_rows_q3_K);
266+
GGML_METAL_DEL_KERNEL(get_rows_q4_K);
267+
GGML_METAL_DEL_KERNEL(get_rows_q5_K);
268+
GGML_METAL_DEL_KERNEL(get_rows_q6_K);
269+
GGML_METAL_DEL_KERNEL(rms_norm);
270+
GGML_METAL_DEL_KERNEL(norm);
271+
GGML_METAL_DEL_KERNEL(mul_mat_f16_f32);
272+
GGML_METAL_DEL_KERNEL(mul_mat_q4_0_f32);
273+
GGML_METAL_DEL_KERNEL(mul_mat_q4_1_f32);
274+
GGML_METAL_DEL_KERNEL(mul_mat_q8_0_f32);
275+
GGML_METAL_DEL_KERNEL(mul_mat_q2_K_f32);
276+
GGML_METAL_DEL_KERNEL(mul_mat_q3_K_f32);
277+
GGML_METAL_DEL_KERNEL(mul_mat_q4_K_f32);
278+
GGML_METAL_DEL_KERNEL(mul_mat_q5_K_f32);
279+
GGML_METAL_DEL_KERNEL(mul_mat_q6_K_f32);
280+
GGML_METAL_DEL_KERNEL(mul_mm_f16_f32);
281+
GGML_METAL_DEL_KERNEL(mul_mm_q4_0_f32);
282+
GGML_METAL_DEL_KERNEL(mul_mm_q8_0_f32);
283+
GGML_METAL_DEL_KERNEL(mul_mm_q4_1_f32);
284+
GGML_METAL_DEL_KERNEL(mul_mm_q2_K_f32);
285+
GGML_METAL_DEL_KERNEL(mul_mm_q3_K_f32);
286+
GGML_METAL_DEL_KERNEL(mul_mm_q4_K_f32);
287+
GGML_METAL_DEL_KERNEL(mul_mm_q5_K_f32);
288+
GGML_METAL_DEL_KERNEL(mul_mm_q6_K_f32);
289+
GGML_METAL_DEL_KERNEL(rope);
290+
GGML_METAL_DEL_KERNEL(alibi_f32);
291+
GGML_METAL_DEL_KERNEL(cpy_f32_f16);
292+
GGML_METAL_DEL_KERNEL(cpy_f32_f32);
293+
GGML_METAL_DEL_KERNEL(cpy_f16_f16);
294+
295+
#undef GGML_METAL_DEL_KERNEL
296+
242297
for (int i = 0; i < ctx->n_buffers; ++i) {
243298
[ctx->buffers[i].metal release];
244299
}
300+
301+
[ctx->library release];
302+
[ctx->queue release];
303+
[ctx->device release];
304+
305+
dispatch_release(ctx->d_queue);
306+
245307
free(ctx);
246308
}
247309

@@ -261,7 +323,7 @@ void ggml_metal_host_free(void * data) {
261323
}
262324

263325
void ggml_metal_set_n_cb(struct ggml_metal_context * ctx, int n_cb) {
264-
ctx->n_cb = n_cb;
326+
ctx->n_cb = MIN(n_cb, GGML_METAL_MAX_BUFFERS);
265327
}
266328

267329
int ggml_metal_if_optimized(struct ggml_metal_context * ctx) {
@@ -507,6 +569,8 @@ void ggml_metal_graph_compute(
507569
struct ggml_cgraph * gf) {
508570
metal_printf("%s: evaluating graph\n", __func__);
509571

572+
@autoreleasepool {
573+
510574
// if there is ctx->concur_list, dispatch concurrently
511575
// else fallback to serial dispatch
512576
MTLComputePassDescriptor * edesc = MTLComputePassDescriptor.computePassDescriptor;
@@ -521,29 +585,25 @@ void ggml_metal_graph_compute(
521585

522586
const int n_cb = ctx->n_cb;
523587

524-
NSMutableArray * command_buffers = [NSMutableArray arrayWithCapacity:n_cb];
525-
526588
for (int i = 0; i < n_cb; ++i) {
527-
command_buffers[i] = [ctx->queue commandBuffer];
589+
ctx->command_buffers[i] = [ctx->queue commandBuffer];
528590

529591
// enqueue the command buffers in order to specify their execution order
530-
[command_buffers[i] enqueue];
531-
}
592+
[ctx->command_buffers[i] enqueue];
532593

533-
// TODO: is this the best way to start threads?
534-
dispatch_queue_t queue = dispatch_queue_create("llama.cpp", DISPATCH_QUEUE_CONCURRENT);
594+
ctx->command_encoders[i] = [ctx->command_buffers[i] computeCommandEncoderWithDescriptor: edesc];
595+
}
535596

536597
for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
537598
const int n_nodes_per_cb = (n_nodes + n_cb - 1) / n_cb;
538599

539-
dispatch_async(queue, ^{
600+
dispatch_async(ctx->d_queue, ^{
540601
size_t offs_src0 = 0;
541602
size_t offs_src1 = 0;
542603
size_t offs_dst = 0;
543604

544-
id<MTLCommandBuffer> command_buffer = command_buffers[cb_idx];
545-
546-
id<MTLComputeCommandEncoder> encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
605+
id<MTLCommandBuffer> command_buffer = ctx->command_buffers[cb_idx];
606+
id<MTLComputeCommandEncoder> encoder = ctx->command_encoders[cb_idx];
547607

548608
const int node_start = (cb_idx + 0) * n_nodes_per_cb;
549609
const int node_end = MIN((cb_idx == n_cb - 1) ? n_nodes : (cb_idx + 1) * n_nodes_per_cb, n_nodes);
@@ -1117,17 +1177,19 @@ void ggml_metal_graph_compute(
11171177
}
11181178

11191179
// wait for all threads to finish
1120-
dispatch_barrier_sync(queue, ^{});
1121-
1122-
[command_buffers[n_cb - 1] waitUntilCompleted];
1180+
dispatch_barrier_sync(ctx->d_queue, ^{});
11231181

11241182
// check status of command buffers
11251183
// needed to detect if the device ran out-of-memory for example (#1881)
11261184
for (int i = 0; i < n_cb; i++) {
1127-
MTLCommandBufferStatus status = (MTLCommandBufferStatus) [command_buffers[i] status];
1185+
[ctx->command_buffers[i] waitUntilCompleted];
1186+
1187+
MTLCommandBufferStatus status = (MTLCommandBufferStatus) [ctx->command_buffers[i] status];
11281188
if (status != MTLCommandBufferStatusCompleted) {
11291189
fprintf(stderr, "%s: command buffer %d failed with status %lu\n", __func__, i, status);
11301190
GGML_ASSERT(false);
11311191
}
11321192
}
1193+
1194+
}
11331195
}

ggml.c

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2436,7 +2436,6 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
24362436
const int nb = n / qk;
24372437

24382438
assert(n % qk == 0);
2439-
assert(nb % 2 == 0);
24402439

24412440
const block_q4_0 * restrict x = vx;
24422441
const block_q8_0 * restrict y = vy;
@@ -2445,6 +2444,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
24452444
float32x4_t sumv0 = vdupq_n_f32(0.0f);
24462445
float32x4_t sumv1 = vdupq_n_f32(0.0f);
24472446

2447+
GGML_ASSERT(nb % 2 == 0); // TODO: handle odd nb
24482448
for (int i = 0; i < nb; i += 2) {
24492449
const block_q4_0 * restrict x0 = &x[i + 0];
24502450
const block_q4_0 * restrict x1 = &x[i + 1];
@@ -2623,6 +2623,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
26232623
}
26242624

26252625
// Main loop
2626+
GGML_ASSERT(nb % 2 == 0); // TODO: handle odd nb
26262627
for (int i = 2; i < nb; i+=2) {
26272628
_mm_prefetch(&x[i] + sizeof(block_q4_0), _MM_HINT_T0);
26282629
_mm_prefetch(&y[i] + sizeof(block_q8_0), _MM_HINT_T0);
@@ -2706,7 +2707,6 @@ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void *
27062707
const int nb = n / qk;
27072708

27082709
assert(n % qk == 0);
2709-
assert(nb % 2 == 0);
27102710

27112711
const block_q4_1 * restrict x = vx;
27122712
const block_q8_1 * restrict y = vy;
@@ -2718,6 +2718,7 @@ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void *
27182718

27192719
float summs = 0;
27202720

2721+
GGML_ASSERT(nb % 2 == 0); // TODO: handle odd nb
27212722
for (int i = 0; i < nb; i += 2) {
27222723
const block_q4_1 * restrict x0 = &x[i + 0];
27232724
const block_q4_1 * restrict x1 = &x[i + 1];
@@ -2832,7 +2833,6 @@ static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void *
28322833
const int nb = n / qk;
28332834

28342835
assert(n % qk == 0);
2835-
assert(nb % 2 == 0);
28362836
assert(qk == QK5_0);
28372837

28382838
const block_q5_0 * restrict x = vx;
@@ -2848,6 +2848,7 @@ static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void *
28482848
uint64_t tmp0[4];
28492849
uint64_t tmp1[4];
28502850

2851+
GGML_ASSERT(nb % 2 == 0); // TODO: handle odd nb
28512852
for (int i = 0; i < nb; i += 2) {
28522853
const block_q5_0 * restrict x0 = &x[i];
28532854
const block_q5_0 * restrict x1 = &x[i + 1];
@@ -3072,7 +3073,6 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void *
30723073
const int nb = n / qk;
30733074

30743075
assert(n % qk == 0);
3075-
assert(nb % 2 == 0);
30763076
assert(qk == QK5_1);
30773077

30783078
const block_q5_1 * restrict x = vx;
@@ -3091,6 +3091,7 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void *
30913091
uint64_t tmp0[4];
30923092
uint64_t tmp1[4];
30933093

3094+
GGML_ASSERT(nb % 2 == 0); // TODO: handle odd nb
30943095
for (int i = 0; i < nb; i += 2) {
30953096
const block_q5_1 * restrict x0 = &x[i];
30963097
const block_q5_1 * restrict x1 = &x[i + 1];
@@ -3328,7 +3329,6 @@ static void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void *
33283329
const int nb = n / qk;
33293330

33303331
assert(n % qk == 0);
3331-
assert(nb % 2 == 0);
33323332

33333333
const block_q8_0 * restrict x = vx;
33343334
const block_q8_0 * restrict y = vy;
@@ -3337,6 +3337,7 @@ static void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void *
33373337
float32x4_t sumv0 = vdupq_n_f32(0.0f);
33383338
float32x4_t sumv1 = vdupq_n_f32(0.0f);
33393339

3340+
GGML_ASSERT(nb % 2 == 0); // TODO: handle odd nb
33403341
for (int i = 0; i < nb; i += 2) {
33413342
const block_q8_0 * restrict x0 = &x[i + 0];
33423343
const block_q8_0 * restrict x1 = &x[i + 1];

0 commit comments

Comments
 (0)