Skip to content

Commit 8c570c9

Browse files
Minor arithmetic improvement to mmvq wrapper kernel (#7172)
1 parent eaf4bd8 commit 8c570c9

File tree

1 file changed

+11
-9
lines changed

1 file changed

+11
-9
lines changed

ggml-sycl.cpp

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8330,24 +8330,26 @@ static void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict_
83308330
const int blocks_per_row = ncols / qk;
83318331
const int blocks_per_warp = vdr * WARP_SIZE / qi;
83328332

8333-
// partial sum for each thread
8333+
const int qi_vdr = (qi / vdr); // N_threads processing 1 qk block
8334+
8335+
// partial sum for each thread
83348336
float tmp = 0.0f;
83358337

83368338
const block_q_t * x = (const block_q_t *) vx;
83378339
const block_q8_1 * y = (const block_q8_1 *) vy;
83388340

8339-
for (int i = item_ct1.get_local_id(2) / (qi / vdr); i < blocks_per_row;
8341+
for (int i = item_ct1.get_local_id(2) / qi_vdr; i < blocks_per_row;
83408342
i += blocks_per_warp) {
8341-
const int ibx = row*blocks_per_row + i; // x block index
8343+
const int ibx = row * blocks_per_row + i; // x block index
83428344

8343-
const int iby = i * (qk/QK8_1); // y block index that aligns with ibx
8345+
const int iby = i * (qk / QK8_1); // y block index that aligns with ibx
83448346

8345-
const int iqs =
8346-
vdr *
8347-
(item_ct1.get_local_id(2) %
8348-
(qi / vdr)); // x block quant index when casting the quants to int
8347+
const int iqs =
8348+
vdr *
8349+
(item_ct1.get_local_id(2) -
8350+
i * qi_vdr); // x block quant index when casting the quants to int
83498351

8350-
tmp += vec_dot_q_sycl(&x[ibx], &y[iby], iqs);
8352+
tmp += vec_dot_q_sycl(&x[ibx], &y[iby], iqs);
83518353
}
83528354

83538355
// sum up partial sums and write back result

0 commit comments

Comments
 (0)