Skip to content

Commit 53315a9

Browse files
Fixed bug introduced in 5c1a961
This commit short-circuited broadcastability of shapes for zero-size rhs in __setitem__. Refix array API test failure without introducing this regression and reuse _manipulation_functions._broadcast_strides as suggested by @ndgrigorian
1 parent ef5fe17 commit 53315a9

File tree

3 files changed

+21
-23
lines changed

3 files changed

+21
-23
lines changed

dpctl/tensor/_copy_utils.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,23 @@ def _broadcast_shapes(sh1, sh2):
246246
).shape
247247

248248

249+
def _broadcast_strides(X_shape, X_strides, res_ndim):
250+
"""
251+
Broadcasts strides to match the given dimensions;
252+
returns tuple type strides.
253+
"""
254+
out_strides = [0] * res_ndim
255+
X_shape_len = len(X_shape)
256+
str_dim = -X_shape_len
257+
for i in range(X_shape_len):
258+
shape_value = X_shape[i]
259+
if not shape_value == 1:
260+
out_strides[str_dim] = X_strides[i]
261+
str_dim += 1
262+
263+
return tuple(out_strides)
264+
265+
249266
def _copy_from_usm_ndarray_to_usm_ndarray(dst, src):
250267
if any(
251268
not isinstance(arg, dpt.usm_ndarray)
@@ -268,7 +285,7 @@ def _copy_from_usm_ndarray_to_usm_ndarray(dst, src):
268285
except ValueError as exc:
269286
raise ValueError("Shapes of two arrays are not compatible") from exc
270287

271-
if dst.size < src.size:
288+
if dst.size < src.size and dst.size < np.prod(common_shape):
272289
raise ValueError("Destination is smaller ")
273290

274291
if len(common_shape) > dst.ndim:
@@ -279,9 +296,8 @@ def _copy_from_usm_ndarray_to_usm_ndarray(dst, src):
279296
common_shape = common_shape[ones_count:]
280297

281298
if src.ndim < len(common_shape):
282-
pad_count = len(common_shape) - src.ndim
283-
new_src_strides = (0,) * pad_count + tuple(
284-
s if d > 1 else 0 for s, d in zip(src.strides, src.shape)
299+
new_src_strides = _broadcast_strides(
300+
src.shape, src.strides, len(common_shape)
285301
)
286302
src_same_shape = dpt.usm_ndarray(
287303
common_shape, dtype=src.dtype, buffer=src, strides=new_src_strides

dpctl/tensor/_manipulation_functions.py

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import dpctl.tensor._tensor_impl as ti
2626
import dpctl.utils as dputils
2727

28+
from ._copy_utils import _broadcast_strides
2829
from ._type_utils import _to_device_supported_dtype
2930

3031
__doc__ = (
@@ -120,23 +121,6 @@ def __repr__(self):
120121
return self._finfo.__repr__()
121122

122123

123-
def _broadcast_strides(X_shape, X_strides, res_ndim):
124-
"""
125-
Broadcasts strides to match the given dimensions;
126-
returns tuple type strides.
127-
"""
128-
out_strides = [0] * res_ndim
129-
X_shape_len = len(X_shape)
130-
str_dim = -X_shape_len
131-
for i in range(X_shape_len):
132-
shape_value = X_shape[i]
133-
if not shape_value == 1:
134-
out_strides[str_dim] = X_strides[i]
135-
str_dim += 1
136-
137-
return tuple(out_strides)
138-
139-
140124
def _broadcast_shape_impl(shapes):
141125
if len(set(shapes)) == 1:
142126
return shapes[0]

dpctl/tensor/_usmarray.pyx

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1175,8 +1175,6 @@ cdef class usm_ndarray:
11751175
if adv_ind_start_p < 0:
11761176
# basic slicing
11771177
if isinstance(rhs, usm_ndarray):
1178-
if Xv.size == 0:
1179-
return
11801178
_copy_from_usm_ndarray_to_usm_ndarray(Xv, rhs)
11811179
else:
11821180
if hasattr(rhs, "__sycl_usm_array_interface__"):

0 commit comments

Comments
 (0)