@@ -563,7 +563,9 @@ static bool ggml_metal_heap_resize(struct ggml_metal_heap * heap, size_t size) {
563
563
return true ;
564
564
}
565
565
566
- static id <MTLBuffer > ggml_metal_heap_alloc (struct ggml_metal_heap * heap, size_t size, size_t alignment) {
566
+ static id <MTLBuffer > ggml_metal_heap_alloc (struct ggml_metal_heap * heap, size_t size) {
567
+ const size_t alignment = 1024 *1024 ;
568
+
567
569
const size_t size_aligned = GGML_PAD (size, alignment);
568
570
569
571
heap->need += size_aligned;
@@ -1582,7 +1584,8 @@ static bool ggml_metal_encode_node(
1582
1584
ggml_backend_t backend,
1583
1585
int idx,
1584
1586
id <MTLComputeCommandEncoder > encoder,
1585
- struct ggml_metal_heap * heap) {
1587
+ struct ggml_metal_heap * heap,
1588
+ bool no_compute) {
1586
1589
struct ggml_backend_metal_context * ctx = backend->context ;
1587
1590
struct ggml_backend_metal_device_context * ctx_dev = backend->device ->context ;
1588
1591
@@ -1620,6 +1623,28 @@ static bool ggml_metal_encode_node(
1620
1623
GGML_ABORT (" unsupported op" );
1621
1624
}
1622
1625
1626
+ id <MTLBuffer > h_src0 = nil ;
1627
+ switch (dst->op ) {
1628
+ case GGML_OP_SOFT_MAX:
1629
+ {
1630
+ h_src0 = ggml_metal_heap_alloc (heap, ggml_nbytes (src0));
1631
+ if (!h_src0) {
1632
+ // GGML_LOG_ERROR("%s: failed to allocate buffer, idx = %4d, size = %8zu, need = %8zu, max available = %9zu, heap size = %9zu, heap used = %zu\n",
1633
+ // __func__, idx, ggml_nbytes(src0), heap->need, [heap->obj maxAvailableSizeWithAlignment:0], [heap->obj size], [heap->obj usedSize]);
1634
+ return false ;
1635
+ } else {
1636
+ // GGML_LOG_ERROR("%s: allocated %zu\n", __func__, ggml_nbytes(src0));
1637
+ }
1638
+ } break ;
1639
+ default :
1640
+ {
1641
+ } break ;
1642
+ }
1643
+
1644
+ if (no_compute) {
1645
+ return true ;
1646
+ }
1647
+
1623
1648
const int64_t ne00 = src0 ? src0->ne [0 ] : 0 ;
1624
1649
const int64_t ne01 = src0 ? src0->ne [1 ] : 0 ;
1625
1650
const int64_t ne02 = src0 ? src0->ne [2 ] : 0 ;
@@ -2277,23 +2302,14 @@ static bool ggml_metal_encode_node(
2277
2302
/* .nb3 =*/ nb03,
2278
2303
};
2279
2304
2280
- id <MTLBuffer > id_src0h = ggml_metal_heap_alloc (heap, ggml_nbytes (src0), 64 *1024 );
2281
- if (!id_src0h) {
2282
- // GGML_LOG_ERROR("%s: failed to allocate buffer, idx = %4d, size = %8zu, need = %8zu, max available = %9zu, heap size = %9zu, heap used = %zu\n",
2283
- // __func__, idx, ggml_nbytes(src0), heap->need, [heap->obj maxAvailableSizeWithAlignment:0], [heap->obj size], [heap->obj usedSize]);
2284
- return true ;
2285
- } else {
2286
- // GGML_LOG_ERROR("%s: allocated %zu\n", __func__, ggml_nbytes(src0));
2287
- }
2288
-
2289
2305
if (src0->type == GGML_TYPE_F16) {
2290
2306
[encoder setComputePipelineState: ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F16].pipeline];
2291
2307
} else {
2292
2308
[encoder setComputePipelineState: ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline];
2293
2309
}
2294
2310
[encoder setBytes: &args_cpy length: sizeof (args_cpy) atIndex: 0 ];
2295
2311
[encoder setBuffer: id_src0 offset: offs_src0 atIndex: 1 ];
2296
- [encoder setBuffer: id_src0h offset: 0 atIndex: 2 ];
2312
+ [encoder setBuffer: h_src0 offset: 0 atIndex: 2 ];
2297
2313
2298
2314
GGML_ASSERT (ne00 % ggml_blck_size (src0->type ) == 0 );
2299
2315
int nth_cpy = MIN (1024 , ne00 / ggml_blck_size (src0->type ));
@@ -2314,11 +2330,11 @@ static bool ggml_metal_encode_node(
2314
2330
};
2315
2331
2316
2332
[encoder setComputePipelineState: pipeline];
2317
- [encoder setBuffer: id_src0h offset: 0 atIndex: 0 ];
2333
+ [encoder setBuffer: h_src0 offset: 0 atIndex: 0 ];
2318
2334
if (id_src1) {
2319
2335
[encoder setBuffer: id_src1 offset: offs_src1 atIndex: 1 ];
2320
2336
} else {
2321
- [encoder setBuffer: id_src0h offset: 0 atIndex: 1 ];
2337
+ [encoder setBuffer: h_src0 offset: 0 atIndex: 1 ];
2322
2338
}
2323
2339
[encoder setBuffer: id_dst offset: offs_dst atIndex: 2 ];
2324
2340
[encoder setBytes: &args length: sizeof (args) atIndex: 3 ];
@@ -4731,6 +4747,12 @@ static enum ggml_status ggml_metal_graph_compute(
4731
4747
}
4732
4748
}
4733
4749
4750
+ for (int i = 0 ; i <= n_cb; ++i) {
4751
+ struct ggml_metal_heap * heap = ctx->cmd_bufs [i].heap ;
4752
+
4753
+ [heap->obj setPurgeableState: MTLPurgeableStateNonVolatile ];
4754
+ }
4755
+
4734
4756
// the main thread commits the first few commands immediately
4735
4757
// cmd_buf[n_cb]
4736
4758
{
@@ -4823,6 +4845,7 @@ static enum ggml_status ggml_metal_graph_compute(
4823
4845
4824
4846
if (heap->fail == 0 ) {
4825
4847
ggml_metal_heap_reset (ctx->cmd_bufs [i].heap );
4848
+ [heap->obj setPurgeableState: MTLPurgeableStateEmpty ];
4826
4849
4827
4850
continue ;
4828
4851
}
@@ -5233,19 +5256,21 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
5233
5256
5234
5257
const bool should_capture = ctx->capture_next_compute ;
5235
5258
5259
+ bool no_compute = false ;
5260
+
5236
5261
for (int idx = node_start; idx < node_end; ++idx) {
5237
5262
if (should_capture) {
5238
5263
[encoder pushDebugGroup: [NSString stringWithCString: ggml_op_desc (ggml_graph_node (ctx->gf, idx)) encoding: NSUTF8StringEncoding]];
5239
5264
}
5240
5265
5241
- const bool res = ggml_metal_encode_node (backend, idx, encoder, heap);
5266
+ const bool res = ggml_metal_encode_node (backend, idx, encoder, heap, no_compute );
5242
5267
5243
5268
if (should_capture) {
5244
5269
[encoder popDebugGroup ];
5245
5270
}
5246
5271
5247
5272
if (!res) {
5248
- break ;
5273
+ no_compute = true ;
5249
5274
}
5250
5275
}
5251
5276
0 commit comments