Skip to content

Commit a6c33f9

Browse files
committed
remove wrong assert in norm
WA for permute(0,1,3,2) mul_mat
1 parent 958367b commit a6c33f9

File tree

2 files changed

+7
-3
lines changed

2 files changed

+7
-3
lines changed

ggml/src/ggml-sycl.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5173,6 +5173,10 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
51735173
if (op->op == GGML_OP_MUL_MAT) {
51745174
a = op->src[0];
51755175
b = op->src[1];
5176+
if (ggml_is_permuted(a) || ggml_is_permuted(b)) {
5177+
// TODO: fix like https://github.com/ggerganov/llama.cpp/pull/10021
5178+
return false;
5179+
}
51765180
} else {
51775181
a = op->src[2];
51785182
b = op->src[1];

ggml/src/ggml-sycl/norm.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ static void norm_f32(const float* x, float* dst, const int ncols, const float ep
88

99
const int nthreads = item_ct1.get_local_range(2);
1010
const int nwarps = nthreads / WARP_SIZE;
11-
assert(nwarps % WARP_SIZE == 0);
1211
sycl::float2 mean_var = sycl::float2(0.f, 0.f);
1312

1413
for (int col = tid; col < ncols; col += block_size) {
@@ -55,7 +54,6 @@ static void group_norm_f32(const float* x, float* dst, const int group_size, con
5554
int end = start + group_size;
5655
const int nthreads = item_ct1.get_local_range(2);
5756
const int nwarps = nthreads / WARP_SIZE;
58-
assert(nwarps % WARP_SIZE == 0);
5957
start += item_ct1.get_local_id(2);
6058
int nreduce = nwarps / WARP_SIZE;
6159

@@ -144,7 +142,6 @@ static void rms_norm_f32(const float* x, float* dst, const int ncols, const floa
144142
const int tid = item_ct1.get_local_id(2);
145143
const int nthreads = item_ct1.get_local_range(2);
146144
const int nwarps = nthreads / WARP_SIZE;
147-
assert(nwarps % WARP_SIZE == 0);
148145
float tmp = 0.0f; // partial sum for thread in warp
149146

150147
for (int col = tid; col < ncols; col += block_size) {
@@ -202,6 +199,7 @@ static void norm_f32_sycl(const float* x, float* dst, const int ncols,
202199
}
203200
else {
204201
const int work_group_size = ggml_sycl_info().max_work_group_sizes[device];
202+
assert(work_group_size % (WARP_SIZE * WARP_SIZE) == 0);
205203
const sycl::range<3> block_dims(1, 1, work_group_size);
206204
/*
207205
DPCT1049:17: The work-group size passed to the SYCL kernel may exceed
@@ -244,6 +242,7 @@ static void group_norm_f32_sycl(const float* x, float* dst,
244242
}
245243
else {
246244
const int work_group_size = ggml_sycl_info().max_work_group_sizes[device];
245+
assert(work_group_size % (WARP_SIZE * WARP_SIZE) == 0);
247246
const sycl::range<3> block_dims(1, 1, work_group_size);
248247
/*
249248
DPCT1049:18: The work-group size passed to the SYCL kernel may exceed
@@ -290,6 +289,7 @@ static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols,
290289
}
291290
else {
292291
const int work_group_size = ggml_sycl_info().max_work_group_sizes[device];
292+
assert(work_group_size % (WARP_SIZE * WARP_SIZE) == 0);
293293
const sycl::range<3> block_dims(1, 1, work_group_size);
294294
/*
295295
DPCT1049:19: The work-group size passed to the SYCL kernel may exceed

0 commit comments

Comments
 (0)