Skip to content

Commit f3f62f0

Browse files
authored
metal : optimize ggml_mul_mat_id (faster Mixtral PP) (#4725)
* ggml : disable fast-math for Metal (cmake build only) ggml-ci * metal : fix Metal API debug warnings * cmake : add -fno-inline for Metal build (#4545) * metal : fix API debug warnings * metal : fix compile warnings * metal : use uint64_t for strides * cmake : rename option to LLAMA_METAL_SHADER_DEBUG * metal : fix mat-vec Q8_0 kernel for BS > 1 * metal : normalize mat-vec kernel signatures * cmake : respect LLAMA_QKK_64 option * metal : fix mat-vec Q4_K kernel for QK_K == 64 * metal : optimizing ggml_mul_mat_id (wip) * metal : minor fix * metal : opt mul_mm_id
1 parent 0ef3ca2 commit f3f62f0

File tree

2 files changed

+190
-46
lines changed

2 files changed

+190
-46
lines changed

ggml-metal.m

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1657,6 +1657,10 @@ void ggml_metal_graph_compute(
16571657
}
16581658
};
16591659

1660+
if (ggml_is_quantized(src0t)) {
1661+
GGML_ASSERT(ne00 >= nth0*nth1);
1662+
}
1663+
16601664
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
16611665
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
16621666
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
@@ -1715,6 +1719,9 @@ void ggml_metal_graph_compute(
17151719
// TODO: make this more general
17161720
GGML_ASSERT(n_as <= 8);
17171721

1722+
// max size of the src1ids array in the kernel stack
1723+
GGML_ASSERT(ne11 <= 512);
1724+
17181725
struct ggml_tensor * src2 = gf->nodes[i]->src[2];
17191726

17201727
const int64_t ne20 = src2 ? src2->ne[0] : 0;
@@ -1732,32 +1739,29 @@ void ggml_metal_graph_compute(
17321739
GGML_ASSERT(!ggml_is_transposed(src2));
17331740
GGML_ASSERT(!ggml_is_transposed(src1));
17341741

1735-
GGML_ASSERT(ne20 % 32 == 0);
1736-
// !!!!!!!!! TODO: this assert is probably required but not sure!
1737-
//GGML_ASSERT(ne20 >= 64);
17381742
GGML_ASSERT(src1t == GGML_TYPE_F32);
17391743

17401744
const uint r2 = ne12/ne22;
17411745
const uint r3 = ne13/ne23;
17421746

17431747
// find the break-even point where the matrix-matrix kernel becomes more efficient compared
17441748
// to the matrix-vector kernel
1745-
int ne11_mm_min = 1;
1749+
int ne11_mm_min = n_as;
17461750

17471751
const int idx = ((int32_t *) dst->op_params)[0];
17481752

17491753
// batch size
17501754
GGML_ASSERT(ne01 == ne11);
17511755

1752-
const int64_t _ne1 = 1; // kernel_mul_mm_impl needs a reference in constant memory
1753-
17541756
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
17551757
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
17561758
// !!!
17571759
// TODO: for now, always use mat-vec kernels until we figure out how to improve the
17581760
// indirect matrix multiplication
17591761
// !!!
1760-
if ([ctx->device supportsFamily:MTLGPUFamilyApple7] && _ne1 > ne11_mm_min) {
1762+
if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
1763+
ne20 % 32 == 0 && ne20 >= 64 &&
1764+
ne11 > ne11_mm_min) {
17611765
switch (src2->type) {
17621766
case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_f32_f32]; break;
17631767
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_f16_f32]; break;
@@ -1787,7 +1791,7 @@ void ggml_metal_graph_compute(
17871791
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:11];
17881792
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:12];
17891793
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13];
1790-
[encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:14];
1794+
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14];
17911795
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
17921796
[encoder setBytes:&r2 length:sizeof(r2) atIndex:16];
17931797
[encoder setBytes:&r3 length:sizeof(r3) atIndex:17];
@@ -1805,8 +1809,7 @@ void ggml_metal_graph_compute(
18051809

18061810
[encoder setThreadgroupMemoryLength:8192 atIndex:0];
18071811

1808-
// TODO: processing one row at a time (ne11 -> 1) is not efficient
1809-
[encoder dispatchThreadgroups:MTLSizeMake( (_ne1 + 31)/32, (ne21 + 63)/64, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
1812+
[encoder dispatchThreadgroups:MTLSizeMake((ne11 + 31)/32, (ne21 + 63)/64, n_as*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
18101813
} else {
18111814
int nth0 = 32;
18121815
int nth1 = 1;
@@ -1889,11 +1892,17 @@ void ggml_metal_graph_compute(
18891892
} break;
18901893
default:
18911894
{
1892-
GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t);
1895+
GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src2t);
18931896
GGML_ASSERT(false && "not implemented");
18941897
}
18951898
};
18961899

1900+
if (ggml_is_quantized(src2t)) {
1901+
GGML_ASSERT(ne20 >= nth0*nth1);
1902+
}
1903+
1904+
const int64_t _ne1 = 1; // kernels needs a reference in constant memory
1905+
18971906
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
18981907
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
18991908
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];

0 commit comments

Comments
 (0)