Skip to content

Commit 5910ea9

Browse files
[SYCL] Fix DMMV dequantization (#9279)
Fixed dmmv dequant for ncols== GGML_SYCL_DMMV_X
1 parent c8671ae commit 5910ea9

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

ggml/src/ggml-sycl/dmmv.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,8 @@ static void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat *
7676
}
7777

7878
// sum up partial sums and write back result
79-
#pragma unroll
80-
for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
79+
const int mask_start = ncols > GGML_SYCL_DMMV_X ? WARP_SIZE >> 1 : WARP_SIZE >> 2;
80+
for (int mask = mask_start; mask > 0; mask >>= 1) {
8181
tmp +=
8282
dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
8383
}

0 commit comments

Comments
 (0)