@@ -46,14 +46,16 @@ namespace py = pybind11;
4646using namespace dpctl ::tensor::offset_utils;
4747
4848template <typename OrthogIndexer,
49- typename AxisIndexer,
49+ typename SrcAxisIndexer,
50+ typename DstAxisIndexer,
5051 typename RepIndexer,
5152 typename T,
5253 typename repT>
5354class repeat_by_sequence_kernel ;
5455
5556template <typename OrthogIndexer,
56- typename AxisIndexer,
57+ typename SrcAxisIndexer,
58+ typename DstAxisIndexer,
5759 typename RepIndexer,
5860 typename T,
5961 typename repT>
@@ -66,8 +68,8 @@ class RepeatSequenceFunctor
6668 const repT *cumsum = nullptr ;
6769 size_t src_axis_nelems = 1 ;
6870 OrthogIndexer orthog_strider;
69- AxisIndexer src_axis_strider;
70- AxisIndexer dst_axis_strider;
71+ SrcAxisIndexer src_axis_strider;
72+ DstAxisIndexer dst_axis_strider;
7173 RepIndexer reps_strider;
7274
7375public:
@@ -77,8 +79,8 @@ class RepeatSequenceFunctor
7779 const repT *cumsum_,
7880 size_t src_axis_nelems_,
7981 OrthogIndexer orthog_strider_,
80- AxisIndexer src_axis_strider_,
81- AxisIndexer dst_axis_strider_,
82+ SrcAxisIndexer src_axis_strider_,
83+ DstAxisIndexer dst_axis_strider_,
8284 RepIndexer reps_strider_)
8385 : src(src_), dst(dst_), reps(reps_), cumsum(cumsum_),
8486 src_axis_nelems (src_axis_nelems_), orthog_strider(orthog_strider_),
@@ -167,12 +169,12 @@ repeat_by_sequence_impl(sycl::queue &q,
167169
168170 const size_t gws = orthog_nelems * src_axis_nelems;
169171
170- cgh.parallel_for <repeat_by_sequence_kernel<TwoOffsets_StridedIndexer,
171- Strided1DIndexer,
172- Strided1DIndexer, T, repT>>(
172+ cgh.parallel_for <repeat_by_sequence_kernel<
173+ TwoOffsets_StridedIndexer, Strided1DIndexer, Strided1DIndexer,
174+ Strided1DIndexer, T, repT>>(
173175 sycl::range<1 >(gws),
174176 RepeatSequenceFunctor<TwoOffsets_StridedIndexer, Strided1DIndexer,
175- Strided1DIndexer, T, repT>(
177+ Strided1DIndexer, Strided1DIndexer, T, repT>(
176178 src_tp, dst_tp, reps_tp, cumsum_tp, src_axis_nelems,
177179 orthog_indexer, src_axis_indexer, dst_axis_indexer,
178180 reps_indexer));
@@ -197,8 +199,8 @@ typedef sycl::event (*repeat_by_sequence_1d_fn_ptr_t)(
197199 char *,
198200 const char *,
199201 const char *,
200- py:: ssize_t ,
201- py::ssize_t ,
202+ int ,
203+ const py::ssize_t * ,
202204 py::ssize_t ,
203205 py::ssize_t ,
204206 py::ssize_t ,
@@ -212,8 +214,8 @@ sycl::event repeat_by_sequence_1d_impl(sycl::queue &q,
212214 char *dst_cp,
213215 const char *reps_cp,
214216 const char *cumsum_cp,
215- py:: ssize_t src_shape ,
216- py::ssize_t src_stride ,
217+ int src_nd ,
218+ const py::ssize_t *src_shape_strides ,
217219 py::ssize_t dst_shape,
218220 py::ssize_t dst_stride,
219221 py::ssize_t reps_shape,
@@ -231,19 +233,19 @@ sycl::event repeat_by_sequence_1d_impl(sycl::queue &q,
231233 // orthog ndim indexer
232234 TwoZeroOffsets_Indexer orthog_indexer{};
233235 // indexers along repeated axis
234- Strided1DIndexer src_indexer{0 , src_shape, src_stride };
236+ StridedIndexer src_indexer{src_nd, 0 , src_shape_strides };
235237 Strided1DIndexer dst_indexer{0 , dst_shape, dst_stride};
236238 // indexer along reps array
237239 Strided1DIndexer reps_indexer{0 , reps_shape, reps_stride};
238240
239241 const size_t gws = src_nelems;
240242
241- cgh.parallel_for <
242- repeat_by_sequence_kernel< TwoZeroOffsets_Indexer, Strided1DIndexer,
243- Strided1DIndexer, T, repT>>(
243+ cgh.parallel_for <repeat_by_sequence_kernel<
244+ TwoZeroOffsets_Indexer, StridedIndexer , Strided1DIndexer,
245+ Strided1DIndexer, T, repT>>(
244246 sycl::range<1 >(gws),
245- RepeatSequenceFunctor<TwoZeroOffsets_Indexer, Strided1DIndexer ,
246- Strided1DIndexer, T, repT>(
247+ RepeatSequenceFunctor<TwoZeroOffsets_Indexer, StridedIndexer ,
248+ Strided1DIndexer, Strided1DIndexer, T, repT>(
247249 src_tp, dst_tp, reps_tp, cumsum_tp, src_nelems, orthog_indexer,
248250 src_indexer, dst_indexer, reps_indexer));
249251 });
@@ -260,10 +262,16 @@ template <typename fnT, typename T> struct RepeatSequence1DFactory
260262 }
261263};
262264
263- template <typename OrthogIndexer, typename AxisIndexer, typename T>
265+ template <typename OrthogIndexer,
266+ typename SrcAxisIndexer,
267+ typename DstAxisIndexer,
268+ typename T>
264269class repeat_by_scalar_kernel ;
265270
266- template <typename OrthogIndexer, typename AxisIndexer, typename T>
271+ template <typename OrthogIndexer,
272+ typename SrcAxisIndexer,
273+ typename DstAxisIndexer,
274+ typename T>
267275class RepeatScalarFunctor
268276{
269277private:
@@ -272,17 +280,17 @@ class RepeatScalarFunctor
272280 const py::ssize_t reps = 1 ;
273281 size_t dst_axis_nelems = 0 ;
274282 OrthogIndexer orthog_strider;
275- AxisIndexer src_axis_strider;
276- AxisIndexer dst_axis_strider;
283+ SrcAxisIndexer src_axis_strider;
284+ DstAxisIndexer dst_axis_strider;
277285
278286public:
279287 RepeatScalarFunctor (const T *src_,
280288 T *dst_,
281289 const py::ssize_t reps_,
282290 size_t dst_axis_nelems_,
283291 OrthogIndexer orthog_strider_,
284- AxisIndexer src_axis_strider_,
285- AxisIndexer dst_axis_strider_)
292+ SrcAxisIndexer src_axis_strider_,
293+ DstAxisIndexer dst_axis_strider_)
286294 : src(src_), dst(dst_), reps(reps_), dst_axis_nelems(dst_axis_nelems_),
287295 orthog_strider (orthog_strider_), src_axis_strider(src_axis_strider_),
288296 dst_axis_strider(dst_axis_strider_)
@@ -354,10 +362,11 @@ sycl::event repeat_by_scalar_impl(sycl::queue &q,
354362
355363 const size_t gws = orthog_nelems * dst_axis_nelems;
356364
357- cgh.parallel_for <repeat_by_scalar_kernel<TwoOffsets_StridedIndexer,
358- Strided1DIndexer, T>>(
365+ cgh.parallel_for <repeat_by_scalar_kernel<
366+ TwoOffsets_StridedIndexer, Strided1DIndexer, Strided1DIndexer, T>>(
359367 sycl::range<1 >(gws),
360- RepeatScalarFunctor<TwoOffsets_StridedIndexer, Strided1DIndexer, T>(
368+ RepeatScalarFunctor<TwoOffsets_StridedIndexer, Strided1DIndexer,
369+ Strided1DIndexer, T>(
361370 src_tp, dst_tp, reps, dst_axis_nelems, orthog_indexer,
362371 src_axis_indexer, dst_axis_indexer));
363372 });
@@ -380,8 +389,8 @@ typedef sycl::event (*repeat_by_scalar_1d_fn_ptr_t)(
380389 const char *,
381390 char *,
382391 const py::ssize_t ,
383- py:: ssize_t ,
384- py::ssize_t ,
392+ int ,
393+ const py::ssize_t * ,
385394 py::ssize_t ,
386395 py::ssize_t ,
387396 const std::vector<sycl::event> &);
@@ -392,8 +401,8 @@ sycl::event repeat_by_scalar_1d_impl(sycl::queue &q,
392401 const char *src_cp,
393402 char *dst_cp,
394403 const py::ssize_t reps,
395- py:: ssize_t src_shape ,
396- py::ssize_t src_stride ,
404+ int src_nd ,
405+ const py::ssize_t *src_shape_strides ,
397406 py::ssize_t dst_shape,
398407 py::ssize_t dst_stride,
399408 const std::vector<sycl::event> &depends)
@@ -407,17 +416,18 @@ sycl::event repeat_by_scalar_1d_impl(sycl::queue &q,
407416 // orthog ndim indexer
408417 TwoZeroOffsets_Indexer orthog_indexer{};
409418 // indexers along repeated axis
410- Strided1DIndexer src_indexer (0 , src_shape, src_stride );
419+ StridedIndexer src_indexer (src_nd, 0 , src_shape_strides );
411420 Strided1DIndexer dst_indexer{0 , dst_shape, dst_stride};
412421
413422 const size_t gws = dst_nelems;
414423
415- cgh.parallel_for <repeat_by_scalar_kernel<TwoZeroOffsets_Indexer,
416- Strided1DIndexer, T>>(
424+ cgh.parallel_for <repeat_by_scalar_kernel<
425+ TwoZeroOffsets_Indexer, StridedIndexer, Strided1DIndexer, T>>(
417426 sycl::range<1 >(gws),
418- RepeatScalarFunctor<TwoZeroOffsets_Indexer, Strided1DIndexer, T>(
419- src_tp, dst_tp, reps, dst_nelems, orthog_indexer, src_indexer,
420- dst_indexer));
427+ RepeatScalarFunctor<TwoZeroOffsets_Indexer, StridedIndexer,
428+ Strided1DIndexer, T>(src_tp, dst_tp, reps,
429+ dst_nelems, orthog_indexer,
430+ src_indexer, dst_indexer));
421431 });
422432
423433 return repeat_ev;
0 commit comments