@@ -175,9 +175,11 @@ def test_search_reduction_kernels(arg_dtype):
175
175
q = get_queue_or_skip ()
176
176
skip_if_dtype_not_supported (arg_dtype , q )
177
177
178
- x = dpt .ones ((24 * 1025 ), dtype = arg_dtype , sycl_queue = q )
178
+ x_shape = (24 , 1024 )
179
+ x_size = np .prod (x_shape )
180
+ x = dpt .ones (x_size , dtype = arg_dtype , sycl_queue = q )
179
181
idx = randrange (x .size )
180
- idx_tup = np .unravel_index (idx , ( 24 , 1025 ) )
182
+ idx_tup = np .unravel_index (idx , x_shape )
181
183
x [idx ] = 2
182
184
183
185
m = dpt .argmax (x )
@@ -194,7 +196,7 @@ def test_search_reduction_kernels(arg_dtype):
194
196
m = dpt .argmax (y )
195
197
assert m == 2 * idx
196
198
197
- x = dpt .reshape (x , ( 24 , 1025 ) )
199
+ x = dpt .reshape (x , x_shape )
198
200
199
201
x [idx_tup [0 ], :] = 3
200
202
m = dpt .argmax (x , axis = 0 )
@@ -209,15 +211,15 @@ def test_search_reduction_kernels(arg_dtype):
209
211
m = dpt .argmax (x , axis = 1 )
210
212
assert dpt .all (m == idx )
211
213
212
- x = dpt .ones (( 24 * 1025 ) , dtype = arg_dtype , sycl_queue = q )
214
+ x = dpt .ones (x_size , dtype = arg_dtype , sycl_queue = q )
213
215
idx = randrange (x .size )
214
- idx_tup = np .unravel_index (idx , ( 24 , 1025 ) )
216
+ idx_tup = np .unravel_index (idx , x_shape )
215
217
x [idx ] = 0
216
218
217
219
m = dpt .argmin (x )
218
220
assert m == idx
219
221
220
- x = dpt .reshape (x , ( 24 , 1025 ) )
222
+ x = dpt .reshape (x , x_shape )
221
223
222
224
x [idx_tup [0 ], :] = - 1
223
225
m = dpt .argmin (x , axis = 0 )
0 commit comments