Skip to content

Commit d27b3ca

Browse files
authored
ggml : fix repack work size for mul_mat_id (#14292)
ggml-ci
1 parent 9230dbe commit d27b3ca

File tree

1 file changed

+26
-12
lines changed

1 file changed

+26
-12
lines changed

ggml/src/ggml-cpu/repack.cpp

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1163,13 +1163,24 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
11631163
// not realy a GGML_TYPE_Q8_0 but same size.
11641164
switch (op->op) {
11651165
case GGML_OP_MUL_MAT:
1166-
size = ggml_row_size(PARAM_TYPE, ggml_nelements(op->src[1]));
1167-
return true;
1166+
{
1167+
size = ggml_row_size(PARAM_TYPE, ggml_nelements(op->src[1]));
1168+
return true;
1169+
}
11681170
case GGML_OP_MUL_MAT_ID:
1169-
size = ggml_row_size(PARAM_TYPE, ggml_nelements(op->src[1]));
1170-
size = GGML_PAD(size, sizeof(int64_t)); // + padding for next bloc.
1171-
size += sizeof(int64_t) * (1+op->src[0]->ne[2]) * op->src[1]->ne[2];
1172-
return true;
1171+
{
1172+
size = ggml_row_size(PARAM_TYPE, ggml_nelements(op->src[1]));
1173+
size = GGML_PAD(size, sizeof(int64_t)); // + padding for next bloc.
1174+
1175+
const int64_t ne02 = op->src[0]->ne[2]; // n_as, n_expert
1176+
const int64_t ne12 = op->src[1]->ne[2]; // n_tokens
1177+
1178+
const size_t sizeof_mmid_row_mapping = sizeof(int64_t);
1179+
1180+
size += sizeof_mmid_row_mapping*ne02*(ne12 + 1);
1181+
1182+
return true;
1183+
}
11731184
default:
11741185
// GGML_ABORT("fatal error");
11751186
break;
@@ -1305,14 +1316,17 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
13051316
int32_t i2;
13061317
};
13071318

1308-
GGML_ASSERT(params->wsize >= (GGML_PAD(nbw3, sizeof(int64_t)) + n_as * sizeof(int64_t) +
1309-
n_as * ne12 * sizeof(mmid_row_mapping)));
1319+
GGML_ASSERT(params->wsize >=
1320+
(GGML_PAD(nbw3, sizeof(int64_t)) +
1321+
n_as*(ne12 + 1)*sizeof(mmid_row_mapping))
1322+
);
13101323

1311-
auto * wdata = (char *) params->wdata;
1312-
auto * wdata_src1_end = (char *) wdata + GGML_PAD(nbw3, sizeof(int64_t));
1313-
auto * matrix_row_counts = (int64_t *) (wdata_src1_end); // [n_as]
1324+
auto * wdata = (char *)params->wdata;
1325+
auto * wdata_src1_end = (char *)wdata + GGML_PAD(nbw3, sizeof(int64_t));
13141326

1315-
struct mmid_row_mapping * matrix_rows = (struct mmid_row_mapping *) (matrix_row_counts + n_as); // [n_as][ne12]
1327+
// total of [n_as][ne12 + 1] elemets of type mmid_row_mapping (2*int32_t = int64_t)
1328+
auto * matrix_row_counts = (int64_t *) (wdata_src1_end); // [n_as]
1329+
struct mmid_row_mapping * matrix_rows = (struct mmid_row_mapping *) (matrix_row_counts + n_as); // [n_as][ne12]
13161330

13171331
// src1: float32 => param type
13181332
for (int64_t i12 = 0; i12 < ne12; ++i12) {

0 commit comments

Comments
 (0)