Skip to content

Commit aadb6b4

Browse files
Eliminated multiple uses of same literal constants in test_search_reduction_kernels
1 parent 2bc7939 commit aadb6b4

File tree

1 file changed

+8
-6
lines changed

1 file changed

+8
-6
lines changed

dpctl/tests/test_usm_ndarray_reductions.py

+8-6
Original file line numberDiff line numberDiff line change
@@ -175,9 +175,11 @@ def test_search_reduction_kernels(arg_dtype):
175175
q = get_queue_or_skip()
176176
skip_if_dtype_not_supported(arg_dtype, q)
177177

178-
x = dpt.ones((24 * 1025), dtype=arg_dtype, sycl_queue=q)
178+
x_shape = (24, 1024)
179+
x_size = np.prod(x_shape)
180+
x = dpt.ones(x_size, dtype=arg_dtype, sycl_queue=q)
179181
idx = randrange(x.size)
180-
idx_tup = np.unravel_index(idx, (24, 1025))
182+
idx_tup = np.unravel_index(idx, x_shape)
181183
x[idx] = 2
182184

183185
m = dpt.argmax(x)
@@ -194,7 +196,7 @@ def test_search_reduction_kernels(arg_dtype):
194196
m = dpt.argmax(y)
195197
assert m == 2 * idx
196198

197-
x = dpt.reshape(x, (24, 1025))
199+
x = dpt.reshape(x, x_shape)
198200

199201
x[idx_tup[0], :] = 3
200202
m = dpt.argmax(x, axis=0)
@@ -209,15 +211,15 @@ def test_search_reduction_kernels(arg_dtype):
209211
m = dpt.argmax(x, axis=1)
210212
assert dpt.all(m == idx)
211213

212-
x = dpt.ones((24 * 1025), dtype=arg_dtype, sycl_queue=q)
214+
x = dpt.ones(x_size, dtype=arg_dtype, sycl_queue=q)
213215
idx = randrange(x.size)
214-
idx_tup = np.unravel_index(idx, (24, 1025))
216+
idx_tup = np.unravel_index(idx, x_shape)
215217
x[idx] = 0
216218

217219
m = dpt.argmin(x)
218220
assert m == idx
219221

220-
x = dpt.reshape(x, (24, 1025))
222+
x = dpt.reshape(x, x_shape)
221223

222224
x[idx_tup[0], :] = -1
223225
m = dpt.argmin(x, axis=0)

0 commit comments

Comments
 (0)