@@ -8,7 +8,6 @@ static void norm_f32(const float* x, float* dst, const int ncols, const float ep
8
8
9
9
const int nthreads = item_ct1.get_local_range (2 );
10
10
const int nwarps = nthreads / WARP_SIZE;
11
- assert (nwarps % WARP_SIZE == 0 );
12
11
sycl::float2 mean_var = sycl::float2 (0 .f , 0 .f );
13
12
14
13
for (int col = tid; col < ncols; col += block_size) {
@@ -55,7 +54,6 @@ static void group_norm_f32(const float* x, float* dst, const int group_size, con
55
54
int end = start + group_size;
56
55
const int nthreads = item_ct1.get_local_range (2 );
57
56
const int nwarps = nthreads / WARP_SIZE;
58
- assert (nwarps % WARP_SIZE == 0 );
59
57
start += item_ct1.get_local_id (2 );
60
58
int nreduce = nwarps / WARP_SIZE;
61
59
@@ -144,7 +142,6 @@ static void rms_norm_f32(const float* x, float* dst, const int ncols, const floa
144
142
const int tid = item_ct1.get_local_id (2 );
145
143
const int nthreads = item_ct1.get_local_range (2 );
146
144
const int nwarps = nthreads / WARP_SIZE;
147
- assert (nwarps % WARP_SIZE == 0 );
148
145
float tmp = 0 .0f ; // partial sum for thread in warp
149
146
150
147
for (int col = tid; col < ncols; col += block_size) {
@@ -202,6 +199,7 @@ static void norm_f32_sycl(const float* x, float* dst, const int ncols,
202
199
}
203
200
else {
204
201
const int work_group_size = ggml_sycl_info ().max_work_group_sizes [device];
202
+ assert (work_group_size % (WARP_SIZE * WARP_SIZE) == 0 );
205
203
const sycl::range<3 > block_dims (1 , 1 , work_group_size);
206
204
/*
207
205
DPCT1049:17: The work-group size passed to the SYCL kernel may exceed
@@ -244,6 +242,7 @@ static void group_norm_f32_sycl(const float* x, float* dst,
244
242
}
245
243
else {
246
244
const int work_group_size = ggml_sycl_info ().max_work_group_sizes [device];
245
+ assert (work_group_size % (WARP_SIZE * WARP_SIZE) == 0 );
247
246
const sycl::range<3 > block_dims (1 , 1 , work_group_size);
248
247
/*
249
248
DPCT1049:18: The work-group size passed to the SYCL kernel may exceed
@@ -290,6 +289,7 @@ static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols,
290
289
}
291
290
else {
292
291
const int work_group_size = ggml_sycl_info ().max_work_group_sizes [device];
292
+ assert (work_group_size % (WARP_SIZE * WARP_SIZE) == 0 );
293
293
const sycl::range<3 > block_dims (1 , 1 , work_group_size);
294
294
/*
295
295
DPCT1049:19: The work-group size passed to the SYCL kernel may exceed
0 commit comments