diff --git a/dpctl/tensor/_copy_utils.py b/dpctl/tensor/_copy_utils.py index 81928692a6..ae261c50c1 100644 --- a/dpctl/tensor/_copy_utils.py +++ b/dpctl/tensor/_copy_utils.py @@ -300,14 +300,22 @@ def _copy_from_usm_ndarray_to_usm_ndarray(dst, src): src.shape, src.strides, len(common_shape) ) src_same_shape = dpt.usm_ndarray( - common_shape, dtype=src.dtype, buffer=src, strides=new_src_strides + common_shape, + dtype=src.dtype, + buffer=src, + strides=new_src_strides, + offset=src._element_offset, ) elif src.ndim == len(common_shape): new_src_strides = _broadcast_strides( src.shape, src.strides, len(common_shape) ) src_same_shape = dpt.usm_ndarray( - common_shape, dtype=src.dtype, buffer=src, strides=new_src_strides + common_shape, + dtype=src.dtype, + buffer=src, + strides=new_src_strides, + offset=src._element_offset, ) else: # since broadcasting succeeded, src.ndim is greater because of diff --git a/dpctl/tests/test_usm_ndarray_ctor.py b/dpctl/tests/test_usm_ndarray_ctor.py index 7227e687af..5c0707e619 100644 --- a/dpctl/tests/test_usm_ndarray_ctor.py +++ b/dpctl/tests/test_usm_ndarray_ctor.py @@ -1035,6 +1035,7 @@ def test_setitem_same_dtype(dtype, src_usm_type, dst_usm_type): def test_setitem_broadcasting(): + "See gh-1503" get_queue_or_skip() dst = dpt.ones((2, 3, 4), dtype="u4") src = dpt.zeros((3, 1), dtype=dst.dtype) @@ -1043,6 +1044,16 @@ def test_setitem_broadcasting(): assert np.array_equal(dpt.asnumpy(dst), expected) +def test_setitem_broadcasting_offset(): + get_queue_or_skip() + dt = dpt.int32 + x = dpt.asarray([[1, 2, 3], [6, 7, 8]], dtype=dt) + y = dpt.asarray([4, 5], dtype=dt) + x[0] = y[1] + expected = dpt.asarray([[5, 5, 5], [6, 7, 8]], dtype=dt) + assert dpt.all(x == expected) + + def test_setitem_broadcasting_empty_dst_validation(): "Broadcasting rules apply, except exception" get_queue_or_skip()