@@ -56,9 +56,6 @@ class copy_cast_contig_kernel;
5656template <typename srcT, typename dstT, typename IndexerT>
5757class copy_cast_from_host_kernel ;
5858
59- template <typename Ty, typename SrcIndexerT, typename DstIndexerT>
60- class copy_for_reshape_generic_kernel ;
61-
6259template <typename srcTy, typename dstTy> class Caster
6360{
6461public:
@@ -629,68 +626,56 @@ struct CopyAndCastFromHostFactory
629626
630627// =============== Copying for reshape ================== //
631628
629+ template <typename Ty, typename SrcIndexerT, typename DstIndexerT>
630+ class copy_for_reshape_generic_kernel ;
631+
632632template <typename Ty, typename SrcIndexerT, typename DstIndexerT>
633633class GenericCopyForReshapeFunctor
634634{
635635private:
636- py::ssize_t offset = 0 ;
637- py::ssize_t size = 1 ;
638- // USM array of size 2*(src_nd + dst_nd)
639- // [ src_shape; src_strides; dst_shape; dst_strides ]
640- Ty *src_p = nullptr ;
636+ const Ty *src_p = nullptr ;
641637 Ty *dst_p = nullptr ;
642638 SrcIndexerT src_indexer_;
643639 DstIndexerT dst_indexer_;
644640
645641public:
646- GenericCopyForReshapeFunctor (py::ssize_t shift,
647- py::ssize_t nelems,
648- char *src_ptr,
642+ GenericCopyForReshapeFunctor (const char *src_ptr,
649643 char *dst_ptr,
650644 SrcIndexerT src_indexer,
651645 DstIndexerT dst_indexer)
652- : offset(shift), size(nelems), src_p(reinterpret_cast <Ty *>(src_ptr)),
646+ : src_p(reinterpret_cast <const Ty *>(src_ptr)),
653647 dst_p (reinterpret_cast <Ty *>(dst_ptr)), src_indexer_(src_indexer),
654648 dst_indexer_(dst_indexer)
655649 {
656650 }
657651
658652 void operator ()(sycl::id<1 > wiid) const
659653 {
660- py::ssize_t this_src_offset = src_indexer_ (wiid.get (0 ));
661- const Ty *in = src_p + this_src_offset;
662-
663- py::ssize_t shifted_wiid =
664- (static_cast <py::ssize_t >(wiid.get (0 )) + offset) % size;
665- shifted_wiid = (shifted_wiid >= 0 ) ? shifted_wiid : shifted_wiid + size;
654+ const py::ssize_t src_offset = src_indexer_ (wiid.get (0 ));
655+ const py::ssize_t dst_offset = dst_indexer_ (wiid.get (0 ));
666656
667- py::ssize_t this_dst_offset = dst_indexer_ (shifted_wiid);
668-
669- Ty *out = dst_p + this_dst_offset;
670- *out = *in;
657+ dst_p[dst_offset] = src_p[src_offset];
671658 }
672659};
673660
674661// define function type
675662typedef sycl::event (*copy_for_reshape_fn_ptr_t )(
676663 sycl::queue,
677- py::ssize_t , // shift
678- size_t , // num_elements
679- int ,
680- int , // src_nd, dst_nd
664+ size_t , // num_elements
665+ int , // src_nd
666+ int , // dst_nd
681667 py::ssize_t *, // packed shapes and strides
682- char *, // src_data_ptr
668+ const char *, // src_data_ptr
683669 char *, // dst_data_ptr
684670 const std::vector<sycl::event> &);
685671
686672/* !
687673 * @brief Function to copy content of array while reshaping.
688674 *
689- * Submits a kernel to perform a copy `dst[unravel_index((i + shift) % nelems ,
675+ * Submits a kernel to perform a copy `dst[unravel_index(i ,
690676 * dst.shape)] = src[unravel_undex(i, src.shape)]`.
691677 *
692678 * @param q The execution queue where kernel is submitted.
693- * @param shift The shift in flat indexing.
694679 * @param nelems The number of elements to copy
695680 * @param src_nd Array dimension of the source array
696681 * @param dst_nd Array dimension of the destination array
@@ -708,31 +693,40 @@ typedef sycl::event (*copy_for_reshape_fn_ptr_t)(
708693template <typename Ty>
709694sycl::event
710695copy_for_reshape_generic_impl (sycl::queue q,
711- py::ssize_t shift,
712696 size_t nelems,
713697 int src_nd,
714698 int dst_nd,
715699 py::ssize_t *packed_shapes_and_strides,
716- char *src_p,
700+ const char *src_p,
717701 char *dst_p,
718702 const std::vector<sycl::event> &depends)
719703{
720704 dpctl::tensor::type_utils::validate_type_for_device<Ty>(q);
721705
722706 sycl::event copy_for_reshape_ev = q.submit ([&](sycl::handler &cgh) {
723- StridedIndexer src_indexer{
724- src_nd, 0 ,
725- const_cast <const py::ssize_t *>(packed_shapes_and_strides)};
726- StridedIndexer dst_indexer{
727- dst_nd, 0 ,
728- const_cast <const py::ssize_t *>(packed_shapes_and_strides +
729- (2 * src_nd))};
730707 cgh.depends_on (depends);
731- cgh.parallel_for <copy_for_reshape_generic_kernel<Ty, StridedIndexer,
732- StridedIndexer>>(
708+
709+ // packed_shapes_and_strides:
710+ // USM array of size 2*(src_nd + dst_nd)
711+ // [ src_shape; src_strides; dst_shape; dst_strides ]
712+
713+ const py::ssize_t *src_shape_and_strides =
714+ const_cast <const py::ssize_t *>(packed_shapes_and_strides);
715+
716+ const py::ssize_t *dst_shape_and_strides =
717+ const_cast <const py::ssize_t *>(packed_shapes_and_strides +
718+ (2 * src_nd));
719+
720+ StridedIndexer src_indexer{src_nd, 0 , src_shape_and_strides};
721+ StridedIndexer dst_indexer{dst_nd, 0 , dst_shape_and_strides};
722+
723+ using KernelName =
724+ copy_for_reshape_generic_kernel<Ty, StridedIndexer, StridedIndexer>;
725+
726+ cgh.parallel_for <KernelName>(
733727 sycl::range<1 >(nelems),
734728 GenericCopyForReshapeFunctor<Ty, StridedIndexer, StridedIndexer>(
735- shift, nelems, src_p, dst_p, src_indexer, dst_indexer));
729+ src_p, dst_p, src_indexer, dst_indexer));
736730 });
737731
738732 return copy_for_reshape_ev;
@@ -752,6 +746,221 @@ template <typename fnT, typename Ty> struct CopyForReshapeGenericFactory
752746 }
753747};
754748
749+ // =============== Copying for reshape ================== //
750+
751+ template <typename Ty, typename SrcIndexerT, typename DstIndexerT>
752+ class copy_for_roll_strided_kernel ;
753+
754+ template <typename Ty, typename SrcIndexerT, typename DstIndexerT>
755+ class StridedCopyForRollFunctor
756+ {
757+ private:
758+ size_t offset = 0 ;
759+ size_t size = 1 ;
760+ const Ty *src_p = nullptr ;
761+ Ty *dst_p = nullptr ;
762+ SrcIndexerT src_indexer_;
763+ DstIndexerT dst_indexer_;
764+
765+ public:
766+ StridedCopyForRollFunctor (size_t shift,
767+ size_t nelems,
768+ const Ty *src_ptr,
769+ Ty *dst_ptr,
770+ SrcIndexerT src_indexer,
771+ DstIndexerT dst_indexer)
772+ : offset(shift), size(nelems), src_p(src_ptr), dst_p(dst_ptr),
773+ src_indexer_ (src_indexer), dst_indexer_(dst_indexer)
774+ {
775+ }
776+
777+ void operator ()(sycl::id<1 > wiid) const
778+ {
779+ const size_t gid = wiid.get (0 );
780+ const size_t shifted_gid =
781+ ((gid < offset) ? gid + size - offset : gid - offset);
782+
783+ const py::ssize_t src_offset = src_indexer_ (shifted_gid);
784+ const py::ssize_t dst_offset = dst_indexer_ (gid);
785+
786+ dst_p[dst_offset] = src_p[src_offset];
787+ }
788+ };
789+
790+ // define function type
791+ typedef sycl::event (*copy_for_roll_strided_fn_ptr_t )(
792+ sycl::queue,
793+ size_t , // shift
794+ size_t , // num_elements
795+ int , // common_nd
796+ const py::ssize_t *, // packed shapes and strides
797+ const char *, // src_data_ptr
798+ py::ssize_t , // src_offset
799+ char *, // dst_data_ptr
800+ py::ssize_t , // dst_offset
801+ const std::vector<sycl::event> &);
802+
803+ template <typename Ty> class copy_for_roll_contig_kernel ;
804+
805+ /* !
806+ * @brief Function to copy content of array with a shift.
807+ *
808+ * Submits a kernel to perform a copy `dst[unravel_index((i + shift) % nelems ,
809+ * dst.shape)] = src[unravel_undex(i, src.shape)]`.
810+ *
811+ * @param q The execution queue where kernel is submitted.
812+ * @param shift The shift in flat indexing, must be non-negative.
813+ * @param nelems The number of elements to copy
814+ * @param nd Array dimensionality of the destination and source arrays
815+ * @param packed_shapes_and_strides Kernel accessible USM array of size
816+ * `3*nd` with content `[common_shape, src_strides, dst_strides]`.
817+ * @param src_p Typeless USM pointer to the buffer of the source array
818+ * @param src_offset Displacement of first element of src relative src_p in
819+ * elements
820+ * @param dst_p Typeless USM pointer to the buffer of the destination array
821+ * @param dst_offset Displacement of first element of dst relative dst_p in
822+ * elements
823+ * @param depends List of events to wait for before starting computations, if
824+ * any.
825+ *
826+ * @return Event to wait on to ensure that computation completes.
827+ * @ingroup CopyAndCastKernels
828+ */
829+ template <typename Ty>
830+ sycl::event
831+ copy_for_roll_strided_impl (sycl::queue q,
832+ size_t shift,
833+ size_t nelems,
834+ int nd,
835+ const py::ssize_t *packed_shapes_and_strides,
836+ const char *src_p,
837+ py::ssize_t src_offset,
838+ char *dst_p,
839+ py::ssize_t dst_offset,
840+ const std::vector<sycl::event> &depends)
841+ {
842+ dpctl::tensor::type_utils::validate_type_for_device<Ty>(q);
843+
844+ sycl::event copy_for_roll_ev = q.submit ([&](sycl::handler &cgh) {
845+ cgh.depends_on (depends);
846+
847+ // packed_shapes_and_strides:
848+ // USM array of size 3 * nd
849+ // [ common_shape; src_strides; dst_strides ]
850+
851+ StridedIndexer src_indexer{nd, src_offset, packed_shapes_and_strides};
852+ UnpackedStridedIndexer dst_indexer{nd, dst_offset,
853+ packed_shapes_and_strides,
854+ packed_shapes_and_strides + 2 * nd};
855+
856+ using KernelName = copy_for_roll_strided_kernel<Ty, StridedIndexer,
857+ UnpackedStridedIndexer>;
858+
859+ const Ty *src_tp = reinterpret_cast <const Ty *>(src_p);
860+ Ty *dst_tp = reinterpret_cast <Ty *>(dst_p);
861+
862+ cgh.parallel_for <KernelName>(
863+ sycl::range<1 >(nelems),
864+ StridedCopyForRollFunctor<Ty, StridedIndexer,
865+ UnpackedStridedIndexer>(
866+ shift, nelems, src_tp, dst_tp, src_indexer, dst_indexer));
867+ });
868+
869+ return copy_for_roll_ev;
870+ }
871+
872+ // define function type
873+ typedef sycl::event (*copy_for_roll_contig_fn_ptr_t )(
874+ sycl::queue,
875+ size_t , // shift
876+ size_t , // num_elements
877+ const char *, // src_data_ptr
878+ py::ssize_t , // src_offset
879+ char *, // dst_data_ptr
880+ py::ssize_t , // dst_offset
881+ const std::vector<sycl::event> &);
882+
883+ /* !
884+ * @brief Function to copy content of array with a shift.
885+ *
886+ * Submits a kernel to perform a copy `dst[unravel_index((i + shift) % nelems ,
887+ * dst.shape)] = src[unravel_undex(i, src.shape)]`.
888+ *
889+ * @param q The execution queue where kernel is submitted.
890+ * @param shift The shift in flat indexing, must be non-negative.
891+ * @param nelems The number of elements to copy
892+ * @param src_p Typeless USM pointer to the buffer of the source array
893+ * @param src_offset Displacement of the start of array src relative src_p in
894+ * elements
895+ * @param dst_p Typeless USM pointer to the buffer of the destination array
896+ * @param dst_offset Displacement of the start of array dst relative dst_p in
897+ * elements
898+ * @param depends List of events to wait for before starting computations, if
899+ * any.
900+ *
901+ * @return Event to wait on to ensure that computation completes.
902+ * @ingroup CopyAndCastKernels
903+ */
904+ template <typename Ty>
905+ sycl::event copy_for_roll_contig_impl (sycl::queue q,
906+ size_t shift,
907+ size_t nelems,
908+ const char *src_p,
909+ py::ssize_t src_offset,
910+ char *dst_p,
911+ py::ssize_t dst_offset,
912+ const std::vector<sycl::event> &depends)
913+ {
914+ dpctl::tensor::type_utils::validate_type_for_device<Ty>(q);
915+
916+ sycl::event copy_for_roll_ev = q.submit ([&](sycl::handler &cgh) {
917+ cgh.depends_on (depends);
918+
919+ NoOpIndexer src_indexer{};
920+ NoOpIndexer dst_indexer{};
921+
922+ using KernelName = copy_for_roll_contig_kernel<Ty>;
923+
924+ const Ty *src_tp = reinterpret_cast <const Ty *>(src_p) + src_offset;
925+ Ty *dst_tp = reinterpret_cast <Ty *>(dst_p) + dst_offset;
926+
927+ cgh.parallel_for <KernelName>(
928+ sycl::range<1 >(nelems),
929+ StridedCopyForRollFunctor<Ty, NoOpIndexer, NoOpIndexer>(
930+ shift, nelems, src_tp, dst_tp, src_indexer, dst_indexer));
931+ });
932+
933+ return copy_for_roll_ev;
934+ }
935+
936+ /* !
937+ * @brief Factory to get function pointer of type `fnT` for given array data
938+ * type `Ty`.
939+ * @ingroup CopyAndCastKernels
940+ */
941+ template <typename fnT, typename Ty> struct CopyForRollStridedFactory
942+ {
943+ fnT get ()
944+ {
945+ fnT f = copy_for_roll_strided_impl<Ty>;
946+ return f;
947+ }
948+ };
949+
950+ /* !
951+ * @brief Factory to get function pointer of type `fnT` for given array data
952+ * type `Ty`.
953+ * @ingroup CopyAndCastKernels
954+ */
955+ template <typename fnT, typename Ty> struct CopyForRollContigFactory
956+ {
957+ fnT get ()
958+ {
959+ fnT f = copy_for_roll_contig_impl<Ty>;
960+ return f;
961+ }
962+ };
963+
755964} // namespace copy_and_cast
756965} // namespace kernels
757966} // namespace tensor
0 commit comments