Skip to content

Commit 3516681

Browse files
committed
Fixes gh-1432
Caused by a typo in the Python binding changes made in #1427 Added a test for correct behavior
1 parent ebf118a commit 3516681

File tree

2 files changed

+7
-3
lines changed

2 files changed

+7
-3
lines changed

dpctl/tensor/libtensor/source/repeat.cpp

+1-3
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,6 @@ py_repeat_by_sequence(const dpctl::tensor::usm_ndarray &src,
136136
const py::ssize_t *dst_shape = dst.get_shape_raw();
137137
bool same_orthog_dims(true);
138138
size_t orthog_nelems(1); // number of orthogonal iterations
139-
140139
for (auto i = 0; i < axis; ++i) {
141140
auto src_sh_i = src_shape[i];
142141
orthog_nelems *= src_sh_i;
@@ -554,7 +553,6 @@ py_repeat_by_scalar(const dpctl::tensor::usm_ndarray &src,
554553
const py::ssize_t *dst_shape = dst.get_shape_raw();
555554
bool same_orthog_dims(true);
556555
size_t orthog_nelems(1); // number of orthogonal iterations
557-
558556
for (auto i = 0; i < axis; ++i) {
559557
auto src_sh_i = src_shape[i];
560558
orthog_nelems *= src_sh_i;
@@ -634,7 +632,7 @@ py_repeat_by_scalar(const dpctl::tensor::usm_ndarray &src,
634632
assert(dst_shape_vec.size() == 1);
635633
assert(dst_strides_vec.size() == 1);
636634

637-
if (src_nd > 0) {
635+
if (src_nd == 0) {
638636
src_shape_vec = {0};
639637
src_strides_vec = {0};
640638
}

dpctl/tests/test_usm_ndarray_manipulation.py

+6
Original file line numberDiff line numberDiff line change
@@ -1170,6 +1170,12 @@ def test_repeat_axes():
11701170
res = dpt.repeat(x, reps, axis=1)
11711171
assert dpt.all(res == expected_res)
11721172

1173+
x = dpt.arange(10, dtype="i4")
1174+
expected_res = dpt.empty(x.shape[0] * reps, x.dtype)
1175+
expected_res[::2], expected_res[1::2] = x, x
1176+
res = dpt.repeat(x, reps, axis=0)
1177+
assert dpt.all(res == expected_res)
1178+
11731179

11741180
def test_repeat_size_0_outputs():
11751181
get_queue_or_skip()

0 commit comments

Comments
 (0)