Skip to content

Commit f686102

Browse files
Merge pull request #1467 from IntelPython/fix-usm-ndarray-ctor-when-shape-is-integral-numpy-scalar
Fix usm_ndarray ctor when shape is integral numpy scalar
2 parents dbab3fe + da59476 commit f686102

File tree

3 files changed

+23
-13
lines changed

3 files changed

+23
-13
lines changed

dpctl/tensor/_usmarray.pyx

+14-7
Original file line numberDiff line numberDiff line change
@@ -182,13 +182,20 @@ cdef class usm_ndarray:
182182
cdef bint is_fp16 = False
183183

184184
self._reset()
185-
if (not isinstance(shape, (list, tuple))
186-
and not hasattr(shape, 'tolist')):
187-
try:
188-
<Py_ssize_t> shape
189-
shape = [shape, ]
190-
except Exception:
191-
raise TypeError("Argument shape must be a list or a tuple.")
185+
if not isinstance(shape, (list, tuple)):
186+
if hasattr(shape, 'tolist'):
187+
fn = getattr(shape, 'tolist')
188+
if callable(fn):
189+
shape = shape.tolist()
190+
if not isinstance(shape, (list, tuple)):
191+
try:
192+
<Py_ssize_t> shape
193+
shape = [shape, ]
194+
except Exception as e:
195+
raise TypeError(
196+
"Argument shape must a non-negative integer, "
197+
"or a list/tuple of such integers."
198+
) from e
192199
nd = len(shape)
193200
if dtype is None:
194201
if isinstance(buffer, (dpmem._memory._Memory, usm_ndarray)):

dpctl/tests/test_usm_ndarray_ctor.py

+1
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
(2, 5, 2),
4040
(2, 2, 2, 2, 2, 2, 2, 2),
4141
5,
42+
np.int32(7),
4243
],
4344
)
4445
@pytest.mark.parametrize("usm_type", ["shared", "host", "device"])

dpctl/tests/test_usm_ndarray_reductions.py

+8-6
Original file line numberDiff line numberDiff line change
@@ -175,9 +175,11 @@ def test_search_reduction_kernels(arg_dtype):
175175
q = get_queue_or_skip()
176176
skip_if_dtype_not_supported(arg_dtype, q)
177177

178-
x = dpt.ones((24 * 1025), dtype=arg_dtype, sycl_queue=q)
178+
x_shape = (24, 1024)
179+
x_size = np.prod(x_shape)
180+
x = dpt.ones(x_size, dtype=arg_dtype, sycl_queue=q)
179181
idx = randrange(x.size)
180-
idx_tup = np.unravel_index(idx, (24, 1025))
182+
idx_tup = np.unravel_index(idx, x_shape)
181183
x[idx] = 2
182184

183185
m = dpt.argmax(x)
@@ -194,7 +196,7 @@ def test_search_reduction_kernels(arg_dtype):
194196
m = dpt.argmax(y)
195197
assert m == 2 * idx
196198

197-
x = dpt.reshape(x, (24, 1025))
199+
x = dpt.reshape(x, x_shape)
198200

199201
x[idx_tup[0], :] = 3
200202
m = dpt.argmax(x, axis=0)
@@ -209,15 +211,15 @@ def test_search_reduction_kernels(arg_dtype):
209211
m = dpt.argmax(x, axis=1)
210212
assert dpt.all(m == idx)
211213

212-
x = dpt.ones((24 * 1025), dtype=arg_dtype, sycl_queue=q)
214+
x = dpt.ones(x_size, dtype=arg_dtype, sycl_queue=q)
213215
idx = randrange(x.size)
214-
idx_tup = np.unravel_index(idx, (24, 1025))
216+
idx_tup = np.unravel_index(idx, x_shape)
215217
x[idx] = 0
216218

217219
m = dpt.argmin(x)
218220
assert m == idx
219221

220-
x = dpt.reshape(x, (24, 1025))
222+
x = dpt.reshape(x, x_shape)
221223

222224
x[idx_tup[0], :] = -1
223225
m = dpt.argmin(x, axis=0)

0 commit comments

Comments
 (0)