diff --git a/dpctl/tensor/libtensor/include/kernels/reductions.hpp b/dpctl/tensor/libtensor/include/kernels/reductions.hpp index c8aae0a3b9..30cd3fad42 100644 --- a/dpctl/tensor/libtensor/include/kernels/reductions.hpp +++ b/dpctl/tensor/libtensor/include/kernels/reductions.hpp @@ -146,9 +146,9 @@ struct ReductionOverGroupWithAtomicFunctor void operator()(sycl::nd_item<1> it) const { - const size_t red_gws_ = it.get_global_range(0) / iter_gws_; - const size_t iter_gid = it.get_global_id(0) / red_gws_; - const size_t reduction_batch_id = get_reduction_batch_id(it); + const size_t iter_gid = it.get_group(0) % iter_gws_; + const size_t reduction_batch_id = it.get_group(0) / iter_gws_; + const size_t reduction_lid = it.get_local_id(0); const size_t wg = it.get_local_range(0); // 0 <= reduction_lid < wg @@ -204,14 +204,6 @@ struct ReductionOverGroupWithAtomicFunctor } } } - -private: - size_t get_reduction_batch_id(sycl::nd_item<1> const &it) const - { - const size_t n_reduction_groups = it.get_group_range(0) / iter_gws_; - const size_t reduction_batch_id = it.get_group(0) % n_reduction_groups; - return reduction_batch_id; - } }; typedef sycl::event (*sum_reduction_strided_impl_fn_ptr)( @@ -241,6 +233,12 @@ class sum_reduction_seq_strided_krn; template class sum_reduction_seq_contig_krn; +template +class sum_reduction_axis0_over_group_with_atomics_contig_krn; + +template +class sum_reduction_axis1_over_group_with_atomics_contig_krn; + using dpctl::tensor::sycl_utils::choose_workgroup_size; template @@ -344,20 +342,6 @@ sycl::event sum_reduction_over_group_with_atomics_strided_impl( (reduction_nelems + reductions_per_wi * wg - 1) / (reductions_per_wi * wg); - if (reduction_groups > 1) { - const size_t &max_wg = - d.get_info(); - - if (reduction_nelems < preferrered_reductions_per_wi * max_wg) { - wg = max_wg; - reductions_per_wi = - std::max(1, (reduction_nelems + wg - 1) / wg); - reduction_groups = - (reduction_nelems + reductions_per_wi * wg - 1) / - (reductions_per_wi * wg); - } - } - auto globalRange = sycl::range<1>{iter_nelems * reduction_groups * wg}; auto localRange = sycl::range<1>{wg}; @@ -395,7 +379,7 @@ typedef sycl::event (*sum_reduction_contig_impl_fn_ptr)( /* @brief Reduce rows in a matrix */ template -sycl::event sum_reduction_over_group_with_atomics_contig_impl( +sycl::event sum_reduction_axis1_over_group_with_atomics_contig_impl( sycl::queue exec_q, size_t iter_nelems, // number of reductions (num. of rows in a matrix // when reducing over rows) @@ -417,7 +401,7 @@ sycl::event sum_reduction_over_group_with_atomics_contig_impl( const sycl::device &d = exec_q.get_device(); const auto &sg_sizes = d.get_info(); - size_t wg = choose_workgroup_size<2>(reduction_nelems, sg_sizes); + size_t wg = choose_workgroup_size<4>(reduction_nelems, sg_sizes); if (reduction_nelems < wg) { sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) { @@ -463,11 +447,11 @@ sycl::event sum_reduction_over_group_with_atomics_contig_impl( RowsIndexerT, NoOpIndexerT>; using ReductionIndexerT = NoOpIndexerT; - RowsIndexerT columns_indexer{ + RowsIndexerT rows_indexer{ 0, static_cast(iter_nelems), static_cast(reduction_nelems)}; NoOpIndexerT result_indexer{}; - InputOutputIterIndexerT in_out_iter_indexer{columns_indexer, + InputOutputIterIndexerT in_out_iter_indexer{rows_indexer, result_indexer}; ReductionIndexerT reduction_indexer{}; @@ -481,27 +465,95 @@ sycl::event sum_reduction_over_group_with_atomics_contig_impl( (reduction_nelems + reductions_per_wi * wg - 1) / (reductions_per_wi * wg); - if (reduction_groups > 1) { - const size_t &max_wg = - d.get_info(); - - if (reduction_nelems < preferrered_reductions_per_wi * max_wg) { - wg = max_wg; - reductions_per_wi = - std::max(1, (reduction_nelems + wg - 1) / wg); - reduction_groups = - (reduction_nelems + reductions_per_wi * wg - 1) / - (reductions_per_wi * wg); - } - } + auto globalRange = + sycl::range<1>{iter_nelems * reduction_groups * wg}; + auto localRange = sycl::range<1>{wg}; + + using KernelName = + class sum_reduction_axis1_over_group_with_atomics_contig_krn< + argTy, resTy, ReductionOpT, InputOutputIterIndexerT, + ReductionIndexerT>; + + cgh.parallel_for( + sycl::nd_range<1>(globalRange, localRange), + ReductionOverGroupWithAtomicFunctor( + arg_tp, res_tp, ReductionOpT(), identity_val, + in_out_iter_indexer, reduction_indexer, reduction_nelems, + iter_nelems, reductions_per_wi)); + }); + + return comp_ev; + } +} + +/* @brief Reduce rows in a matrix */ +template +sycl::event sum_reduction_axis0_over_group_with_atomics_contig_impl( + sycl::queue exec_q, + size_t iter_nelems, // number of reductions (num. of cols in a matrix + // when reducing over cols) + size_t reduction_nelems, // size of each reduction (length of cols, i.e. + // number of rows) + const char *arg_cp, + char *res_cp, + py::ssize_t iter_arg_offset, + py::ssize_t iter_res_offset, + py::ssize_t reduction_arg_offset, + const std::vector &depends) +{ + const argTy *arg_tp = reinterpret_cast(arg_cp) + + iter_arg_offset + reduction_arg_offset; + resTy *res_tp = reinterpret_cast(res_cp) + iter_res_offset; + + using ReductionOpT = sycl::plus; + constexpr resTy identity_val = resTy{0}; + + const sycl::device &d = exec_q.get_device(); + const auto &sg_sizes = d.get_info(); + size_t wg = choose_workgroup_size<4>(reduction_nelems, sg_sizes); + + { + sycl::event res_init_ev = exec_q.fill( + res_tp, resTy(identity_val), iter_nelems, depends); + + sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(res_init_ev); + + using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using ColsIndexerT = dpctl::tensor::offset_utils::Strided1DIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + NoOpIndexerT, NoOpIndexerT>; + using ReductionIndexerT = ColsIndexerT; + + NoOpIndexerT columns_indexer{}; + NoOpIndexerT result_indexer{}; + InputOutputIterIndexerT in_out_iter_indexer{columns_indexer, + result_indexer}; + ReductionIndexerT reduction_indexer{ + 0, /* size */ static_cast(reduction_nelems), + /* step */ static_cast(iter_nelems)}; + + constexpr size_t preferrered_reductions_per_wi = 8; + size_t reductions_per_wi = + (reduction_nelems < preferrered_reductions_per_wi * wg) + ? std::max(1, (reduction_nelems + wg - 1) / wg) + : preferrered_reductions_per_wi; + + size_t reduction_groups = + (reduction_nelems + reductions_per_wi * wg - 1) / + (reductions_per_wi * wg); auto globalRange = sycl::range<1>{iter_nelems * reduction_groups * wg}; auto localRange = sycl::range<1>{wg}; - using KernelName = class sum_reduction_over_group_with_atomics_krn< - argTy, resTy, ReductionOpT, InputOutputIterIndexerT, - ReductionIndexerT>; + using KernelName = + class sum_reduction_axis0_over_group_with_atomics_contig_krn< + argTy, resTy, ReductionOpT, InputOutputIterIndexerT, + ReductionIndexerT>; cgh.parallel_for( sycl::nd_range<1>(globalRange, localRange), @@ -558,14 +610,13 @@ struct ReductionOverGroupNoAtomicFunctor void operator()(sycl::nd_item<1> it) const { - - const size_t red_gws_ = it.get_global_range(0) / iter_gws_; - const size_t iter_gid = it.get_global_id(0) / red_gws_; - const size_t n_reduction_groups = it.get_group_range(0) / iter_gws_; - const size_t reduction_batch_id = it.get_group(0) % n_reduction_groups; const size_t reduction_lid = it.get_local_id(0); const size_t wg = it.get_local_range(0); // 0 <= reduction_lid < wg + const size_t iter_gid = it.get_group(0) % iter_gws_; + const size_t reduction_batch_id = it.get_group(0) / iter_gws_; + const size_t n_reduction_groups = it.get_group_range(0) / iter_gws_; + // work-items sums over input with indices // inp_data_id = reduction_batch_id * wg * reductions_per_wi + m * wg // + reduction_lid @@ -1079,7 +1130,25 @@ struct SumOverAxisTempsStridedFactory }; template -struct SumOverAxisAtomicContigFactory +struct SumOverAxis1AtomicContigFactory +{ + fnT get() const + { + if constexpr (TypePairSupportDataForSumReductionAtomic< + srcTy, dstTy>::is_defined) + { + return dpctl::tensor::kernels:: + sum_reduction_axis1_over_group_with_atomics_contig_impl; + } + else { + return nullptr; + } + } +}; + +template +struct SumOverAxis0AtomicContigFactory { fnT get() const { @@ -1087,7 +1156,8 @@ struct SumOverAxisAtomicContigFactory srcTy, dstTy>::is_defined) { return dpctl::tensor::kernels:: - sum_reduction_over_group_with_atomics_contig_impl; + sum_reduction_axis0_over_group_with_atomics_contig_impl; } else { return nullptr; diff --git a/dpctl/tensor/libtensor/source/sum_reductions.cpp b/dpctl/tensor/libtensor/source/sum_reductions.cpp index 3502a81a0e..7628813c6d 100644 --- a/dpctl/tensor/libtensor/source/sum_reductions.cpp +++ b/dpctl/tensor/libtensor/source/sum_reductions.cpp @@ -88,8 +88,11 @@ static sum_reduction_strided_impl_fn_ptr using dpctl::tensor::kernels::sum_reduction_contig_impl_fn_ptr; static sum_reduction_contig_impl_fn_ptr - sum_over_axis_contig_atomic_dispatch_table[td_ns::num_types] - [td_ns::num_types]; + sum_over_axis1_contig_atomic_dispatch_table[td_ns::num_types] + [td_ns::num_types]; +static sum_reduction_contig_impl_fn_ptr + sum_over_axis0_contig_atomic_dispatch_table[td_ns::num_types] + [td_ns::num_types]; std::pair py_sum_over_axis( dpctl::tensor::usm_ndarray src, @@ -194,8 +197,30 @@ std::pair py_sum_over_axis( if ((is_src_c_contig && is_dst_c_contig) || (is_src_f_contig && dst_nelems == 1)) { - auto fn = sum_over_axis_contig_atomic_dispatch_table[src_typeid] - [dst_typeid]; + auto fn = sum_over_axis1_contig_atomic_dispatch_table[src_typeid] + [dst_typeid]; + if (fn != nullptr) { + size_t iter_nelems = dst_nelems; + + constexpr py::ssize_t zero_offset = 0; + + sycl::event sum_over_axis_contig_ev = + fn(exec_q, iter_nelems, reduction_nelems, src.get_data(), + dst.get_data(), + zero_offset, // iteration_src_offset + zero_offset, // iteration_dst_offset + zero_offset, // reduction_src_offset + depends); + + sycl::event keep_args_event = dpctl::utils::keep_args_alive( + exec_q, {src, dst}, {sum_over_axis_contig_ev}); + + return std::make_pair(keep_args_event, sum_over_axis_contig_ev); + } + } + else if (is_src_f_contig & is_dst_c_contig) { + auto fn = sum_over_axis0_contig_atomic_dispatch_table[src_typeid] + [dst_typeid]; if (fn != nullptr) { size_t iter_nelems = dst_nelems; @@ -271,27 +296,58 @@ std::pair py_sum_over_axis( iteration_src_offset, iteration_dst_offset); } - if (supports_atomics && (reduction_nd == 1) && - (simplified_reduction_src_strides[0] == 1) && (iteration_nd == 1) && - ((simplified_iteration_shape[0] == 1) || - ((simplified_iteration_dst_strides[0] == 1) && - (static_cast(simplified_iteration_src_strides[0]) == - reduction_nelems)))) - { - auto fn = - sum_over_axis_contig_atomic_dispatch_table[src_typeid][dst_typeid]; - if (fn != nullptr) { - size_t iter_nelems = dst_nelems; + if (supports_atomics && (reduction_nd == 1) && (iteration_nd == 1)) { + bool mat_reduce_over_axis1 = false; + bool mat_reduce_over_axis0 = false; + bool array_reduce_all_elems = false; + size_t iter_nelems = dst_nelems; + + if (simplified_reduction_src_strides[0] == 1) { + array_reduce_all_elems = (simplified_iteration_shape[0] == 1); + mat_reduce_over_axis1 = + (simplified_iteration_dst_strides[0] == 1) && + (static_cast(simplified_iteration_src_strides[0]) == + reduction_nelems); + } + else if (static_cast(simplified_reduction_src_strides[0]) == + iter_nelems) + { + mat_reduce_over_axis0 = + (simplified_iteration_dst_strides[0] == 1) && + (simplified_iteration_src_strides[0] == 1); + } + + if (mat_reduce_over_axis1 || array_reduce_all_elems) { + auto fn = sum_over_axis1_contig_atomic_dispatch_table[src_typeid] + [dst_typeid]; + if (fn != nullptr) { + sycl::event sum_over_axis1_contig_ev = + fn(exec_q, iter_nelems, reduction_nelems, src.get_data(), + dst.get_data(), iteration_src_offset, + iteration_dst_offset, reduction_src_offset, depends); - sycl::event sum_over_axis_contig_ev = - fn(exec_q, iter_nelems, reduction_nelems, src.get_data(), - dst.get_data(), iteration_src_offset, iteration_dst_offset, - reduction_src_offset, depends); + sycl::event keep_args_event = dpctl::utils::keep_args_alive( + exec_q, {src, dst}, {sum_over_axis1_contig_ev}); + + return std::make_pair(keep_args_event, + sum_over_axis1_contig_ev); + } + } + else if (mat_reduce_over_axis0) { + auto fn = sum_over_axis0_contig_atomic_dispatch_table[src_typeid] + [dst_typeid]; + if (fn != nullptr) { + sycl::event sum_over_axis0_contig_ev = + fn(exec_q, iter_nelems, reduction_nelems, src.get_data(), + dst.get_data(), iteration_src_offset, + iteration_dst_offset, reduction_src_offset, depends); - sycl::event keep_args_event = dpctl::utils::keep_args_alive( - exec_q, {src, dst}, {sum_over_axis_contig_ev}); + sycl::event keep_args_event = dpctl::utils::keep_args_alive( + exec_q, {src, dst}, {sum_over_axis0_contig_ev}); - return std::make_pair(keep_args_event, sum_over_axis_contig_ev); + return std::make_pair(keep_args_event, + sum_over_axis0_contig_ev); + } } } @@ -451,11 +507,17 @@ void populate_sum_over_axis_dispatch_table(void) dtb2; dtb2.populate_dispatch_table(sum_over_axis_strided_temps_dispatch_table); - using dpctl::tensor::kernels::SumOverAxisAtomicContigFactory; + using dpctl::tensor::kernels::SumOverAxis1AtomicContigFactory; DispatchTableBuilder + SumOverAxis1AtomicContigFactory, num_types> dtb3; - dtb3.populate_dispatch_table(sum_over_axis_contig_atomic_dispatch_table); + dtb3.populate_dispatch_table(sum_over_axis1_contig_atomic_dispatch_table); + + using dpctl::tensor::kernels::SumOverAxis0AtomicContigFactory; + DispatchTableBuilder + dtb4; + dtb4.populate_dispatch_table(sum_over_axis0_contig_atomic_dispatch_table); } namespace py = pybind11;