From e458c308f44b01822b98b49467f992c287093ea2 Mon Sep 17 00:00:00 2001 From: nicoboss Date: Sun, 8 Jun 2025 15:41:07 +0200 Subject: [PATCH] Fix imatrix calculation for MLA models --- tools/imatrix/imatrix.cpp | 35 ++++++++++++++++++++++++----------- 1 file changed, 24 insertions(+), 11 deletions(-) diff --git a/tools/imatrix/imatrix.cpp b/tools/imatrix/imatrix.cpp index 3d3c66f7c3f87..0e9bce97fcbea 100644 --- a/tools/imatrix/imatrix.cpp +++ b/tools/imatrix/imatrix.cpp @@ -178,23 +178,36 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void * } else { auto & e = m_stats[wname]; if (e.values.empty()) { - e.values.resize(src1->ne[0], 0); - e.counts.resize(src1->ne[0], 0); + if (src0->ne[3] > 1) { + LOG_ERR("Unsupported 4D tensor %s\n", wname.c_str()); + exit(1); + } + // If we have a 3D tensor as it is the case for the attn_k_b and attn_v_b for DeepSeek MLA models, + // than we need to compute the imatrix for each head, and not just one imatrx for all heads. + // Hence, the storage we need is src0->ne[0]*src0->ne[2]. + e.values.resize(src0->ne[0]*src0->ne[2], 0); + e.counts.resize(src0->ne[0]*src0->ne[2], 0); } - else if (e.values.size() != (size_t)src1->ne[0]) { + else if (e.values.size() != (size_t)(src0->ne[0]*src0->ne[2])) { LOG_ERR("%s: inconsistent size for %s (%d vs %d)\n", __func__, wname.c_str(), (int)e.values.size(), (int)src1->ne[0]); exit(1); //GGML_ABORT("fatal error"); } ++e.ncall; LOG_DBGV(2, "%s[%d]: %32s, %s, %5d x %5d, %d\n", __func__, m_last_call, wname.c_str(), ggml_op_name(t->op), (int)src1->ne[0], (int)src1->ne[1], (int)src1->type); - for (int row = 0; row < (int)src1->ne[1]; ++row) { - const float * x = (const float *) (data + row * src1->nb[1]); - for (int j = 0; j < (int)src1->ne[0]; ++j) { - e.values[j] += x[j]*x[j]; - e.counts[j]++; - if (!std::isfinite(e.values[j])) { - LOG_ERR("%f detected in %s\n", e.values[j], wname.c_str()); - exit(1); + int rk2 = src1->ne[2]/src0->ne[2]; + for (int i12 = 0; i12 < (int)src1->ne[2]; ++i12) { // i.e., loop over attention heads for MLA models + int i02 = i12/rk2; + auto values = e.values.data() + i02*src0->ne[0]; + auto counts = e.counts.data() + i02*src0->ne[0]; + for (int i11 = 0; i11 < (int)src1->ne[1]; ++i11) { + const float * x = (const float *)((const char *)data + i11*src1->nb[1] + i12*src1->nb[2]); + for (int j = 0; j < (int)src1->ne[0]; ++j) { + values[j] += x[j]*x[j]; + counts[j]++; + if (!std::isfinite(values[j])) { + LOG_ERR("%f detected in %s\n", e.values[j], wname.c_str()); + exit(1); + } } } }