Skip to content

Commit a0e584d

Browse files
ggerganovslaren
andauthored
imatrix : fix wname for mul_mat_id ops (#6271)
* imatrix : fix wname for mul_mat_id ops * also filter tensor names in mul_mat_id ops --------- Co-authored-by: slaren <[email protected]>
1 parent 7aed0ff commit a0e584d

File tree

1 file changed

+21
-18
lines changed

1 file changed

+21
-18
lines changed

examples/imatrix/imatrix.cpp

Lines changed: 21 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -50,29 +50,31 @@ class IMatrixCollector {
5050
void keep_imatrix(int ncall) const;
5151
};
5252

53+
// remove any prefix and suffixes from the name
54+
// CUDA0#blk.0.attn_k.weight#0 => blk.0.attn_k.weight
55+
static std::string filter_tensor_name(const char * name) {
56+
std::string wname;
57+
const char * p = strchr(name, '#');
58+
if (p != NULL) {
59+
p = p + 1;
60+
const char * q = strchr(p, '#');
61+
if (q != NULL) {
62+
wname = std::string(p, q - p);
63+
} else {
64+
wname = p;
65+
}
66+
} else {
67+
wname = name;
68+
}
69+
return wname;
70+
}
71+
5372
bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void * user_data) {
5473
GGML_UNUSED(user_data);
5574

5675
const struct ggml_tensor * src0 = t->src[0];
5776
const struct ggml_tensor * src1 = t->src[1];
58-
59-
std::string wname;
60-
{
61-
// remove any prefix and suffixes from the name
62-
// CUDA0#blk.0.attn_k.weight#0 => blk.0.attn_k.weight
63-
const char * p = strchr(src0->name, '#');
64-
if (p != NULL) {
65-
p = p + 1;
66-
const char * q = strchr(p, '#');
67-
if (q != NULL) {
68-
wname = std::string(p, q - p);
69-
} else {
70-
wname = p;
71-
}
72-
} else {
73-
wname = src0->name;
74-
}
75-
}
77+
std::string wname = filter_tensor_name(src0->name);
7678

7779
// when ask is true, the scheduler wants to know if we are interested in data from this tensor
7880
// if we return true, a follow-up call will be made with ask=false in which we can do the actual collection
@@ -112,6 +114,7 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void *
112114
// this is necessary to guarantee equal number of "ncall" for each tensor
113115
for (int ex = 0; ex < n_as; ++ex) {
114116
src0 = t->src[2 + ex];
117+
wname = filter_tensor_name(src0->name);
115118
auto& e = m_stats[wname];
116119
if (e.values.empty()) {
117120
e.values.resize(src1->ne[0], 0);

0 commit comments

Comments
 (0)