Skip to content

Commit 26ab943

Browse files
committed
llamafile : improve moe prompt eval speed on cpu
This change introduces a llamafile_mixmul() API that allows tinyBLAS to speed up "Mixture of Expert" models. On my Threadripper, Mixtral's 8x7b F16 weights now process prompts 2x faster. I'm also seeing a 60 percent improvement with Mixtral 8x22b Q4_0. The same applies to Q8_0, which is also supported by tinyBLAS. MoE models spend the majority of their time inside MUL_MAT_ID rather than MUL_MAT, which is why llamafile_sgemm was not able to help them before. llamafile_mixmul works by decomposing the mixmul operation into sgemm calls.
1 parent 4e96a81 commit 26ab943

File tree

3 files changed

+457
-40
lines changed

3 files changed

+457
-40
lines changed

ggml.c

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11003,11 +11003,14 @@ static void ggml_compute_forward_mul_mat_id(
1100311003
const struct ggml_tensor * src1 = dst->src[1];
1100411004
const struct ggml_tensor * ids = dst->src[2];
1100511005

11006-
GGML_TENSOR_BINARY_OP_LOCALS
11006+
if (llamafile_mixmul(params, src0, src1, ids, dst))
11007+
return;
1100711008

1100811009
const int ith = params->ith;
1100911010
const int nth = params->nth;
1101011011

11012+
GGML_TENSOR_BINARY_OP_LOCALS
11013+
1101111014
const enum ggml_type type = src0->type;
1101211015

1101311016
const bool src1_cont = ggml_is_contiguous(src1);
@@ -18504,6 +18507,7 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
1850418507
cur = 0;
1850518508
const struct ggml_tensor * src0 = node->src[0];
1850618509
const struct ggml_tensor * src1 = node->src[1];
18510+
const struct ggml_tensor * src2 = node->src[2];
1850718511
const enum ggml_type vec_dot_type = type_traits[src0->type].vec_dot_type;
1850818512
if (src1->type != vec_dot_type) {
1850918513
cur += ggml_row_size(vec_dot_type, ggml_nelements(src1));
@@ -18512,6 +18516,8 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
1851218516
cur += GGML_PAD(cur, sizeof(int64_t)); // align
1851318517
cur += n_as * sizeof(int64_t); // matrix_row_counts
1851418518
cur += n_as * src1->ne[2] * sizeof(int64_t); // matrix_rows
18519+
size_t cur2 = llamafile_mixmul_needs(src0, src1, src2);
18520+
cur = cur > cur2 ? cur : cur2;
1851518521
} break;
1851618522
case GGML_OP_OUT_PROD:
1851718523
{

0 commit comments

Comments
 (0)