@@ -175,9 +175,11 @@ def test_search_reduction_kernels(arg_dtype):
175175 q = get_queue_or_skip ()
176176 skip_if_dtype_not_supported (arg_dtype , q )
177177
178- x = dpt .ones ((24 * 1025 ), dtype = arg_dtype , sycl_queue = q )
178+ x_shape = (24 , 1024 )
179+ x_size = np .prod (x_shape )
180+ x = dpt .ones (x_size , dtype = arg_dtype , sycl_queue = q )
179181 idx = randrange (x .size )
180- idx_tup = np .unravel_index (idx , ( 24 , 1025 ) )
182+ idx_tup = np .unravel_index (idx , x_shape )
181183 x [idx ] = 2
182184
183185 m = dpt .argmax (x )
@@ -194,7 +196,7 @@ def test_search_reduction_kernels(arg_dtype):
194196 m = dpt .argmax (y )
195197 assert m == 2 * idx
196198
197- x = dpt .reshape (x , ( 24 , 1025 ) )
199+ x = dpt .reshape (x , x_shape )
198200
199201 x [idx_tup [0 ], :] = 3
200202 m = dpt .argmax (x , axis = 0 )
@@ -209,15 +211,15 @@ def test_search_reduction_kernels(arg_dtype):
209211 m = dpt .argmax (x , axis = 1 )
210212 assert dpt .all (m == idx )
211213
212- x = dpt .ones (( 24 * 1025 ) , dtype = arg_dtype , sycl_queue = q )
214+ x = dpt .ones (x_size , dtype = arg_dtype , sycl_queue = q )
213215 idx = randrange (x .size )
214- idx_tup = np .unravel_index (idx , ( 24 , 1025 ) )
216+ idx_tup = np .unravel_index (idx , x_shape )
215217 x [idx ] = 0
216218
217219 m = dpt .argmin (x )
218220 assert m == idx
219221
220- x = dpt .reshape (x , ( 24 , 1025 ) )
222+ x = dpt .reshape (x , x_shape )
221223
222224 x [idx_tup [0 ], :] = - 1
223225 m = dpt .argmin (x , axis = 0 )
0 commit comments