@@ -1163,13 +1163,24 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
1163
1163
// not realy a GGML_TYPE_Q8_0 but same size.
1164
1164
switch (op->op ) {
1165
1165
case GGML_OP_MUL_MAT:
1166
- size = ggml_row_size (PARAM_TYPE, ggml_nelements (op->src [1 ]));
1167
- return true ;
1166
+ {
1167
+ size = ggml_row_size (PARAM_TYPE, ggml_nelements (op->src [1 ]));
1168
+ return true ;
1169
+ }
1168
1170
case GGML_OP_MUL_MAT_ID:
1169
- size = ggml_row_size (PARAM_TYPE, ggml_nelements (op->src [1 ]));
1170
- size = GGML_PAD (size, sizeof (int64_t )); // + padding for next bloc.
1171
- size += sizeof (int64_t ) * (1 +op->src [0 ]->ne [2 ]) * op->src [1 ]->ne [2 ];
1172
- return true ;
1171
+ {
1172
+ size = ggml_row_size (PARAM_TYPE, ggml_nelements (op->src [1 ]));
1173
+ size = GGML_PAD (size, sizeof (int64_t )); // + padding for next bloc.
1174
+
1175
+ const int64_t ne02 = op->src [0 ]->ne [2 ]; // n_as, n_expert
1176
+ const int64_t ne12 = op->src [1 ]->ne [2 ]; // n_tokens
1177
+
1178
+ const size_t sizeof_mmid_row_mapping = sizeof (int64_t );
1179
+
1180
+ size += sizeof_mmid_row_mapping*ne02*(ne12 + 1 );
1181
+
1182
+ return true ;
1183
+ }
1173
1184
default :
1174
1185
// GGML_ABORT("fatal error");
1175
1186
break ;
@@ -1305,14 +1316,17 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
1305
1316
int32_t i2;
1306
1317
};
1307
1318
1308
- GGML_ASSERT (params->wsize >= (GGML_PAD (nbw3, sizeof (int64_t )) + n_as * sizeof (int64_t ) +
1309
- n_as * ne12 * sizeof (mmid_row_mapping)));
1319
+ GGML_ASSERT (params->wsize >=
1320
+ (GGML_PAD (nbw3, sizeof (int64_t )) +
1321
+ n_as*(ne12 + 1 )*sizeof (mmid_row_mapping))
1322
+ );
1310
1323
1311
- auto * wdata = (char *) params->wdata ;
1312
- auto * wdata_src1_end = (char *) wdata + GGML_PAD (nbw3, sizeof (int64_t ));
1313
- auto * matrix_row_counts = (int64_t *) (wdata_src1_end); // [n_as]
1324
+ auto * wdata = (char *)params->wdata ;
1325
+ auto * wdata_src1_end = (char *)wdata + GGML_PAD (nbw3, sizeof (int64_t ));
1314
1326
1315
- struct mmid_row_mapping * matrix_rows = (struct mmid_row_mapping *) (matrix_row_counts + n_as); // [n_as][ne12]
1327
+ // total of [n_as][ne12 + 1] elemets of type mmid_row_mapping (2*int32_t = int64_t)
1328
+ auto * matrix_row_counts = (int64_t *) (wdata_src1_end); // [n_as]
1329
+ struct mmid_row_mapping * matrix_rows = (struct mmid_row_mapping *) (matrix_row_counts + n_as); // [n_as][ne12]
1316
1330
1317
1331
// src1: float32 => param type
1318
1332
for (int64_t i12 = 0 ; i12 < ne12; ++i12) {
0 commit comments