Skip to content
Merged
1 change: 1 addition & 0 deletions dpctl/tensor/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ pybind11_add_module(${python_module_name} MODULE
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/copy_and_cast_usm_to_usm.cpp
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/copy_numpy_ndarray_into_usm_ndarray.cpp
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/copy_for_reshape.cpp
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/copy_for_roll.cpp
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/linear_sequences.cpp
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/integer_advanced_indexing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/boolean_advanced_indexing.cpp
Expand Down
36 changes: 13 additions & 23 deletions dpctl/tensor/_manipulation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
# limitations under the License.


from itertools import chain, product, repeat
import operator
from itertools import chain, repeat

import numpy as np
from numpy.core.numeric import normalize_axis_index, normalize_axis_tuple
Expand Down Expand Up @@ -426,10 +427,11 @@ def roll(X, shift, axis=None):
if not isinstance(X, dpt.usm_ndarray):
raise TypeError(f"Expected usm_ndarray type, got {type(X)}.")
if axis is None:
shift = operator.index(shift)
res = dpt.empty(
X.shape, dtype=X.dtype, usm_type=X.usm_type, sycl_queue=X.sycl_queue
)
hev, _ = ti._copy_usm_ndarray_for_reshape(
hev, _ = ti._copy_usm_ndarray_for_roll_1d(
src=X, dst=res, shift=shift, sycl_queue=X.sycl_queue
)
hev.wait()
Expand All @@ -438,31 +440,20 @@ def roll(X, shift, axis=None):
broadcasted = np.broadcast(shift, axis)
if broadcasted.ndim > 1:
raise ValueError("'shift' and 'axis' should be scalars or 1D sequences")
shifts = {ax: 0 for ax in range(X.ndim)}
shifts = [
0,
] * X.ndim
for sh, ax in broadcasted:
shifts[ax] += sh
rolls = [((np.s_[:], np.s_[:]),)] * X.ndim
for ax, offset in shifts.items():
offset %= X.shape[ax] or 1
if offset:
# (original, result), (original, result)
rolls[ax] = (
(np.s_[:-offset], np.s_[offset:]),
(np.s_[-offset:], np.s_[:offset]),
)

exec_q = X.sycl_queue
res = dpt.empty(
X.shape, dtype=X.dtype, usm_type=X.usm_type, sycl_queue=X.sycl_queue
X.shape, dtype=X.dtype, usm_type=X.usm_type, sycl_queue=exec_q
)
hev_list = []
for indices in product(*rolls):
arr_index, res_index = zip(*indices)
hev, _ = ti._copy_usm_ndarray_into_usm_ndarray(
src=X[arr_index], dst=res[res_index], sycl_queue=X.sycl_queue
)
hev_list.append(hev)

dpctl.SyclEvent.wait_for(hev_list)
ht_e, _ = ti._copy_usm_ndarray_for_roll_nd(
src=X, dst=res, shifts=shifts, sycl_queue=exec_q
)
ht_e.wait()
return res


Expand Down Expand Up @@ -550,7 +541,6 @@ def _concat_axis_None(arrays):
hev, _ = ti._copy_usm_ndarray_for_reshape(
src=src_,
dst=res[fill_start:fill_end],
shift=0,
sycl_queue=exec_q,
)
fill_start = fill_end
Expand Down
2 changes: 1 addition & 1 deletion dpctl/tensor/_reshape.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def reshape(X, shape, order="C", copy=None):
)
if order == "C":
hev, _ = _copy_usm_ndarray_for_reshape(
src=X, dst=flat_res, shift=0, sycl_queue=X.sycl_queue
src=X, dst=flat_res, sycl_queue=X.sycl_queue
)
hev.wait()
else:
Expand Down
Loading