diff --git a/dpctl/tensor/_manipulation_functions.py b/dpctl/tensor/_manipulation_functions.py index 7201cd96fb..7135304b58 100644 --- a/dpctl/tensor/_manipulation_functions.py +++ b/dpctl/tensor/_manipulation_functions.py @@ -19,7 +19,6 @@ import operator import numpy as np -from numpy import AxisError from numpy.core.numeric import normalize_axis_index, normalize_axis_tuple import dpctl @@ -929,20 +928,26 @@ def repeat(x, repeats, axis=None): Args: x (usm_ndarray): input array - repeat (Union[int, Tuple[int, ...]]): + repeats (Union[int, Sequence[int, ...], usm_ndarray]): The number of repetitions for each element. - `repeats` is broadcasted to fit the shape of the given axis. + `repeats` is broadcast to fit the shape of the given axis. + If `repeats` is an array, it must have an integer data type. + Otherwise, `repeats` must be a Python integer, tuple, list, or + range. axis (Optional[int]): - The axis along which to repeat values. The `axis` is required - if input array has more than one dimension. + The axis along which to repeat values. If `axis` is `None`, the + function repeats elements of the flattened array. + Default: `None`. Returns: usm_narray: Array with repeated elements. - The returned array must have the same data type as `x`, - is created on the same device as `x` and has the same USM - allocation type as `x`. + The returned array must have the same data type as `x`, is created + on the same device as `x` and has the same USM allocation type as + `x`. If `axis` is `None`, the returned array is one-dimensional, + otherwise, it has the same shape as `x`, except for the axis along + which elements were repeated. Raises: AxisError: if `axis` value is invalid. @@ -951,20 +956,11 @@ def repeat(x, repeats, axis=None): raise TypeError(f"Expected usm_ndarray type, got {type(x)}.") x_ndim = x.ndim - if axis is None: - if x_ndim > 1: - raise ValueError( - f"`axis` cannot be `None` for array of dimension {x_ndim}" - ) - axis = 0 - x_shape = x.shape - if x_ndim > 0: + if axis is not None: axis = normalize_axis_index(operator.index(axis), x_ndim) axis_size = x_shape[axis] else: - if axis != 0: - AxisError("`axis` must be `0` for input of dimension `0`") axis_size = x.size scalar = False @@ -977,8 +973,8 @@ def repeat(x, repeats, axis=None): elif isinstance(repeats, dpt.usm_ndarray): if repeats.ndim > 1: raise ValueError( - "`repeats` array must be 0- or 1-dimensional, got" - "{repeats.ndim}" + "`repeats` array must be 0- or 1-dimensional, got " + f"{repeats.ndim}" ) exec_q = dpctl.utils.get_execution_queue( (x.sycl_queue, repeats.sycl_queue) @@ -1015,22 +1011,22 @@ def repeat(x, repeats, axis=None): if not dpt.all(repeats >= 0): raise ValueError("`repeats` elements must be positive") - elif isinstance(repeats, tuple): + elif isinstance(repeats, (tuple, list, range)): usm_type = x.usm_type exec_q = x.sycl_queue len_reps = len(repeats) - if len_reps != axis_size: - raise ValueError( - "`repeats` tuple must have the same length as the repeated " - "axis" - ) - elif len_reps == 1: + if len_reps == 1: repeats = repeats[0] if repeats < 0: raise ValueError("`repeats` elements must be positive") scalar = True else: + if len_reps != axis_size: + raise ValueError( + "`repeats` sequence must have the same length as the " + "repeated axis" + ) repeats = dpt.asarray( repeats, dtype=dpt.int64, usm_type=usm_type, sycl_queue=exec_q ) @@ -1038,7 +1034,7 @@ def repeat(x, repeats, axis=None): raise ValueError("`repeats` elements must be positive") else: raise TypeError( - "Expected int, tuple, or `usm_ndarray` for second argument," + "Expected int, sequence, or `usm_ndarray` for second argument," f"got {type(repeats)}" ) @@ -1047,7 +1043,10 @@ def repeat(x, repeats, axis=None): if scalar: res_axis_size = repeats * axis_size - res_shape = x_shape[:axis] + (res_axis_size,) + x_shape[axis + 1 :] + if axis is not None: + res_shape = x_shape[:axis] + (res_axis_size,) + x_shape[axis + 1 :] + else: + res_shape = (res_axis_size,) res = dpt.empty( res_shape, dtype=x.dtype, usm_type=usm_type, sycl_queue=exec_q ) @@ -1081,9 +1080,17 @@ def repeat(x, repeats, axis=None): res_axis_size = ti._cumsum_1d( rep_buf, cumsum, sycl_queue=exec_q, depends=[copy_ev] ) - res_shape = x_shape[:axis] + (res_axis_size,) + x_shape[axis + 1 :] + if axis is not None: + res_shape = ( + x_shape[:axis] + (res_axis_size,) + x_shape[axis + 1 :] + ) + else: + res_shape = (res_axis_size,) res = dpt.empty( - res_shape, dtype=x.dtype, usm_type=usm_type, sycl_queue=exec_q + res_shape, + dtype=x.dtype, + usm_type=usm_type, + sycl_queue=exec_q, ) if res_axis_size > 0: ht_rep_ev, _ = ti._repeat_by_sequence( @@ -1103,11 +1110,18 @@ def repeat(x, repeats, axis=None): usm_type=usm_type, sycl_queue=exec_q, ) - # _cumsum_1d synchronizes so `depends` ends here safely res_axis_size = ti._cumsum_1d(repeats, cumsum, sycl_queue=exec_q) - res_shape = x_shape[:axis] + (res_axis_size,) + x_shape[axis + 1 :] + if axis is not None: + res_shape = ( + x_shape[:axis] + (res_axis_size,) + x_shape[axis + 1 :] + ) + else: + res_shape = (res_axis_size,) res = dpt.empty( - res_shape, dtype=x.dtype, usm_type=usm_type, sycl_queue=exec_q + res_shape, + dtype=x.dtype, + usm_type=usm_type, + sycl_queue=exec_q, ) if res_axis_size > 0: ht_rep_ev, _ = ti._repeat_by_sequence( diff --git a/dpctl/tensor/libtensor/include/kernels/repeat.hpp b/dpctl/tensor/libtensor/include/kernels/repeat.hpp index da1989fc3c..1f2335fc6c 100644 --- a/dpctl/tensor/libtensor/include/kernels/repeat.hpp +++ b/dpctl/tensor/libtensor/include/kernels/repeat.hpp @@ -46,14 +46,16 @@ namespace py = pybind11; using namespace dpctl::tensor::offset_utils; template class repeat_by_sequence_kernel; template @@ -66,8 +68,8 @@ class RepeatSequenceFunctor const repT *cumsum = nullptr; size_t src_axis_nelems = 1; OrthogIndexer orthog_strider; - AxisIndexer src_axis_strider; - AxisIndexer dst_axis_strider; + SrcAxisIndexer src_axis_strider; + DstAxisIndexer dst_axis_strider; RepIndexer reps_strider; public: @@ -77,8 +79,8 @@ class RepeatSequenceFunctor const repT *cumsum_, size_t src_axis_nelems_, OrthogIndexer orthog_strider_, - AxisIndexer src_axis_strider_, - AxisIndexer dst_axis_strider_, + SrcAxisIndexer src_axis_strider_, + DstAxisIndexer dst_axis_strider_, RepIndexer reps_strider_) : src(src_), dst(dst_), reps(reps_), cumsum(cumsum_), src_axis_nelems(src_axis_nelems_), orthog_strider(orthog_strider_), @@ -167,12 +169,12 @@ repeat_by_sequence_impl(sycl::queue &q, const size_t gws = orthog_nelems * src_axis_nelems; - cgh.parallel_for>( + cgh.parallel_for>( sycl::range<1>(gws), RepeatSequenceFunctor( + Strided1DIndexer, Strided1DIndexer, T, repT>( src_tp, dst_tp, reps_tp, cumsum_tp, src_axis_nelems, orthog_indexer, src_axis_indexer, dst_axis_indexer, reps_indexer)); @@ -197,8 +199,8 @@ typedef sycl::event (*repeat_by_sequence_1d_fn_ptr_t)( char *, const char *, const char *, - py::ssize_t, - py::ssize_t, + int, + const py::ssize_t *, py::ssize_t, py::ssize_t, py::ssize_t, @@ -212,8 +214,8 @@ sycl::event repeat_by_sequence_1d_impl(sycl::queue &q, char *dst_cp, const char *reps_cp, const char *cumsum_cp, - py::ssize_t src_shape, - py::ssize_t src_stride, + int src_nd, + const py::ssize_t *src_shape_strides, py::ssize_t dst_shape, py::ssize_t dst_stride, py::ssize_t reps_shape, @@ -231,19 +233,19 @@ sycl::event repeat_by_sequence_1d_impl(sycl::queue &q, // orthog ndim indexer TwoZeroOffsets_Indexer orthog_indexer{}; // indexers along repeated axis - Strided1DIndexer src_indexer{0, src_shape, src_stride}; + StridedIndexer src_indexer{src_nd, 0, src_shape_strides}; Strided1DIndexer dst_indexer{0, dst_shape, dst_stride}; // indexer along reps array Strided1DIndexer reps_indexer{0, reps_shape, reps_stride}; const size_t gws = src_nelems; - cgh.parallel_for< - repeat_by_sequence_kernel>( + cgh.parallel_for>( sycl::range<1>(gws), - RepeatSequenceFunctor( + RepeatSequenceFunctor( src_tp, dst_tp, reps_tp, cumsum_tp, src_nelems, orthog_indexer, src_indexer, dst_indexer, reps_indexer)); }); @@ -260,10 +262,16 @@ template struct RepeatSequence1DFactory } }; -template +template class repeat_by_scalar_kernel; -template +template class RepeatScalarFunctor { private: @@ -272,8 +280,8 @@ class RepeatScalarFunctor const py::ssize_t reps = 1; size_t dst_axis_nelems = 0; OrthogIndexer orthog_strider; - AxisIndexer src_axis_strider; - AxisIndexer dst_axis_strider; + SrcAxisIndexer src_axis_strider; + DstAxisIndexer dst_axis_strider; public: RepeatScalarFunctor(const T *src_, @@ -281,8 +289,8 @@ class RepeatScalarFunctor const py::ssize_t reps_, size_t dst_axis_nelems_, OrthogIndexer orthog_strider_, - AxisIndexer src_axis_strider_, - AxisIndexer dst_axis_strider_) + SrcAxisIndexer src_axis_strider_, + DstAxisIndexer dst_axis_strider_) : src(src_), dst(dst_), reps(reps_), dst_axis_nelems(dst_axis_nelems_), orthog_strider(orthog_strider_), src_axis_strider(src_axis_strider_), dst_axis_strider(dst_axis_strider_) @@ -354,10 +362,11 @@ sycl::event repeat_by_scalar_impl(sycl::queue &q, const size_t gws = orthog_nelems * dst_axis_nelems; - cgh.parallel_for>( + cgh.parallel_for>( sycl::range<1>(gws), - RepeatScalarFunctor( + RepeatScalarFunctor( src_tp, dst_tp, reps, dst_axis_nelems, orthog_indexer, src_axis_indexer, dst_axis_indexer)); }); @@ -380,8 +389,8 @@ typedef sycl::event (*repeat_by_scalar_1d_fn_ptr_t)( const char *, char *, const py::ssize_t, - py::ssize_t, - py::ssize_t, + int, + const py::ssize_t *, py::ssize_t, py::ssize_t, const std::vector &); @@ -392,8 +401,8 @@ sycl::event repeat_by_scalar_1d_impl(sycl::queue &q, const char *src_cp, char *dst_cp, const py::ssize_t reps, - py::ssize_t src_shape, - py::ssize_t src_stride, + int src_nd, + const py::ssize_t *src_shape_strides, py::ssize_t dst_shape, py::ssize_t dst_stride, const std::vector &depends) @@ -407,17 +416,18 @@ sycl::event repeat_by_scalar_1d_impl(sycl::queue &q, // orthog ndim indexer TwoZeroOffsets_Indexer orthog_indexer{}; // indexers along repeated axis - Strided1DIndexer src_indexer(0, src_shape, src_stride); + StridedIndexer src_indexer(src_nd, 0, src_shape_strides); Strided1DIndexer dst_indexer{0, dst_shape, dst_stride}; const size_t gws = dst_nelems; - cgh.parallel_for>( + cgh.parallel_for>( sycl::range<1>(gws), - RepeatScalarFunctor( - src_tp, dst_tp, reps, dst_nelems, orthog_indexer, src_indexer, - dst_indexer)); + RepeatScalarFunctor(src_tp, dst_tp, reps, + dst_nelems, orthog_indexer, + src_indexer, dst_indexer)); }); return repeat_ev; diff --git a/dpctl/tensor/libtensor/source/repeat.cpp b/dpctl/tensor/libtensor/source/repeat.cpp index 0dbfb17a5d..3b1c956dd4 100644 --- a/dpctl/tensor/libtensor/source/repeat.cpp +++ b/dpctl/tensor/libtensor/source/repeat.cpp @@ -237,18 +237,37 @@ py_repeat_by_sequence(const dpctl::tensor::usm_ndarray &src, assert(dst_shape_vec.size() == 1); assert(dst_strides_vec.size() == 1); - py::ssize_t src_shape(0); - py::ssize_t src_stride(0); - if (src_nd > 0) { - src_shape = src_shape_vec[0]; - src_stride = src_strides_vec[0]; + if (src_nd == 0) { + src_shape_vec = {0}; + src_strides_vec = {0}; + } + + using dpctl::tensor::offset_utils::device_allocate_and_pack; + const auto &ptr_size_event_tuple1 = + device_allocate_and_pack( + exec_q, host_task_events, src_shape_vec, src_strides_vec); + py::ssize_t *packed_src_shape_strides = + std::get<0>(ptr_size_event_tuple1); + if (packed_src_shape_strides == nullptr) { + throw std::runtime_error("Unable to allocate device memory"); } + sycl::event copy_shapes_strides_ev = std::get<2>(ptr_size_event_tuple1); - sycl::event repeat_ev = + repeat_ev = fn(exec_q, src_axis_nelems, src_data_p, dst_data_p, reps_data_p, - cumsum_data_p, src_shape, src_stride, dst_shape_vec[0], - dst_strides_vec[0], reps_shape_vec[0], reps_strides_vec[0], - depends); + cumsum_data_p, src_nd, packed_src_shape_strides, + dst_shape_vec[0], dst_strides_vec[0], reps_shape_vec[0], + reps_strides_vec[0], depends); + + sycl::event cleanup_tmp_allocations_ev = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(repeat_ev); + const auto &ctx = exec_q.get_context(); + cgh.host_task([ctx, packed_src_shape_strides] { + sycl::free(packed_src_shape_strides, ctx); + }); + }); + host_task_events.push_back(cleanup_tmp_allocations_ev); } else { // non-empty othogonal directions @@ -343,6 +362,162 @@ py_repeat_by_sequence(const dpctl::tensor::usm_ndarray &src, return std::make_pair(py_obj_management_host_task_ev, repeat_ev); } +std::pair +py_repeat_by_sequence(const dpctl::tensor::usm_ndarray &src, + const dpctl::tensor::usm_ndarray &dst, + const dpctl::tensor::usm_ndarray &reps, + const dpctl::tensor::usm_ndarray &cumsum, + sycl::queue &exec_q, + const std::vector &depends) +{ + + int dst_nd = dst.get_ndim(); + if (dst_nd != 1) { + throw py::value_error( + "`dst` array must be 1-dimensional when repeating a full array"); + } + + int reps_nd = reps.get_ndim(); + if (reps_nd != 1) { + throw py::value_error("`reps` array must be 1-dimensional"); + } + + if (cumsum.get_ndim() != 1) { + throw py::value_error("`cumsum` array must be 1-dimensional."); + } + + if (!cumsum.is_c_contiguous()) { + throw py::value_error("Expecting `cumsum` array to be C-contiguous."); + } + + if (!dpctl::utils::queues_are_compatible(exec_q, {src, reps, cumsum, dst})) + { + throw py::value_error( + "Execution queue is not compatible with allocation queues"); + } + + size_t src_sz = src.get_size(); + size_t reps_sz = reps.get_size(); + size_t cumsum_sz = cumsum.get_size(); + + // shape at repeated axis must be equal to the sum of reps + if (src_sz != reps_sz || src_sz != cumsum_sz) { + throw py::value_error("Inconsistent array dimensions"); + } + + if (src_sz == 0) { + return std::make_pair(sycl::event(), sycl::event()); + } + + // ensure that dst is sufficiently ample + auto dst_offsets = dst.get_minmax_offsets(); + // destination must be ample enough to accommodate all elements + { + size_t range = + static_cast(dst_offsets.second - dst_offsets.first); + if (range + 1 < static_cast(dst.get_size())) { + throw py::value_error( + "Memory addressed by the destination array can not " + "accommodate all the " + "array elements."); + } + } + + auto const &overlap = dpctl::tensor::overlap::MemoryOverlap(); + // check that dst does not intersect with src, cumsum, or reps + if (overlap(dst, src) || overlap(dst, reps) || overlap(dst, cumsum)) { + throw py::value_error("Destination array overlaps with inputs"); + } + + int src_typenum = src.get_typenum(); + int dst_typenum = dst.get_typenum(); + int reps_typenum = reps.get_typenum(); + int cumsum_typenum = cumsum.get_typenum(); + + auto const &array_types = td_ns::usm_ndarray_types(); + int src_typeid = array_types.typenum_to_lookup_id(src_typenum); + int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum); + int reps_typeid = array_types.typenum_to_lookup_id(reps_typenum); + int cumsum_typeid = array_types.typenum_to_lookup_id(cumsum_typenum); + + if (src_typeid != dst_typeid) { + throw py::value_error( + "Destination array must have the same elemental data type"); + } + + constexpr int int64_typeid = static_cast(td_ns::typenum_t::INT64); + if (cumsum_typeid != int64_typeid) { + throw py::value_error( + "Unexpected data type of `cumsum` array, expecting " + "'int64'"); + } + + if (reps_typeid != cumsum_typeid) { + throw py::value_error("`reps` array must have the same elemental " + "data type as cumsum"); + } + + const char *src_data_p = src.get_data(); + const char *reps_data_p = reps.get_data(); + const char *cumsum_data_p = cumsum.get_data(); + char *dst_data_p = dst.get_data(); + + int src_nd = src.get_ndim(); + auto src_shape_vec = src.get_shape_vector(); + auto src_strides_vec = src.get_strides_vector(); + if (src_nd == 0) { + src_shape_vec = {0}; + src_strides_vec = {0}; + } + + auto dst_shape_vec = dst.get_shape_vector(); + auto dst_strides_vec = dst.get_strides_vector(); + + auto reps_shape_vec = reps.get_shape_vector(); + auto reps_strides_vec = reps.get_strides_vector(); + + std::vector host_task_events{}; + + auto fn = repeat_by_sequence_1d_dispatch_vector[src_typeid]; + + using dpctl::tensor::offset_utils::device_allocate_and_pack; + const auto &ptr_size_event_tuple1 = device_allocate_and_pack( + exec_q, host_task_events, src_shape_vec, src_strides_vec); + py::ssize_t *packed_src_shapes_strides = std::get<0>(ptr_size_event_tuple1); + if (packed_src_shapes_strides == nullptr) { + throw std::runtime_error("Unable to allocate device memory"); + } + sycl::event copy_shapes_strides_ev = std::get<2>(ptr_size_event_tuple1); + + std::vector all_deps; + all_deps.reserve(depends.size() + 1); + all_deps.insert(all_deps.end(), depends.begin(), depends.end()); + all_deps.push_back(copy_shapes_strides_ev); + + assert(all_deps.size() == depends.size() + 1); + + sycl::event repeat_ev = + fn(exec_q, src_sz, src_data_p, dst_data_p, reps_data_p, cumsum_data_p, + src_nd, packed_src_shapes_strides, dst_shape_vec[0], + dst_strides_vec[0], reps_shape_vec[0], reps_strides_vec[0], depends); + + sycl::event cleanup_tmp_allocations_ev = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(repeat_ev); + const auto &ctx = exec_q.get_context(); + cgh.host_task([ctx, packed_src_shapes_strides] { + sycl::free(packed_src_shapes_strides, ctx); + }); + }); + host_task_events.push_back(cleanup_tmp_allocations_ev); + host_task_events.push_back(repeat_ev); + + sycl::event py_obj_management_host_task_ev = dpctl::utils::keep_args_alive( + exec_q, {src, reps, cumsum, dst}, host_task_events); + + return std::make_pair(py_obj_management_host_task_ev, repeat_ev); +} + std::pair py_repeat_by_scalar(const dpctl::tensor::usm_ndarray &src, const dpctl::tensor::usm_ndarray &dst, @@ -452,15 +627,42 @@ py_repeat_by_scalar(const dpctl::tensor::usm_ndarray &src, assert(dst_shape_vec.size() == 1); assert(dst_strides_vec.size() == 1); - py::ssize_t src_shape(0); - py::ssize_t src_stride(0); if (src_nd > 0) { - src_shape = src_shape_vec[0]; - src_stride = src_strides_vec[0]; + src_shape_vec = {0}; + src_strides_vec = {0}; } - sycl::event repeat_ev = - fn(exec_q, dst_axis_nelems, src_data_p, dst_data_p, reps, src_shape, - src_stride, dst_shape_vec[0], dst_strides_vec[0], depends); + + using dpctl::tensor::offset_utils::device_allocate_and_pack; + const auto &ptr_size_event_tuple1 = + device_allocate_and_pack( + exec_q, host_task_events, src_shape_vec, src_strides_vec); + py::ssize_t *packed_src_shape_strides = + std::get<0>(ptr_size_event_tuple1); + if (packed_src_shape_strides == nullptr) { + throw std::runtime_error("Unable to allocate device memory"); + } + sycl::event copy_shapes_strides_ev = std::get<2>(ptr_size_event_tuple1); + + std::vector all_deps; + all_deps.reserve(depends.size() + 1); + all_deps.insert(all_deps.end(), depends.begin(), depends.end()); + all_deps.push_back(copy_shapes_strides_ev); + + assert(all_deps.size() == depends.size() + 1); + + repeat_ev = fn(exec_q, dst_axis_nelems, src_data_p, dst_data_p, reps, + src_nd, packed_src_shape_strides, dst_shape_vec[0], + dst_strides_vec[0], depends); + + sycl::event cleanup_tmp_allocations_ev = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(repeat_ev); + const auto &ctx = exec_q.get_context(); + cgh.host_task([ctx, packed_src_shape_strides] { + sycl::free(packed_src_shape_strides, ctx); + }); + }); + host_task_events.push_back(cleanup_tmp_allocations_ev); } else { // non-empty othogonal directions @@ -554,6 +756,126 @@ py_repeat_by_scalar(const dpctl::tensor::usm_ndarray &src, return std::make_pair(py_obj_management_host_task_ev, repeat_ev); } +std::pair +py_repeat_by_scalar(const dpctl::tensor::usm_ndarray &src, + const dpctl::tensor::usm_ndarray &dst, + const py::ssize_t reps, + sycl::queue &exec_q, + const std::vector &depends) +{ + int dst_nd = dst.get_ndim(); + if (dst_nd != 1) { + throw py::value_error( + "`dst` array must be 1-dimensional when repeating a full array"); + } + + if (!dpctl::utils::queues_are_compatible(exec_q, {src, dst})) { + throw py::value_error( + "Execution queue is not compatible with allocation queues"); + } + + size_t src_sz = src.get_size(); + size_t dst_sz = dst.get_size(); + + // shape at repeated axis must be equal to the shape of src at the axis * + // reps + if ((src_sz * reps) != dst_sz) { + throw py::value_error("Inconsistent array dimensions"); + } + + if (src_sz == 0) { + return std::make_pair(sycl::event(), sycl::event()); + } + + // ensure that dst is sufficiently ample + auto dst_offsets = dst.get_minmax_offsets(); + // destination must be ample enough to accommodate all elements + { + size_t range = + static_cast(dst_offsets.second - dst_offsets.first); + if (range + 1 < static_cast(src_sz * reps)) { + throw py::value_error( + "Memory addressed by the destination array can not " + "accommodate all the " + "array elements."); + } + } + + auto const &overlap = dpctl::tensor::overlap::MemoryOverlap(); + // check that dst does not intersect with src + if (overlap(dst, src)) { + throw py::value_error("Destination array overlaps with inputs"); + } + + int src_typenum = src.get_typenum(); + int dst_typenum = dst.get_typenum(); + + auto const &array_types = td_ns::usm_ndarray_types(); + int src_typeid = array_types.typenum_to_lookup_id(src_typenum); + int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum); + + if (src_typeid != dst_typeid) { + throw py::value_error( + "Destination array must have the same elemental data type"); + } + + const char *src_data_p = src.get_data(); + char *dst_data_p = dst.get_data(); + + int src_nd = src.get_ndim(); + auto src_shape_vec = src.get_shape_vector(); + auto src_strides_vec = src.get_strides_vector(); + + if (src_nd == 0) { + src_shape_vec = {0}; + src_strides_vec = {0}; + } + + auto dst_shape_vec = dst.get_shape_vector(); + auto dst_strides_vec = dst.get_strides_vector(); + + std::vector host_task_events{}; + + auto fn = repeat_by_scalar_1d_dispatch_vector[src_typeid]; + + using dpctl::tensor::offset_utils::device_allocate_and_pack; + const auto &ptr_size_event_tuple1 = device_allocate_and_pack( + exec_q, host_task_events, src_shape_vec, src_strides_vec); + py::ssize_t *packed_src_shape_strides = std::get<0>(ptr_size_event_tuple1); + if (packed_src_shape_strides == nullptr) { + throw std::runtime_error("Unable to allocate device memory"); + } + sycl::event copy_shapes_strides_ev = std::get<2>(ptr_size_event_tuple1); + + std::vector all_deps; + all_deps.reserve(depends.size() + 1); + all_deps.insert(all_deps.end(), depends.begin(), depends.end()); + all_deps.push_back(copy_shapes_strides_ev); + + assert(all_deps.size() == depends.size() + 1); + + sycl::event repeat_ev = fn(exec_q, dst_sz, src_data_p, dst_data_p, reps, + src_nd, packed_src_shape_strides, + dst_shape_vec[0], dst_strides_vec[0], depends); + + sycl::event cleanup_tmp_allocations_ev = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(repeat_ev); + const auto &ctx = exec_q.get_context(); + cgh.host_task([ctx, packed_src_shape_strides] { + sycl::free(packed_src_shape_strides, ctx); + }); + }); + + host_task_events.push_back(cleanup_tmp_allocations_ev); + host_task_events.push_back(repeat_ev); + + sycl::event py_obj_management_host_task_ev = + dpctl::utils::keep_args_alive(exec_q, {src, dst}, host_task_events); + + return std::make_pair(py_obj_management_host_task_ev, repeat_ev); +} + } // namespace py_internal } // namespace tensor } // namespace dpctl diff --git a/dpctl/tensor/libtensor/source/repeat.hpp b/dpctl/tensor/libtensor/source/repeat.hpp index 87fb0a0847..65ace36516 100644 --- a/dpctl/tensor/libtensor/source/repeat.hpp +++ b/dpctl/tensor/libtensor/source/repeat.hpp @@ -48,6 +48,14 @@ py_repeat_by_sequence(const dpctl::tensor::usm_ndarray &src, sycl::queue &exec_q, const std::vector &depends); +extern std::pair +py_repeat_by_sequence(const dpctl::tensor::usm_ndarray &src, + const dpctl::tensor::usm_ndarray &dst, + const dpctl::tensor::usm_ndarray &reps, + const dpctl::tensor::usm_ndarray &cumsum, + sycl::queue &exec_q, + const std::vector &depends); + extern std::pair py_repeat_by_scalar(const dpctl::tensor::usm_ndarray &src, const dpctl::tensor::usm_ndarray &dst, @@ -56,6 +64,13 @@ py_repeat_by_scalar(const dpctl::tensor::usm_ndarray &src, sycl::queue &exec_q, const std::vector &depends); +extern std::pair +py_repeat_by_scalar(const dpctl::tensor::usm_ndarray &src, + const dpctl::tensor::usm_ndarray &dst, + const py::ssize_t reps, + sycl::queue &exec_q, + const std::vector &depends); + } // namespace py_internal } // namespace tensor } // namespace dpctl diff --git a/dpctl/tensor/libtensor/source/tensor_py.cpp b/dpctl/tensor/libtensor/source/tensor_py.cpp index 2ce7c72add..b7614b8bf1 100644 --- a/dpctl/tensor/libtensor/source/tensor_py.cpp +++ b/dpctl/tensor/libtensor/source/tensor_py.cpp @@ -30,6 +30,7 @@ #include #include #include +#include #include "dpctl4pybind11.hpp" @@ -402,13 +403,43 @@ PYBIND11_MODULE(_tensor_impl, m) py::arg("x2"), py::arg("dst"), py::arg("sycl_queue"), py::arg("depends") = py::list()); - m.def("_repeat_by_sequence", &py_repeat_by_sequence, "", py::arg("src"), + auto repeat_sequence = [](const dpctl::tensor::usm_ndarray &src, + const dpctl::tensor::usm_ndarray &dst, + const dpctl::tensor::usm_ndarray &reps, + const dpctl::tensor::usm_ndarray &cumsum, + std::optional axis, sycl::queue &exec_q, + const std::vector depends) + -> std::pair { + if (axis) { + return py_repeat_by_sequence(src, dst, reps, cumsum, axis.value(), + exec_q, depends); + } + else { + return py_repeat_by_sequence(src, dst, reps, cumsum, exec_q, + depends); + } + }; + m.def("_repeat_by_sequence", repeat_sequence, py::arg("src"), py::arg("dst"), py::arg("reps"), py::arg("cumsum"), py::arg("axis"), py::arg("sycl_queue"), py::arg("depends") = py::list()); - m.def("_repeat_by_scalar", &py_repeat_by_scalar, "", py::arg("src"), - py::arg("dst"), py::arg("reps"), py::arg("axis"), - py::arg("sycl_queue"), py::arg("depends") = py::list()); + auto repeat_scalar = [](const dpctl::tensor::usm_ndarray &src, + const dpctl::tensor::usm_ndarray &dst, + const py::ssize_t reps, std::optional axis, + sycl::queue &exec_q, + const std::vector depends) + -> std::pair { + if (axis) { + return py_repeat_by_scalar(src, dst, reps, axis.value(), exec_q, + depends); + } + else { + return py_repeat_by_scalar(src, dst, reps, exec_q, depends); + } + }; + m.def("_repeat_by_scalar", repeat_scalar, py::arg("src"), py::arg("dst"), + py::arg("reps"), py::arg("axis"), py::arg("sycl_queue"), + py::arg("depends") = py::list()); dpctl::tensor::py_internal::init_elementwise_functions(m); dpctl::tensor::py_internal::init_boolean_reduction_functions(m); diff --git a/dpctl/tests/test_usm_ndarray_manipulation.py b/dpctl/tests/test_usm_ndarray_manipulation.py index 2126727d5b..ae32afdba9 100644 --- a/dpctl/tests/test_usm_ndarray_manipulation.py +++ b/dpctl/tests/test_usm_ndarray_manipulation.py @@ -1193,11 +1193,17 @@ def test_repeat_size_0_outputs(): assert res.size == 0 assert res.shape == (3, 0, 5) - x = dpt.ones((3, 2, 5)) res = dpt.repeat(x, (0, 0), axis=1) assert res.size == 0 assert res.shape == (3, 0, 5) + # axis=None cases + res = dpt.repeat(x, 0) + assert res.size == 0 + + res = dpt.repeat(x, (0,) * x.size) + assert res.size == 0 + def test_repeat_strides(): get_queue_or_skip() @@ -1220,6 +1226,17 @@ def test_repeat_strides(): res = dpt.repeat(x1, (reps,) * x1.shape[0], axis=0) assert dpt.all(res == expected_res) + # axis=None + x = dpt.reshape(dpt.arange(10 * 10), (10, 10)) + x1 = dpt.reshape(x[::-2, :], -1) + x2 = x[::-2, :] + expected_res = dpt.empty(10 * 10, dtype="i4") + expected_res[::2], expected_res[1::2] = x1, x1 + res = dpt.repeat(x2, reps) + assert dpt.all(res == expected_res) + res = dpt.repeat(x2, (reps,) * x1.size) + assert dpt.all(res == expected_res) + def test_repeat_casting(): get_queue_or_skip() @@ -1256,11 +1273,6 @@ def test_repeat_arg_validation(): with pytest.raises(ValueError): dpt.repeat(x, 2, axis=1) - # x.ndim cannot be > 1 for axis=None - x = dpt.empty((5, 10)) - with pytest.raises(ValueError): - dpt.repeat(x, 2, axis=None) - # repeats must be positive x = dpt.empty(5) with pytest.raises(ValueError):