Skip to content

repeat with axis=None repeats flattened array #1427

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Oct 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 48 additions & 34 deletions dpctl/tensor/_manipulation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import operator

import numpy as np
from numpy import AxisError
from numpy.core.numeric import normalize_axis_index, normalize_axis_tuple

import dpctl
Expand Down Expand Up @@ -929,20 +928,26 @@ def repeat(x, repeats, axis=None):
Args:
x (usm_ndarray): input array

repeat (Union[int, Tuple[int, ...]]):
repeats (Union[int, Sequence[int, ...], usm_ndarray]):
The number of repetitions for each element.
`repeats` is broadcasted to fit the shape of the given axis.
`repeats` is broadcast to fit the shape of the given axis.
If `repeats` is an array, it must have an integer data type.
Otherwise, `repeats` must be a Python integer, tuple, list, or
range.

axis (Optional[int]):
The axis along which to repeat values. The `axis` is required
if input array has more than one dimension.
The axis along which to repeat values. If `axis` is `None`, the
function repeats elements of the flattened array.
Default: `None`.

Returns:
usm_narray:
Array with repeated elements.
The returned array must have the same data type as `x`,
is created on the same device as `x` and has the same USM
allocation type as `x`.
The returned array must have the same data type as `x`, is created
on the same device as `x` and has the same USM allocation type as
`x`. If `axis` is `None`, the returned array is one-dimensional,
otherwise, it has the same shape as `x`, except for the axis along
which elements were repeated.

Raises:
AxisError: if `axis` value is invalid.
Expand All @@ -951,20 +956,11 @@ def repeat(x, repeats, axis=None):
raise TypeError(f"Expected usm_ndarray type, got {type(x)}.")

x_ndim = x.ndim
if axis is None:
if x_ndim > 1:
raise ValueError(
f"`axis` cannot be `None` for array of dimension {x_ndim}"
)
axis = 0

x_shape = x.shape
if x_ndim > 0:
if axis is not None:
axis = normalize_axis_index(operator.index(axis), x_ndim)
axis_size = x_shape[axis]
else:
if axis != 0:
AxisError("`axis` must be `0` for input of dimension `0`")
axis_size = x.size

scalar = False
Expand All @@ -977,8 +973,8 @@ def repeat(x, repeats, axis=None):
elif isinstance(repeats, dpt.usm_ndarray):
if repeats.ndim > 1:
raise ValueError(
"`repeats` array must be 0- or 1-dimensional, got"
"{repeats.ndim}"
"`repeats` array must be 0- or 1-dimensional, got "
f"{repeats.ndim}"
)
exec_q = dpctl.utils.get_execution_queue(
(x.sycl_queue, repeats.sycl_queue)
Expand Down Expand Up @@ -1015,30 +1011,30 @@ def repeat(x, repeats, axis=None):
if not dpt.all(repeats >= 0):
raise ValueError("`repeats` elements must be positive")

elif isinstance(repeats, tuple):
elif isinstance(repeats, (tuple, list, range)):
usm_type = x.usm_type
exec_q = x.sycl_queue

len_reps = len(repeats)
if len_reps != axis_size:
raise ValueError(
"`repeats` tuple must have the same length as the repeated "
"axis"
)
elif len_reps == 1:
if len_reps == 1:
repeats = repeats[0]
if repeats < 0:
raise ValueError("`repeats` elements must be positive")
scalar = True
else:
if len_reps != axis_size:
raise ValueError(
"`repeats` sequence must have the same length as the "
"repeated axis"
)
repeats = dpt.asarray(
repeats, dtype=dpt.int64, usm_type=usm_type, sycl_queue=exec_q
)
if not dpt.all(repeats >= 0):
raise ValueError("`repeats` elements must be positive")
else:
raise TypeError(
"Expected int, tuple, or `usm_ndarray` for second argument,"
"Expected int, sequence, or `usm_ndarray` for second argument,"
f"got {type(repeats)}"
)

Expand All @@ -1047,7 +1043,10 @@ def repeat(x, repeats, axis=None):

if scalar:
res_axis_size = repeats * axis_size
res_shape = x_shape[:axis] + (res_axis_size,) + x_shape[axis + 1 :]
if axis is not None:
res_shape = x_shape[:axis] + (res_axis_size,) + x_shape[axis + 1 :]
else:
res_shape = (res_axis_size,)
res = dpt.empty(
res_shape, dtype=x.dtype, usm_type=usm_type, sycl_queue=exec_q
)
Expand Down Expand Up @@ -1081,9 +1080,17 @@ def repeat(x, repeats, axis=None):
res_axis_size = ti._cumsum_1d(
rep_buf, cumsum, sycl_queue=exec_q, depends=[copy_ev]
)
res_shape = x_shape[:axis] + (res_axis_size,) + x_shape[axis + 1 :]
if axis is not None:
res_shape = (
x_shape[:axis] + (res_axis_size,) + x_shape[axis + 1 :]
)
else:
res_shape = (res_axis_size,)
res = dpt.empty(
res_shape, dtype=x.dtype, usm_type=usm_type, sycl_queue=exec_q
res_shape,
dtype=x.dtype,
usm_type=usm_type,
sycl_queue=exec_q,
)
if res_axis_size > 0:
ht_rep_ev, _ = ti._repeat_by_sequence(
Expand All @@ -1103,11 +1110,18 @@ def repeat(x, repeats, axis=None):
usm_type=usm_type,
sycl_queue=exec_q,
)
# _cumsum_1d synchronizes so `depends` ends here safely
res_axis_size = ti._cumsum_1d(repeats, cumsum, sycl_queue=exec_q)
res_shape = x_shape[:axis] + (res_axis_size,) + x_shape[axis + 1 :]
if axis is not None:
res_shape = (
x_shape[:axis] + (res_axis_size,) + x_shape[axis + 1 :]
)
else:
res_shape = (res_axis_size,)
res = dpt.empty(
res_shape, dtype=x.dtype, usm_type=usm_type, sycl_queue=exec_q
res_shape,
dtype=x.dtype,
usm_type=usm_type,
sycl_queue=exec_q,
)
if res_axis_size > 0:
ht_rep_ev, _ = ti._repeat_by_sequence(
Expand Down
88 changes: 49 additions & 39 deletions dpctl/tensor/libtensor/include/kernels/repeat.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,16 @@ namespace py = pybind11;
using namespace dpctl::tensor::offset_utils;

template <typename OrthogIndexer,
typename AxisIndexer,
typename SrcAxisIndexer,
typename DstAxisIndexer,
typename RepIndexer,
typename T,
typename repT>
class repeat_by_sequence_kernel;

template <typename OrthogIndexer,
typename AxisIndexer,
typename SrcAxisIndexer,
typename DstAxisIndexer,
typename RepIndexer,
typename T,
typename repT>
Expand All @@ -66,8 +68,8 @@ class RepeatSequenceFunctor
const repT *cumsum = nullptr;
size_t src_axis_nelems = 1;
OrthogIndexer orthog_strider;
AxisIndexer src_axis_strider;
AxisIndexer dst_axis_strider;
SrcAxisIndexer src_axis_strider;
DstAxisIndexer dst_axis_strider;
RepIndexer reps_strider;

public:
Expand All @@ -77,8 +79,8 @@ class RepeatSequenceFunctor
const repT *cumsum_,
size_t src_axis_nelems_,
OrthogIndexer orthog_strider_,
AxisIndexer src_axis_strider_,
AxisIndexer dst_axis_strider_,
SrcAxisIndexer src_axis_strider_,
DstAxisIndexer dst_axis_strider_,
RepIndexer reps_strider_)
: src(src_), dst(dst_), reps(reps_), cumsum(cumsum_),
src_axis_nelems(src_axis_nelems_), orthog_strider(orthog_strider_),
Expand Down Expand Up @@ -167,12 +169,12 @@ repeat_by_sequence_impl(sycl::queue &q,

const size_t gws = orthog_nelems * src_axis_nelems;

cgh.parallel_for<repeat_by_sequence_kernel<TwoOffsets_StridedIndexer,
Strided1DIndexer,
Strided1DIndexer, T, repT>>(
cgh.parallel_for<repeat_by_sequence_kernel<
TwoOffsets_StridedIndexer, Strided1DIndexer, Strided1DIndexer,
Strided1DIndexer, T, repT>>(
sycl::range<1>(gws),
RepeatSequenceFunctor<TwoOffsets_StridedIndexer, Strided1DIndexer,
Strided1DIndexer, T, repT>(
Strided1DIndexer, Strided1DIndexer, T, repT>(
src_tp, dst_tp, reps_tp, cumsum_tp, src_axis_nelems,
orthog_indexer, src_axis_indexer, dst_axis_indexer,
reps_indexer));
Expand All @@ -197,8 +199,8 @@ typedef sycl::event (*repeat_by_sequence_1d_fn_ptr_t)(
char *,
const char *,
const char *,
py::ssize_t,
py::ssize_t,
int,
const py::ssize_t *,
py::ssize_t,
py::ssize_t,
py::ssize_t,
Expand All @@ -212,8 +214,8 @@ sycl::event repeat_by_sequence_1d_impl(sycl::queue &q,
char *dst_cp,
const char *reps_cp,
const char *cumsum_cp,
py::ssize_t src_shape,
py::ssize_t src_stride,
int src_nd,
const py::ssize_t *src_shape_strides,
py::ssize_t dst_shape,
py::ssize_t dst_stride,
py::ssize_t reps_shape,
Expand All @@ -231,19 +233,19 @@ sycl::event repeat_by_sequence_1d_impl(sycl::queue &q,
// orthog ndim indexer
TwoZeroOffsets_Indexer orthog_indexer{};
// indexers along repeated axis
Strided1DIndexer src_indexer{0, src_shape, src_stride};
StridedIndexer src_indexer{src_nd, 0, src_shape_strides};
Strided1DIndexer dst_indexer{0, dst_shape, dst_stride};
// indexer along reps array
Strided1DIndexer reps_indexer{0, reps_shape, reps_stride};

const size_t gws = src_nelems;

cgh.parallel_for<
repeat_by_sequence_kernel<TwoZeroOffsets_Indexer, Strided1DIndexer,
Strided1DIndexer, T, repT>>(
cgh.parallel_for<repeat_by_sequence_kernel<
TwoZeroOffsets_Indexer, StridedIndexer, Strided1DIndexer,
Strided1DIndexer, T, repT>>(
sycl::range<1>(gws),
RepeatSequenceFunctor<TwoZeroOffsets_Indexer, Strided1DIndexer,
Strided1DIndexer, T, repT>(
RepeatSequenceFunctor<TwoZeroOffsets_Indexer, StridedIndexer,
Strided1DIndexer, Strided1DIndexer, T, repT>(
src_tp, dst_tp, reps_tp, cumsum_tp, src_nelems, orthog_indexer,
src_indexer, dst_indexer, reps_indexer));
});
Expand All @@ -260,10 +262,16 @@ template <typename fnT, typename T> struct RepeatSequence1DFactory
}
};

template <typename OrthogIndexer, typename AxisIndexer, typename T>
template <typename OrthogIndexer,
typename SrcAxisIndexer,
typename DstAxisIndexer,
typename T>
class repeat_by_scalar_kernel;

template <typename OrthogIndexer, typename AxisIndexer, typename T>
template <typename OrthogIndexer,
typename SrcAxisIndexer,
typename DstAxisIndexer,
typename T>
class RepeatScalarFunctor
{
private:
Expand All @@ -272,17 +280,17 @@ class RepeatScalarFunctor
const py::ssize_t reps = 1;
size_t dst_axis_nelems = 0;
OrthogIndexer orthog_strider;
AxisIndexer src_axis_strider;
AxisIndexer dst_axis_strider;
SrcAxisIndexer src_axis_strider;
DstAxisIndexer dst_axis_strider;

public:
RepeatScalarFunctor(const T *src_,
T *dst_,
const py::ssize_t reps_,
size_t dst_axis_nelems_,
OrthogIndexer orthog_strider_,
AxisIndexer src_axis_strider_,
AxisIndexer dst_axis_strider_)
SrcAxisIndexer src_axis_strider_,
DstAxisIndexer dst_axis_strider_)
: src(src_), dst(dst_), reps(reps_), dst_axis_nelems(dst_axis_nelems_),
orthog_strider(orthog_strider_), src_axis_strider(src_axis_strider_),
dst_axis_strider(dst_axis_strider_)
Expand Down Expand Up @@ -354,10 +362,11 @@ sycl::event repeat_by_scalar_impl(sycl::queue &q,

const size_t gws = orthog_nelems * dst_axis_nelems;

cgh.parallel_for<repeat_by_scalar_kernel<TwoOffsets_StridedIndexer,
Strided1DIndexer, T>>(
cgh.parallel_for<repeat_by_scalar_kernel<
TwoOffsets_StridedIndexer, Strided1DIndexer, Strided1DIndexer, T>>(
sycl::range<1>(gws),
RepeatScalarFunctor<TwoOffsets_StridedIndexer, Strided1DIndexer, T>(
RepeatScalarFunctor<TwoOffsets_StridedIndexer, Strided1DIndexer,
Strided1DIndexer, T>(
src_tp, dst_tp, reps, dst_axis_nelems, orthog_indexer,
src_axis_indexer, dst_axis_indexer));
});
Expand All @@ -380,8 +389,8 @@ typedef sycl::event (*repeat_by_scalar_1d_fn_ptr_t)(
const char *,
char *,
const py::ssize_t,
py::ssize_t,
py::ssize_t,
int,
const py::ssize_t *,
py::ssize_t,
py::ssize_t,
const std::vector<sycl::event> &);
Expand All @@ -392,8 +401,8 @@ sycl::event repeat_by_scalar_1d_impl(sycl::queue &q,
const char *src_cp,
char *dst_cp,
const py::ssize_t reps,
py::ssize_t src_shape,
py::ssize_t src_stride,
int src_nd,
const py::ssize_t *src_shape_strides,
py::ssize_t dst_shape,
py::ssize_t dst_stride,
const std::vector<sycl::event> &depends)
Expand All @@ -407,17 +416,18 @@ sycl::event repeat_by_scalar_1d_impl(sycl::queue &q,
// orthog ndim indexer
TwoZeroOffsets_Indexer orthog_indexer{};
// indexers along repeated axis
Strided1DIndexer src_indexer(0, src_shape, src_stride);
StridedIndexer src_indexer(src_nd, 0, src_shape_strides);
Strided1DIndexer dst_indexer{0, dst_shape, dst_stride};

const size_t gws = dst_nelems;

cgh.parallel_for<repeat_by_scalar_kernel<TwoZeroOffsets_Indexer,
Strided1DIndexer, T>>(
cgh.parallel_for<repeat_by_scalar_kernel<
TwoZeroOffsets_Indexer, StridedIndexer, Strided1DIndexer, T>>(
sycl::range<1>(gws),
RepeatScalarFunctor<TwoZeroOffsets_Indexer, Strided1DIndexer, T>(
src_tp, dst_tp, reps, dst_nelems, orthog_indexer, src_indexer,
dst_indexer));
RepeatScalarFunctor<TwoZeroOffsets_Indexer, StridedIndexer,
Strided1DIndexer, T>(src_tp, dst_tp, reps,
dst_nelems, orthog_indexer,
src_indexer, dst_indexer));
});

return repeat_ev;
Expand Down
Loading