@@ -746,7 +746,85 @@ template <typename fnT, typename Ty> struct CopyForReshapeGenericFactory
746746 }
747747};
748748
749- // =============== Copying for reshape ================== //
749+ // ================== Copying for roll ================== //
750+
751+ /* ! @brief Functor to cyclically roll global_id to the left */
752+ struct LeftRolled1DTransformer
753+ {
754+ LeftRolled1DTransformer (size_t offset, size_t size)
755+ : offset_(offset), size_(size)
756+ {
757+ }
758+
759+ size_t operator ()(size_t gid) const
760+ {
761+ const size_t shifted_gid =
762+ ((gid < offset_) ? gid + size_ - offset_ : gid - offset_);
763+ return shifted_gid;
764+ }
765+
766+ private:
767+ size_t offset_ = 0 ;
768+ size_t size_ = 1 ;
769+ };
770+
771+ /* ! @brief Indexer functor to compose indexer and transformer */
772+ template <typename IndexerT, typename TransformerT> struct CompositionIndexer
773+ {
774+ CompositionIndexer (IndexerT f, TransformerT t) : f_(f), t_(t) {}
775+
776+ auto operator ()(size_t gid) const
777+ {
778+ return f_ (t_ (gid));
779+ }
780+
781+ private:
782+ IndexerT f_;
783+ TransformerT t_;
784+ };
785+
786+ /* ! @brief Indexer functor to find offset for nd-shifted indices lifted from
787+ * iteration id */
788+ struct RolledNDIndexer
789+ {
790+ RolledNDIndexer (int nd,
791+ const py::ssize_t *shape,
792+ const py::ssize_t *strides,
793+ const py::ssize_t *ndshifts,
794+ py::ssize_t starting_offset)
795+ : nd_(nd), shape_(shape), strides_(strides), ndshifts_(ndshifts),
796+ starting_offset_ (starting_offset)
797+ {
798+ }
799+
800+ py::ssize_t operator ()(size_t gid) const
801+ {
802+ return compute_offset (gid);
803+ }
804+
805+ private:
806+ int nd_ = -1 ;
807+ const py::ssize_t *shape_ = nullptr ;
808+ const py::ssize_t *strides_ = nullptr ;
809+ const py::ssize_t *ndshifts_ = nullptr ;
810+ py::ssize_t starting_offset_ = 0 ;
811+
812+ py::ssize_t compute_offset (py::ssize_t gid) const
813+ {
814+ using dpctl::tensor::strides::CIndexer_vector;
815+
816+ CIndexer_vector _ind (nd_);
817+ py::ssize_t relative_offset_ (0 );
818+ _ind.get_left_rolled_displacement <const py::ssize_t *,
819+ const py::ssize_t *>(
820+ gid,
821+ shape_, // shape ptr
822+ strides_, // strides ptr
823+ ndshifts_, // shifts ptr
824+ relative_offset_);
825+ return starting_offset_ + relative_offset_;
826+ }
827+ };
750828
751829template <typename Ty, typename SrcIndexerT, typename DstIndexerT>
752830class copy_for_roll_strided_kernel ;
@@ -755,32 +833,26 @@ template <typename Ty, typename SrcIndexerT, typename DstIndexerT>
755833class StridedCopyForRollFunctor
756834{
757835private:
758- size_t offset = 0 ;
759- size_t size = 1 ;
760836 const Ty *src_p = nullptr ;
761837 Ty *dst_p = nullptr ;
762838 SrcIndexerT src_indexer_;
763839 DstIndexerT dst_indexer_;
764840
765841public:
766- StridedCopyForRollFunctor (size_t shift,
767- size_t nelems,
768- const Ty *src_ptr,
842+ StridedCopyForRollFunctor (const Ty *src_ptr,
769843 Ty *dst_ptr,
770844 SrcIndexerT src_indexer,
771845 DstIndexerT dst_indexer)
772- : offset(shift), size(nelems), src_p(src_ptr), dst_p(dst_ptr),
773- src_indexer_ (src_indexer), dst_indexer_(dst_indexer)
846+ : src_p(src_ptr), dst_p(dst_ptr), src_indexer_(src_indexer ),
847+ dst_indexer_ (dst_indexer)
774848 {
775849 }
776850
777851 void operator ()(sycl::id<1 > wiid) const
778852 {
779853 const size_t gid = wiid.get (0 );
780- const size_t shifted_gid =
781- ((gid < offset) ? gid + size - offset : gid - offset);
782854
783- const py::ssize_t src_offset = src_indexer_ (shifted_gid );
855+ const py::ssize_t src_offset = src_indexer_ (gid );
784856 const py::ssize_t dst_offset = dst_indexer_ (gid);
785857
786858 dst_p[dst_offset] = src_p[src_offset];
@@ -800,8 +872,6 @@ typedef sycl::event (*copy_for_roll_strided_fn_ptr_t)(
800872 py::ssize_t , // dst_offset
801873 const std::vector<sycl::event> &);
802874
803- template <typename Ty> class copy_for_roll_contig_kernel ;
804-
805875/* !
806876 * @brief Function to copy content of array with a shift.
807877 *
@@ -812,8 +882,8 @@ template <typename Ty> class copy_for_roll_contig_kernel;
812882 * @param shift The shift in flat indexing, must be non-negative.
813883 * @param nelems The number of elements to copy
814884 * @param nd Array dimensionality of the destination and source arrays
815- * @param packed_shapes_and_strides Kernel accessible USM array of size
816- * `3*nd` with content `[common_shape, src_strides, dst_strides]`.
885+ * @param packed_shapes_and_strides Kernel accessible USM array
886+ * of size `3*nd` with content `[common_shape, src_strides, dst_strides]`.
817887 * @param src_p Typeless USM pointer to the buffer of the source array
818888 * @param src_offset Displacement of first element of src relative src_p in
819889 * elements
@@ -849,21 +919,29 @@ copy_for_roll_strided_impl(sycl::queue q,
849919 // [ common_shape; src_strides; dst_strides ]
850920
851921 StridedIndexer src_indexer{nd, src_offset, packed_shapes_and_strides};
922+ LeftRolled1DTransformer left_roll_transformer{shift, nelems};
923+
924+ using CompositeIndexerT =
925+ CompositionIndexer<StridedIndexer, LeftRolled1DTransformer>;
926+
927+ CompositeIndexerT rolled_src_indexer (src_indexer,
928+ left_roll_transformer);
929+
852930 UnpackedStridedIndexer dst_indexer{nd, dst_offset,
853931 packed_shapes_and_strides,
854932 packed_shapes_and_strides + 2 * nd};
855933
856- using KernelName = copy_for_roll_strided_kernel<Ty, StridedIndexer ,
934+ using KernelName = copy_for_roll_strided_kernel<Ty, CompositeIndexerT ,
857935 UnpackedStridedIndexer>;
858936
859937 const Ty *src_tp = reinterpret_cast <const Ty *>(src_p);
860938 Ty *dst_tp = reinterpret_cast <Ty *>(dst_p);
861939
862940 cgh.parallel_for <KernelName>(
863941 sycl::range<1 >(nelems),
864- StridedCopyForRollFunctor<Ty, StridedIndexer ,
942+ StridedCopyForRollFunctor<Ty, CompositeIndexerT ,
865943 UnpackedStridedIndexer>(
866- shift, nelems, src_tp, dst_tp, src_indexer , dst_indexer));
944+ src_tp, dst_tp, rolled_src_indexer , dst_indexer));
867945 });
868946
869947 return copy_for_roll_ev;
@@ -880,6 +958,8 @@ typedef sycl::event (*copy_for_roll_contig_fn_ptr_t)(
880958 py::ssize_t , // dst_offset
881959 const std::vector<sycl::event> &);
882960
961+ template <typename Ty> class copy_for_roll_contig_kernel ;
962+
883963/* !
884964 * @brief Function to copy content of array with a shift.
885965 *
@@ -917,6 +997,10 @@ sycl::event copy_for_roll_contig_impl(sycl::queue q,
917997 cgh.depends_on (depends);
918998
919999 NoOpIndexer src_indexer{};
1000+ LeftRolled1DTransformer roller{shift, nelems};
1001+
1002+ CompositionIndexer<NoOpIndexer, LeftRolled1DTransformer>
1003+ left_rolled_src_indexer{src_indexer, roller};
9201004 NoOpIndexer dst_indexer{};
9211005
9221006 using KernelName = copy_for_roll_contig_kernel<Ty>;
@@ -926,8 +1010,10 @@ sycl::event copy_for_roll_contig_impl(sycl::queue q,
9261010
9271011 cgh.parallel_for <KernelName>(
9281012 sycl::range<1 >(nelems),
929- StridedCopyForRollFunctor<Ty, NoOpIndexer, NoOpIndexer>(
930- shift, nelems, src_tp, dst_tp, src_indexer, dst_indexer));
1013+ StridedCopyForRollFunctor<
1014+ Ty, CompositionIndexer<NoOpIndexer, LeftRolled1DTransformer>,
1015+ NoOpIndexer>(src_tp, dst_tp, left_rolled_src_indexer,
1016+ dst_indexer));
9311017 });
9321018
9331019 return copy_for_roll_ev;
@@ -961,6 +1047,86 @@ template <typename fnT, typename Ty> struct CopyForRollContigFactory
9611047 }
9621048};
9631049
1050+ template <typename Ty, typename SrcIndexerT, typename DstIndexerT>
1051+ class copy_for_roll_ndshift_strided_kernel ;
1052+
1053+ // define function type
1054+ typedef sycl::event (*copy_for_roll_ndshift_strided_fn_ptr_t )(
1055+ sycl::queue,
1056+ size_t , // num_elements
1057+ int , // common_nd
1058+ const py::ssize_t *, // packed shape, strides, shifts
1059+ const char *, // src_data_ptr
1060+ py::ssize_t , // src_offset
1061+ char *, // dst_data_ptr
1062+ py::ssize_t , // dst_offset
1063+ const std::vector<sycl::event> &);
1064+
1065+ template <typename Ty>
1066+ sycl::event copy_for_roll_ndshift_strided_impl (
1067+ sycl::queue q,
1068+ size_t nelems,
1069+ int nd,
1070+ const py::ssize_t *packed_shapes_and_strides_and_shifts,
1071+ const char *src_p,
1072+ py::ssize_t src_offset,
1073+ char *dst_p,
1074+ py::ssize_t dst_offset,
1075+ const std::vector<sycl::event> &depends)
1076+ {
1077+ dpctl::tensor::type_utils::validate_type_for_device<Ty>(q);
1078+
1079+ sycl::event copy_for_roll_ev = q.submit ([&](sycl::handler &cgh) {
1080+ cgh.depends_on (depends);
1081+
1082+ // packed_shapes_and_strides_and_shifts:
1083+ // USM array of size 4 * nd
1084+ // [ common_shape; src_strides; dst_strides; shifts ]
1085+
1086+ const py::ssize_t *shape_ptr = packed_shapes_and_strides_and_shifts;
1087+ const py::ssize_t *src_strides_ptr =
1088+ packed_shapes_and_strides_and_shifts + nd;
1089+ const py::ssize_t *dst_strides_ptr =
1090+ packed_shapes_and_strides_and_shifts + 2 * nd;
1091+ const py::ssize_t *shifts_ptr =
1092+ packed_shapes_and_strides_and_shifts + 3 * nd;
1093+
1094+ RolledNDIndexer src_indexer{nd, shape_ptr, src_strides_ptr, shifts_ptr,
1095+ src_offset};
1096+
1097+ UnpackedStridedIndexer dst_indexer{nd, dst_offset, shape_ptr,
1098+ dst_strides_ptr};
1099+
1100+ using KernelName = copy_for_roll_strided_kernel<Ty, RolledNDIndexer,
1101+ UnpackedStridedIndexer>;
1102+
1103+ const Ty *src_tp = reinterpret_cast <const Ty *>(src_p);
1104+ Ty *dst_tp = reinterpret_cast <Ty *>(dst_p);
1105+
1106+ cgh.parallel_for <KernelName>(
1107+ sycl::range<1 >(nelems),
1108+ StridedCopyForRollFunctor<Ty, RolledNDIndexer,
1109+ UnpackedStridedIndexer>(
1110+ src_tp, dst_tp, src_indexer, dst_indexer));
1111+ });
1112+
1113+ return copy_for_roll_ev;
1114+ }
1115+
1116+ /* !
1117+ * @brief Factory to get function pointer of type `fnT` for given array data
1118+ * type `Ty`.
1119+ * @ingroup CopyAndCastKernels
1120+ */
1121+ template <typename fnT, typename Ty> struct CopyForRollNDShiftFactory
1122+ {
1123+ fnT get ()
1124+ {
1125+ fnT f = copy_for_roll_ndshift_strided_impl<Ty>;
1126+ return f;
1127+ }
1128+ };
1129+
9641130} // namespace copy_and_cast
9651131} // namespace kernels
9661132} // namespace tensor
0 commit comments