diff --git a/dpctl/tensor/libtensor/source/repeat.cpp b/dpctl/tensor/libtensor/source/repeat.cpp index 391f995feb..f3a20cbbaa 100644 --- a/dpctl/tensor/libtensor/source/repeat.cpp +++ b/dpctl/tensor/libtensor/source/repeat.cpp @@ -136,7 +136,6 @@ py_repeat_by_sequence(const dpctl::tensor::usm_ndarray &src, const py::ssize_t *dst_shape = dst.get_shape_raw(); bool same_orthog_dims(true); size_t orthog_nelems(1); // number of orthogonal iterations - for (auto i = 0; i < axis; ++i) { auto src_sh_i = src_shape[i]; orthog_nelems *= src_sh_i; @@ -554,7 +553,6 @@ py_repeat_by_scalar(const dpctl::tensor::usm_ndarray &src, const py::ssize_t *dst_shape = dst.get_shape_raw(); bool same_orthog_dims(true); size_t orthog_nelems(1); // number of orthogonal iterations - for (auto i = 0; i < axis; ++i) { auto src_sh_i = src_shape[i]; orthog_nelems *= src_sh_i; @@ -634,7 +632,7 @@ py_repeat_by_scalar(const dpctl::tensor::usm_ndarray &src, assert(dst_shape_vec.size() == 1); assert(dst_strides_vec.size() == 1); - if (src_nd > 0) { + if (src_nd == 0) { src_shape_vec = {0}; src_strides_vec = {0}; } diff --git a/dpctl/tests/test_usm_ndarray_manipulation.py b/dpctl/tests/test_usm_ndarray_manipulation.py index ae32afdba9..f3704274d4 100644 --- a/dpctl/tests/test_usm_ndarray_manipulation.py +++ b/dpctl/tests/test_usm_ndarray_manipulation.py @@ -1170,6 +1170,12 @@ def test_repeat_axes(): res = dpt.repeat(x, reps, axis=1) assert dpt.all(res == expected_res) + x = dpt.arange(10, dtype="i4") + expected_res = dpt.empty(x.shape[0] * reps, x.dtype) + expected_res[::2], expected_res[1::2] = x, x + res = dpt.repeat(x, reps, axis=0) + assert dpt.all(res == expected_res) + def test_repeat_size_0_outputs(): get_queue_or_skip()