Skip to content

Commit 508a181

Browse files
committed
Misc: improve conditions for vec_dot handling
1 parent 95735a6 commit 508a181

File tree

2 files changed

+7
-7
lines changed

2 files changed

+7
-7
lines changed

ggml.c

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8127,7 +8127,7 @@ static void ggml_compute_forward_mul_mat(
81278127
#endif
81288128

81298129
if (params->type == GGML_TASK_INIT) {
8130-
if (vec_dot_type != GGML_TYPE_F32) {
8130+
if (src1->type != vec_dot_type) {
81318131
char * wdata = params->wdata;
81328132
const size_t row_size = ne10*GGML_TYPE_SIZE[vec_dot_type]/GGML_BLCK_SIZE[vec_dot_type];
81338133

@@ -8159,7 +8159,7 @@ static void ggml_compute_forward_mul_mat(
81598159
const int ir0 = dr*ith;
81608160
const int ir1 = MIN(ir0 + dr, nr);
81618161

8162-
void * wdata = (vec_dot_type == GGML_TYPE_F32) ? src1->data : params->wdata;
8162+
void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
81638163
const size_t row_size = ne00*GGML_TYPE_SIZE[vec_dot_type]/GGML_BLCK_SIZE[vec_dot_type];
81648164

81658165
for (int ir = ir0; ir < ir1; ++ir) {
@@ -11041,6 +11041,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
1104111041
//printf("nr0 = %8d, nr1 = %8d, nr0*nr1 = %8d, n_tasks = %d\n", nr0, nr1, nr0*nr1, node->n_tasks);
1104211042

1104311043
size_t cur = 0;
11044+
const enum ggml_type vec_dot_type = type_handling[node->src0->type].vec_dot_type;
1104411045

1104511046
if (node->src0->type == GGML_TYPE_F32 && node->src1->type == GGML_TYPE_F32) {
1104611047
cur = 0;
@@ -11049,7 +11050,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
1104911050
node->n_tasks = 1;
1105011051
}
1105111052
#endif
11052-
} else if (type_handling[node->src0->type].vec_dot && node->src1->type == GGML_TYPE_F32) {
11053+
} else if (node->src1->type != vec_dot_type) {
1105311054
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
1105411055
if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
1105511056
node->n_tasks = 1;
@@ -11058,8 +11059,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
1105811059
} else
1105911060
#endif
1106011061
{
11061-
const enum ggml_type type_q = type_handling[node->src0->type].vec_dot_type;
11062-
cur = GGML_TYPE_SIZE[type_q]*ggml_nelements(node->src1)/GGML_BLCK_SIZE[type_q];
11062+
cur = GGML_TYPE_SIZE[vec_dot_type]*ggml_nelements(node->src1)/GGML_BLCK_SIZE[vec_dot_type];
1106311063
}
1106411064
} else {
1106511065
GGML_ASSERT(false);

tests/test-quantize-perf.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ int main(int argc, char * argv[]) {
142142
break;
143143
}
144144
std::string op {argv[i]};
145-
if (op == "from_float_reference") {
145+
if (op == "quantize_row_q_reference") {
146146
params.op_quantize_row_q_reference = true;
147147
} else if (op == "quantize_row_q") {
148148
params.op_quantize_row_q = true;
@@ -229,7 +229,7 @@ int main(int argc, char * argv[]) {
229229
printf("%s\n", ggml_type_name(type));
230230

231231
if (params.op_quantize_row_q_reference) {
232-
printf(" from_float_reference\n");
232+
printf(" quantize_row_q_reference\n");
233233
for (size_t size : params.test_sizes) {
234234
printf(" %zu values (%.2f MB)\n", size, 4*size/(float)(1024*1024));
235235
auto quantize_fn = [&](void ) {

0 commit comments

Comments
 (0)