diff --git a/dpctl/tensor/libtensor/include/kernels/boolean_reductions.hpp b/dpctl/tensor/libtensor/include/kernels/boolean_reductions.hpp deleted file mode 100644 index 583776949e..0000000000 --- a/dpctl/tensor/libtensor/include/kernels/boolean_reductions.hpp +++ /dev/null @@ -1,732 +0,0 @@ -//=== boolean_reductions.hpp - Implementation of boolean reduction kernels // -// ---*-C++-*--/===// -// -// Data Parallel Control (dpctl) -// -// Copyright 2020-2024 Intel Corporation -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// -//===----------------------------------------------------------------------===// -/// -/// \file -/// This file defines kernels for dpctl.tensor.any and dpctl.tensor.all -//===----------------------------------------------------------------------===// - -#pragma once -#include - -#include -#include -#include -#include - -#include "dpctl_tensor_types.hpp" -#include "utils/offset_utils.hpp" -#include "utils/sycl_utils.hpp" -#include "utils/type_dispatch_building.hpp" -#include "utils/type_utils.hpp" - -namespace dpctl -{ -namespace tensor -{ -namespace kernels -{ - -template struct boolean_predicate -{ - bool operator()(const T &v) const - { - using dpctl::tensor::type_utils::convert_impl; - return convert_impl(v); - } -}; - -template -struct all_reduce_wg_contig -{ - void operator()(sycl::nd_item<1> &ndit, - outT *out, - const size_t &out_idx, - const inpT *start, - const inpT *end) const - { - PredicateT pred{}; - auto wg = ndit.get_group(); - outT red_val_over_wg = - static_cast(sycl::joint_all_of(wg, start, end, pred)); - - if (wg.leader()) { - sycl::atomic_ref - res_ref(out[out_idx]); - res_ref.fetch_and(red_val_over_wg); - } - } -}; - -template -struct any_reduce_wg_contig -{ - void operator()(sycl::nd_item<1> &ndit, - outT *out, - const size_t &out_idx, - const inpT *start, - const inpT *end) const - { - PredicateT pred{}; - auto wg = ndit.get_group(); - outT red_val_over_wg = - static_cast(sycl::joint_any_of(wg, start, end, pred)); - - if (wg.leader()) { - sycl::atomic_ref - res_ref(out[out_idx]); - res_ref.fetch_or(red_val_over_wg); - } - } -}; - -template struct all_reduce_wg_strided -{ - void operator()(sycl::nd_item<1> &ndit, - T *out, - const size_t &out_idx, - const T &local_val) const - { - auto wg = ndit.get_group(); - T red_val_over_wg = static_cast(sycl::all_of_group(wg, local_val)); - - if (wg.leader()) { - sycl::atomic_ref - res_ref(out[out_idx]); - res_ref.fetch_and(red_val_over_wg); - } - } -}; - -template struct any_reduce_wg_strided -{ - void operator()(sycl::nd_item<1> &ndit, - T *out, - const size_t &out_idx, - const T &local_val) const - { - auto wg = ndit.get_group(); - T red_val_over_wg = static_cast(sycl::any_of_group(wg, local_val)); - - if (wg.leader()) { - sycl::atomic_ref - res_ref(out[out_idx]); - res_ref.fetch_or(red_val_over_wg); - } - } -}; - -template -struct SequentialBooleanReduction -{ -private: - const argT *inp_ = nullptr; - outT *out_ = nullptr; - ReductionOp reduction_op_; - outT identity_; - InputOutputIterIndexerT inp_out_iter_indexer_; - InputRedIndexerT inp_reduced_dims_indexer_; - size_t reduction_max_gid_ = 0; - -public: - SequentialBooleanReduction(const argT *inp, - outT *res, - ReductionOp reduction_op, - const outT &identity_val, - InputOutputIterIndexerT arg_res_iter_indexer, - InputRedIndexerT arg_reduced_dims_indexer, - size_t reduction_size) - : inp_(inp), out_(res), reduction_op_(reduction_op), - identity_(identity_val), inp_out_iter_indexer_(arg_res_iter_indexer), - inp_reduced_dims_indexer_(arg_reduced_dims_indexer), - reduction_max_gid_(reduction_size) - { - } - - void operator()(sycl::id<1> id) const - { - - auto const &inp_out_iter_offsets_ = inp_out_iter_indexer_(id[0]); - const ssize_t &inp_iter_offset = - inp_out_iter_offsets_.get_first_offset(); - const ssize_t &out_iter_offset = - inp_out_iter_offsets_.get_second_offset(); - - outT red_val(identity_); - for (size_t m = 0; m < reduction_max_gid_; ++m) { - ssize_t inp_reduction_offset = - static_cast(inp_reduced_dims_indexer_(m)); - ssize_t inp_offset = inp_iter_offset + inp_reduction_offset; - - // must convert to boolean first to handle nans - using dpctl::tensor::type_utils::convert_impl; - outT val = convert_impl(inp_[inp_offset]); - ReductionOp op = reduction_op_; - - red_val = op(red_val, val); - } - - out_[out_iter_offset] = red_val; - } -}; - -template -struct ContigBooleanReduction -{ -private: - const argT *inp_ = nullptr; - outT *out_ = nullptr; - GroupOp group_op_; - size_t reduction_max_gid_ = 0; - size_t iter_gws_ = 1; - size_t reductions_per_wi = 16; - -public: - ContigBooleanReduction(const argT *inp, - outT *res, - GroupOp group_op, - size_t reduction_size, - size_t iteration_size, - size_t reduction_size_per_wi) - : inp_(inp), out_(res), group_op_(group_op), - reduction_max_gid_(reduction_size), iter_gws_(iteration_size), - reductions_per_wi(reduction_size_per_wi) - { - } - - void operator()(sycl::nd_item<1> it) const - { - const size_t reduction_id = it.get_group(0) % iter_gws_; - const size_t reduction_batch_id = it.get_group(0) / iter_gws_; - const size_t wg_size = it.get_local_range(0); - - const size_t base = reduction_id * reduction_max_gid_; - const size_t start = - base + reduction_batch_id * wg_size * reductions_per_wi; - const size_t end = std::min((start + (reductions_per_wi * wg_size)), - base + reduction_max_gid_); - // reduction and atomic operations are performed - // in group_op_ - group_op_(it, out_, reduction_id, inp_ + start, inp_ + end); - } -}; - -typedef sycl::event (*boolean_reduction_contig_impl_fn_ptr)( - sycl::queue &, - size_t, - size_t, - const char *, - char *, - ssize_t, - ssize_t, - ssize_t, - const std::vector &); - -template -class boolean_reduction_contig_krn; - -template -class boolean_reduction_seq_contig_krn; - -using dpctl::tensor::sycl_utils::choose_workgroup_size; - -template -sycl::event -boolean_reduction_axis1_contig_impl(sycl::queue &exec_q, - size_t iter_nelems, - size_t reduction_nelems, - const char *arg_cp, - char *res_cp, - ssize_t iter_arg_offset, - ssize_t iter_res_offset, - ssize_t red_arg_offset, - const std::vector &depends) -{ - const argTy *arg_tp = reinterpret_cast(arg_cp) + - iter_arg_offset + red_arg_offset; - resTy *res_tp = reinterpret_cast(res_cp) + iter_res_offset; - - constexpr resTy identity_val = sycl::known_identity::value; - - const sycl::device &d = exec_q.get_device(); - const auto &sg_sizes = d.get_info(); - size_t wg = choose_workgroup_size<4>(reduction_nelems, sg_sizes); - - sycl::event red_ev; - if (reduction_nelems < wg) { - red_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(depends); - - using InputIterIndexerT = - dpctl::tensor::offset_utils::Strided1DIndexer; - using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; - using InputOutputIterIndexerT = - dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< - InputIterIndexerT, NoOpIndexerT>; - using ReductionIndexerT = NoOpIndexerT; - - InputOutputIterIndexerT in_out_iter_indexer{ - InputIterIndexerT{0, static_cast(iter_nelems), - static_cast(reduction_nelems)}, - NoOpIndexerT{}}; - ReductionIndexerT reduction_indexer{}; - - cgh.parallel_for>( - sycl::range<1>(iter_nelems), - SequentialBooleanReduction( - arg_tp, res_tp, RedOpT(), identity_val, in_out_iter_indexer, - reduction_indexer, reduction_nelems)); - }); - } - else { - sycl::event init_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(depends); - cgh.fill(res_tp, resTy(identity_val), iter_nelems); - }); - red_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(init_ev); - - constexpr std::uint8_t dim = 1; - - constexpr size_t preferred_reductions_per_wi = 4; - size_t reductions_per_wi = - (reduction_nelems < preferred_reductions_per_wi * wg) - ? ((reduction_nelems + wg - 1) / wg) - : preferred_reductions_per_wi; - - size_t reduction_groups = - (reduction_nelems + reductions_per_wi * wg - 1) / - (reductions_per_wi * wg); - - auto gws = sycl::range{iter_nelems * reduction_groups * wg}; - auto lws = sycl::range{wg}; - - cgh.parallel_for< - class boolean_reduction_contig_krn>( - sycl::nd_range(gws, lws), - ContigBooleanReduction( - arg_tp, res_tp, GroupOpT(), reduction_nelems, iter_nelems, - reductions_per_wi)); - }); - } - return red_ev; -} - -template struct AllAxis1ContigFactory -{ - fnT get() const - { - using resTy = std::int32_t; - using RedOpT = sycl::logical_and; - using GroupOpT = - all_reduce_wg_contig>; - - return dpctl::tensor::kernels::boolean_reduction_axis1_contig_impl< - srcTy, resTy, RedOpT, GroupOpT>; - } -}; - -template struct AnyAxis1ContigFactory -{ - fnT get() const - { - using resTy = std::int32_t; - using RedOpT = sycl::logical_or; - using GroupOpT = - any_reduce_wg_contig>; - - return dpctl::tensor::kernels::boolean_reduction_axis1_contig_impl< - srcTy, resTy, RedOpT, GroupOpT>; - } -}; - -template -struct StridedBooleanReduction -{ -private: - const argT *inp_ = nullptr; - outT *out_ = nullptr; - ReductionOp reduction_op_; - GroupOp group_op_; - outT identity_; - InputOutputIterIndexerT inp_out_iter_indexer_; - InputRedIndexerT inp_reduced_dims_indexer_; - size_t reduction_max_gid_ = 0; - size_t iter_gws_ = 1; - size_t reductions_per_wi = 16; - -public: - StridedBooleanReduction(const argT *inp, - outT *res, - ReductionOp reduction_op, - GroupOp group_op, - const outT &identity_val, - InputOutputIterIndexerT arg_res_iter_indexer, - InputRedIndexerT arg_reduced_dims_indexer, - size_t reduction_size, - size_t iteration_size, - size_t reduction_size_per_wi) - : inp_(inp), out_(res), reduction_op_(reduction_op), - group_op_(group_op), identity_(identity_val), - inp_out_iter_indexer_(arg_res_iter_indexer), - inp_reduced_dims_indexer_(arg_reduced_dims_indexer), - reduction_max_gid_(reduction_size), iter_gws_(iteration_size), - reductions_per_wi(reduction_size_per_wi) - { - } - - void operator()(sycl::nd_item<1> it) const - { - const size_t reduction_id = it.get_group(0) % iter_gws_; - const size_t reduction_batch_id = it.get_group(0) / iter_gws_; - - const size_t reduction_lid = it.get_local_id(0); - const size_t wg_size = it.get_local_range(0); - - auto inp_out_iter_offsets_ = inp_out_iter_indexer_(reduction_id); - const ssize_t &inp_iter_offset = - inp_out_iter_offsets_.get_first_offset(); - const ssize_t &out_iter_offset = - inp_out_iter_offsets_.get_second_offset(); - - outT local_red_val(identity_); - size_t arg_reduce_gid0 = - reduction_lid + reduction_batch_id * wg_size * reductions_per_wi; - size_t arg_reduce_gid_max = std::min( - reduction_max_gid_, arg_reduce_gid0 + reductions_per_wi * wg_size); - for (size_t arg_reduce_gid = arg_reduce_gid0; - arg_reduce_gid < arg_reduce_gid_max; arg_reduce_gid += wg_size) - { - ssize_t inp_reduction_offset = - static_cast(inp_reduced_dims_indexer_(arg_reduce_gid)); - ssize_t inp_offset = inp_iter_offset + inp_reduction_offset; - - // must convert to boolean first to handle nans - using dpctl::tensor::type_utils::convert_impl; - bool val = convert_impl(inp_[inp_offset]); - ReductionOp op = reduction_op_; - - local_red_val = op(local_red_val, static_cast(val)); - } - // reduction and atomic operations are performed - // in group_op_ - group_op_(it, out_, out_iter_offset, local_red_val); - } -}; - -template -class boolean_reduction_axis0_contig_krn; - -template -sycl::event -boolean_reduction_axis0_contig_impl(sycl::queue &exec_q, - size_t iter_nelems, - size_t reduction_nelems, - const char *arg_cp, - char *res_cp, - ssize_t iter_arg_offset, - ssize_t iter_res_offset, - ssize_t red_arg_offset, - const std::vector &depends) -{ - const argTy *arg_tp = reinterpret_cast(arg_cp) + - iter_arg_offset + red_arg_offset; - resTy *res_tp = reinterpret_cast(res_cp) + iter_res_offset; - - constexpr resTy identity_val = sycl::known_identity::value; - - const sycl::device &d = exec_q.get_device(); - const auto &sg_sizes = d.get_info(); - size_t wg = choose_workgroup_size<4>(reduction_nelems, sg_sizes); - - { - sycl::event init_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(depends); - cgh.fill(res_tp, resTy(identity_val), iter_nelems); - }); - sycl::event red_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(init_ev); - - constexpr std::uint8_t dim = 1; - - using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; - using ColsIndexerT = dpctl::tensor::offset_utils::Strided1DIndexer; - using InputOutputIterIndexerT = - dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< - NoOpIndexerT, NoOpIndexerT>; - using ReductionIndexerT = ColsIndexerT; - - NoOpIndexerT columns_indexer{}; - NoOpIndexerT result_indexer{}; - InputOutputIterIndexerT in_out_iter_indexer{columns_indexer, - result_indexer}; - ReductionIndexerT reduction_indexer{ - 0, static_cast(reduction_nelems), - static_cast(iter_nelems)}; - - constexpr size_t preferred_reductions_per_wi = 4; - size_t reductions_per_wi = - (reduction_nelems < preferred_reductions_per_wi * wg) - ? ((reduction_nelems + wg - 1) / wg) - : preferred_reductions_per_wi; - - size_t reduction_groups = - (reduction_nelems + reductions_per_wi * wg - 1) / - (reductions_per_wi * wg); - - auto gws = sycl::range{iter_nelems * reduction_groups * wg}; - auto lws = sycl::range{wg}; - - cgh.parallel_for>( - sycl::nd_range(gws, lws), - StridedBooleanReduction( - arg_tp, res_tp, RedOpT(), GroupOpT(), identity_val, - in_out_iter_indexer, reduction_indexer, reduction_nelems, - iter_nelems, reductions_per_wi)); - }); - return red_ev; - } -} - -template struct AllAxis0ContigFactory -{ - fnT get() const - { - using resTy = std::int32_t; - using RedOpT = sycl::logical_and; - using GroupOpT = all_reduce_wg_strided; - - return dpctl::tensor::kernels::boolean_reduction_axis0_contig_impl< - srcTy, resTy, RedOpT, GroupOpT>; - } -}; - -template struct AnyAxis0ContigFactory -{ - fnT get() const - { - using resTy = std::int32_t; - using RedOpT = sycl::logical_or; - using GroupOpT = any_reduce_wg_strided; - - return dpctl::tensor::kernels::boolean_reduction_axis0_contig_impl< - srcTy, resTy, RedOpT, GroupOpT>; - } -}; - -template -class boolean_reduction_strided_krn; - -template -class boolean_reduction_seq_strided_krn; - -typedef sycl::event (*boolean_reduction_strided_impl_fn_ptr)( - sycl::queue &, - size_t, - size_t, - const char *, - char *, - int, - const ssize_t *, - ssize_t, - ssize_t, - int, - const ssize_t *, - ssize_t, - const std::vector &); - -template -sycl::event -boolean_reduction_strided_impl(sycl::queue &exec_q, - size_t iter_nelems, - size_t reduction_nelems, - const char *arg_cp, - char *res_cp, - int iter_nd, - const ssize_t *iter_shape_and_strides, - ssize_t iter_arg_offset, - ssize_t iter_res_offset, - int red_nd, - const ssize_t *reduction_shape_stride, - ssize_t reduction_arg_offset, - const std::vector &depends) -{ - const argTy *arg_tp = reinterpret_cast(arg_cp); - resTy *res_tp = reinterpret_cast(res_cp); - - constexpr resTy identity_val = sycl::known_identity::value; - - const sycl::device &d = exec_q.get_device(); - const auto &sg_sizes = d.get_info(); - size_t wg = choose_workgroup_size<4>(reduction_nelems, sg_sizes); - - sycl::event red_ev; - if (reduction_nelems < wg) { - red_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(depends); - - using InputOutputIterIndexerT = - dpctl::tensor::offset_utils::TwoOffsets_StridedIndexer; - using ReductionIndexerT = - dpctl::tensor::offset_utils::StridedIndexer; - - InputOutputIterIndexerT in_out_iter_indexer{ - iter_nd, iter_arg_offset, iter_res_offset, - iter_shape_and_strides}; - ReductionIndexerT reduction_indexer{red_nd, reduction_arg_offset, - reduction_shape_stride}; - - cgh.parallel_for>( - sycl::range<1>(iter_nelems), - SequentialBooleanReduction( - arg_tp, res_tp, RedOpT(), identity_val, in_out_iter_indexer, - reduction_indexer, reduction_nelems)); - }); - } - else { - sycl::event init_ev = exec_q.submit([&](sycl::handler &cgh) { - using IndexerT = - dpctl::tensor::offset_utils::UnpackedStridedIndexer; - - const ssize_t *const &res_shape = iter_shape_and_strides; - const ssize_t *const &res_strides = - iter_shape_and_strides + 2 * iter_nd; - IndexerT res_indexer(iter_nd, iter_res_offset, res_shape, - res_strides); - - cgh.depends_on(depends); - - cgh.parallel_for(sycl::range<1>(iter_nelems), [=](sycl::id<1> id) { - auto res_offset = res_indexer(id[0]); - res_tp[res_offset] = identity_val; - }); - }); - red_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(init_ev); - - constexpr std::uint8_t dim = 1; - - using InputOutputIterIndexerT = - dpctl::tensor::offset_utils::TwoOffsets_StridedIndexer; - using ReductionIndexerT = - dpctl::tensor::offset_utils::StridedIndexer; - - InputOutputIterIndexerT in_out_iter_indexer{ - iter_nd, iter_arg_offset, iter_res_offset, - iter_shape_and_strides}; - ReductionIndexerT reduction_indexer{red_nd, reduction_arg_offset, - reduction_shape_stride}; - - constexpr size_t preferred_reductions_per_wi = 4; - size_t reductions_per_wi = - (reduction_nelems < preferred_reductions_per_wi * wg) - ? ((reduction_nelems + wg - 1) / wg) - : preferred_reductions_per_wi; - - size_t reduction_groups = - (reduction_nelems + reductions_per_wi * wg - 1) / - (reductions_per_wi * wg); - - auto gws = sycl::range{iter_nelems * reduction_groups * wg}; - auto lws = sycl::range{wg}; - - cgh.parallel_for>( - sycl::nd_range(gws, lws), - StridedBooleanReduction( - arg_tp, res_tp, RedOpT(), GroupOpT(), identity_val, - in_out_iter_indexer, reduction_indexer, reduction_nelems, - iter_nelems, reductions_per_wi)); - }); - } - return red_ev; -} - -template struct AllStridedFactory -{ - fnT get() const - { - using resTy = std::int32_t; - using RedOpT = sycl::logical_and; - using GroupOpT = all_reduce_wg_strided; - - return dpctl::tensor::kernels::boolean_reduction_strided_impl< - srcTy, resTy, RedOpT, GroupOpT>; - } -}; - -template struct AnyStridedFactory -{ - fnT get() const - { - using resTy = std::int32_t; - using RedOpT = sycl::logical_or; - using GroupOpT = any_reduce_wg_strided; - - return dpctl::tensor::kernels::boolean_reduction_strided_impl< - srcTy, resTy, RedOpT, GroupOpT>; - } -}; - -} // namespace kernels -} // namespace tensor -} // namespace dpctl diff --git a/dpctl/tensor/libtensor/include/kernels/linalg_functions/dot_product.hpp b/dpctl/tensor/libtensor/include/kernels/linalg_functions/dot_product.hpp index 6adbe07936..7ccfb3ec72 100644 --- a/dpctl/tensor/libtensor/include/kernels/linalg_functions/dot_product.hpp +++ b/dpctl/tensor/libtensor/include/kernels/linalg_functions/dot_product.hpp @@ -267,9 +267,42 @@ struct DotProductCustomFunctor } }; +template < + typename lhsTy, + typename rhsTy, + typename resTy, + typename BatchIndexerT, + typename RedIndexerT, + template + class kernel_name_token> +sycl::event sequential_dot_product(sycl::queue &exec_q, + const lhsTy *lhs, + const rhsTy *rhs, + resTy *res, + size_t batches, + size_t reduction_nelems, + const BatchIndexerT &batch_indexer, + const RedIndexerT &reduction_indexer, + const std::vector &depends) +{ + sycl::event dot_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + cgh.parallel_for< + kernel_name_token>( + sycl::range<1>(batches), + SequentialDotProduct(lhs, rhs, res, batch_indexer, + reduction_indexer, + reduction_nelems)); + }); + + return dot_ev; +} + template {wg}; auto ndRange = sycl::nd_range<1>(globalRange, localRange); - if constexpr (can_use_reduce_over_group::value) { + if constexpr (can_use_reduce_over_group::value) { using KernelName = - class kernel_name_token; cgh.parallel_for( - ndRange, DotProductFunctor( lhs, rhs, res, ReductionOpT(), batch_indexer, reduction_indexer, reduction_nelems, batches, reductions_per_wi)); } else { - using SlmT = sycl::local_accessor; + using SlmT = sycl::local_accessor; SlmT local_memory = SlmT(localRange, cgh); using KernelName = class custom_reduction_wrapper>; + lhsTy, rhsTy, resTy, ReductionOpT, BatchIndexerT, RedIndexerT>>; cgh.parallel_for( ndRange, - DotProductCustomFunctor( lhs, rhs, res, ReductionOpT(), batch_indexer, reduction_indexer, local_memory, reduction_nelems, batches, @@ -389,31 +422,24 @@ sycl::event dot_product_impl(sycl::queue &exec_q, size_t wg = choose_workgroup_size<4>(reduction_nelems, sg_sizes); if (reduction_nelems < wg) { - sycl::event dot_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(depends); - - using InputOutputBatchIndexerT = - dpctl::tensor::offset_utils::ThreeOffsets_StridedIndexer; - using ReductionIndexerT = - dpctl::tensor::offset_utils::TwoOffsets_StridedIndexer; + using InputOutputBatchIndexerT = + dpctl::tensor::offset_utils::ThreeOffsets_StridedIndexer; + using ReductionIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_StridedIndexer; - const InputOutputBatchIndexerT in_out_batch_indexer{ - batch_nd, batch_lhs_offset, batch_rhs_offset, batch_res_offset, - batch_shape_and_strides}; - const ReductionIndexerT reduction_indexer{ - red_nd, reduction_lhs_offset, reduction_rhs_offset, - reduction_shape_stride}; + const InputOutputBatchIndexerT inp_out_batch_indexer{ + batch_nd, batch_lhs_offset, batch_rhs_offset, batch_res_offset, + batch_shape_and_strides}; + const ReductionIndexerT reduction_indexer{red_nd, reduction_lhs_offset, + reduction_rhs_offset, + reduction_shape_stride}; - cgh.parallel_for>( - sycl::range<1>(batches), - SequentialDotProduct( - lhs_tp, rhs_tp, res_tp, in_out_batch_indexer, - reduction_indexer, reduction_nelems)); - }); + sycl::event dot_ev = + sequential_dot_product( + exec_q, lhs_tp, rhs_tp, res_tp, batches, reduction_nelems, + inp_out_batch_indexer, reduction_indexer, depends); return dot_ev; } @@ -515,37 +541,30 @@ dot_product_contig_impl(sycl::queue &exec_q, size_t wg = choose_workgroup_size<4>(reduction_nelems, sg_sizes); if (reduction_nelems < wg) { - sycl::event dot_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(depends); - - using InputBatchIndexerT = - dpctl::tensor::offset_utils::Strided1DIndexer; - using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; - using InputOutputBatchIndexerT = - dpctl::tensor::offset_utils::ThreeOffsets_CombinedIndexer< - InputBatchIndexerT, InputBatchIndexerT, NoOpIndexerT>; - using ReductionIndexerT = - dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< - NoOpIndexerT, NoOpIndexerT>; + using InputBatchIndexerT = + dpctl::tensor::offset_utils::Strided1DIndexer; + using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputBatchIndexerT = + dpctl::tensor::offset_utils::ThreeOffsets_CombinedIndexer< + InputBatchIndexerT, InputBatchIndexerT, NoOpIndexerT>; + using ReductionIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + NoOpIndexerT, NoOpIndexerT>; - const InputBatchIndexerT inp_batch_indexer{ - 0, static_cast(reduction_nelems), - static_cast(batches)}; - const InputOutputBatchIndexerT inp_out_batch_indexer{ - inp_batch_indexer, inp_batch_indexer, NoOpIndexerT{}}; - constexpr ReductionIndexerT reduction_indexer{NoOpIndexerT{}, - NoOpIndexerT{}}; + const InputBatchIndexerT inp_batch_indexer{ + 0, static_cast(reduction_nelems), + static_cast(batches)}; + const InputOutputBatchIndexerT inp_out_batch_indexer{ + inp_batch_indexer, inp_batch_indexer, NoOpIndexerT{}}; + constexpr ReductionIndexerT reduction_indexer{NoOpIndexerT{}, + NoOpIndexerT{}}; - cgh.parallel_for>( - sycl::range<1>(batches), - SequentialDotProduct( - lhs_tp, rhs_tp, res_tp, inp_out_batch_indexer, - reduction_indexer, reduction_nelems)); - }); + sycl::event dot_ev = + sequential_dot_product( + exec_q, lhs_tp, rhs_tp, res_tp, batches, reduction_nelems, + inp_out_batch_indexer, reduction_indexer, depends); return dot_ev; } @@ -795,7 +814,7 @@ struct DotProductNoAtomicCustomFunctor template {wg}; auto ndRange = sycl::nd_range<1>(globalRange, localRange); - if constexpr (can_use_reduce_over_group::value) { + if constexpr (can_use_reduce_over_group::value) { using KernelName = - class kernel_name_token; cgh.parallel_for( ndRange, - DotProductNoAtomicFunctor( lhs, rhs, res, ReductionOpT(), batch_indexer, reduction_indexer, reduction_nelems, batches, reductions_per_wi)); } else { - using SlmT = sycl::local_accessor; + using SlmT = sycl::local_accessor; SlmT local_memory = SlmT(localRange, cgh); using KernelName = class custom_reduction_wrapper>; + lhsTy, rhsTy, resTy, ReductionOpT, BatchIndexerT, RedIndexerT>>; cgh.parallel_for( ndRange, - DotProductNoAtomicCustomFunctor( lhs, rhs, res, ReductionOpT(), batch_indexer, @@ -898,40 +917,31 @@ sycl::event dot_product_tree_impl(sycl::queue &exec_q, size_t wg = choose_workgroup_size<4>(reduction_nelems, sg_sizes); if (reduction_nelems < wg) { - sycl::event dot_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(depends); - - using InputOutputBatchIndexerT = - dpctl::tensor::offset_utils::ThreeOffsets_StridedIndexer; - using ReductionIndexerT = - dpctl::tensor::offset_utils::TwoOffsets_StridedIndexer; + using InputOutputBatchIndexerT = + dpctl::tensor::offset_utils::ThreeOffsets_StridedIndexer; + using ReductionIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_StridedIndexer; - const InputOutputBatchIndexerT in_out_batch_indexer{ - batch_nd, batch_lhs_offset, batch_rhs_offset, batch_res_offset, - batch_shape_and_strides}; - const ReductionIndexerT reduction_indexer{ - red_nd, reduction_lhs_offset, reduction_rhs_offset, - reduction_shape_stride}; + const InputOutputBatchIndexerT inp_out_batch_indexer{ + batch_nd, batch_lhs_offset, batch_rhs_offset, batch_res_offset, + batch_shape_and_strides}; + const ReductionIndexerT reduction_indexer{red_nd, reduction_lhs_offset, + reduction_rhs_offset, + reduction_shape_stride}; - cgh.parallel_for>( - sycl::range<1>(batches), - SequentialDotProduct( - lhs_tp, rhs_tp, res_tp, in_out_batch_indexer, - reduction_indexer, reduction_nelems)); - }); + sycl::event dot_ev = + sequential_dot_product( + exec_q, lhs_tp, rhs_tp, res_tp, batches, reduction_nelems, + inp_out_batch_indexer, reduction_indexer, depends); return dot_ev; } constexpr size_t preferred_reductions_per_wi = 8; // prevents running out of resources on CPU - size_t max_wg = - std::min(size_t(2048), - d.get_info() / 2); + size_t max_wg = reduction_detail::get_work_group_size(d); using ReductionOpT = typename std::conditional, sycl::logical_or, @@ -1153,46 +1163,37 @@ dot_product_contig_tree_impl(sycl::queue &exec_q, size_t wg = choose_workgroup_size<4>(reduction_nelems, sg_sizes); if (reduction_nelems < wg) { - sycl::event dot_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(depends); - - using InputBatchIndexerT = - dpctl::tensor::offset_utils::Strided1DIndexer; - using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; - using InputOutputBatchIndexerT = - dpctl::tensor::offset_utils::ThreeOffsets_CombinedIndexer< - InputBatchIndexerT, InputBatchIndexerT, NoOpIndexerT>; - using ReductionIndexerT = - dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< - NoOpIndexerT, NoOpIndexerT>; + using InputBatchIndexerT = + dpctl::tensor::offset_utils::Strided1DIndexer; + using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputBatchIndexerT = + dpctl::tensor::offset_utils::ThreeOffsets_CombinedIndexer< + InputBatchIndexerT, InputBatchIndexerT, NoOpIndexerT>; + using ReductionIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + NoOpIndexerT, NoOpIndexerT>; - const InputBatchIndexerT inp_batch_indexer{ - 0, static_cast(reduction_nelems), - static_cast(batches)}; - const InputOutputBatchIndexerT inp_out_batch_indexer{ - inp_batch_indexer, inp_batch_indexer, NoOpIndexerT{}}; - constexpr ReductionIndexerT reduction_indexer{NoOpIndexerT{}, - NoOpIndexerT{}}; + const InputBatchIndexerT inp_batch_indexer{ + 0, static_cast(reduction_nelems), + static_cast(batches)}; + const InputOutputBatchIndexerT inp_out_batch_indexer{ + inp_batch_indexer, inp_batch_indexer, NoOpIndexerT{}}; + constexpr ReductionIndexerT reduction_indexer{NoOpIndexerT{}, + NoOpIndexerT{}}; - cgh.parallel_for>( - sycl::range<1>(batches), - SequentialDotProduct( - lhs_tp, rhs_tp, res_tp, inp_out_batch_indexer, - reduction_indexer, reduction_nelems)); - }); + sycl::event dot_ev = + sequential_dot_product( + exec_q, lhs_tp, rhs_tp, res_tp, batches, reduction_nelems, + inp_out_batch_indexer, reduction_indexer, depends); return dot_ev; } constexpr size_t preferred_reductions_per_wi = 8; // prevents running out of resources on CPU - size_t max_wg = - std::min(size_t(2048), - d.get_info() / 2); + size_t max_wg = reduction_detail::get_work_group_size(d); using ReductionOpT = typename std::conditional, sycl::logical_or, diff --git a/dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp b/dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp index bb730d398d..1616b41080 100644 --- a/dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp +++ b/dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp @@ -2584,11 +2584,7 @@ gemm_batch_tree_nm_impl(sycl::queue &exec_q, (reduction_nelems + preferred_reductions_per_wi * wg - 1) / (preferred_reductions_per_wi * wg); - // max_max_wg prevents running out of resources on CPU - constexpr size_t max_max_wg = 2048; - size_t max_wg = std::min( - max_max_wg, - dev.get_info() / 2); + size_t max_wg = reduction_detail::get_work_group_size(dev); if (reduction_nelems <= preferred_reductions_per_wi * max_wg) { resTy *tmp = sycl::malloc_device( @@ -2666,17 +2662,20 @@ gemm_batch_tree_nm_impl(sycl::queue &exec_q, using OuterInnerDimsIndexerT = dpctl::tensor::offset_utils::StridedIndexer; using TmpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + const OuterInnerDimsIndexerT lhs_indexer( inner_nd + lhs_outer_nd, 0, lhs_outer_inner_shapes_strides); const OuterInnerDimsIndexerT rhs_indexer( inner_nd + rhs_outer_nd, 0, rhs_outer_inner_shapes_strides); constexpr TmpIndexerT res_indexer{}; + using dpctl::tensor::offset_utils::Strided1DIndexer; using dpctl::tensor::offset_utils::StridedIndexer; using dpctl::tensor::offset_utils::ThreeOffsets_CombinedIndexer; using dpctl::tensor::offset_utils::UnpackedStridedIndexer; using BatchDimsIndexerT = ThreeOffsets_CombinedIndexer< StridedIndexer, UnpackedStridedIndexer, Strided1DIndexer>; + const StridedIndexer lhs_batch_indexer(batch_nd, lhs_batch_offset, batch_shape_strides); const UnpackedStridedIndexer rhs_batch_indexer( @@ -2969,11 +2968,7 @@ gemm_batch_contig_tree_k_impl(sycl::queue &exec_q, (reduction_nelems + preferred_reductions_per_wi * wg - 1) / (preferred_reductions_per_wi * wg); - // max_max_wg prevents running out of resources on CPU - constexpr size_t max_max_wg = 2048; - size_t max_wg = std::min( - max_max_wg, - dev.get_info() / 2); + size_t max_wg = reduction_detail::get_work_group_size(dev); if (reduction_nelems <= preferred_reductions_per_wi * max_wg) { resTy *tmp = sycl::malloc_device( @@ -3172,11 +3167,7 @@ gemm_batch_contig_tree_nm_impl(sycl::queue &exec_q, (reduction_nelems + preferred_reductions_per_wi * wg - 1) / (preferred_reductions_per_wi * wg); - // max_max_wg prevents running out of resources on CPU - constexpr size_t max_max_wg = 2048; - size_t max_wg = std::min( - max_max_wg, - dev.get_info() / 2); + size_t max_wg = reduction_detail::get_work_group_size(dev); if (reduction_nelems <= preferred_reductions_per_wi * max_wg) { resTy *tmp = sycl::malloc_device( @@ -3558,11 +3549,7 @@ sycl::event gemm_tree_k_impl(sycl::queue &exec_q, (reduction_nelems + preferred_reductions_per_wi * wg - 1) / (preferred_reductions_per_wi * wg); - // max_max_wg prevents running out of resources on CPU - constexpr size_t max_max_wg = 2048; - size_t max_wg = std::min( - max_max_wg, - dev.get_info() / 2); + size_t max_wg = reduction_detail::get_work_group_size(dev); if (reduction_nelems <= preferred_reductions_per_wi * max_wg) { resTy *tmp = sycl::malloc_device( @@ -3728,11 +3715,7 @@ sycl::event gemm_tree_nm_impl(sycl::queue &exec_q, (reduction_nelems + preferred_reductions_per_wi * wg - 1) / (preferred_reductions_per_wi * wg); - // max_max_wg prevents running out of resources on CPU - constexpr size_t max_max_wg = 2048; - size_t max_wg = std::min( - max_max_wg, - dev.get_info() / 2); + size_t max_wg = reduction_detail::get_work_group_size(dev); if (reduction_nelems <= preferred_reductions_per_wi * max_wg) { resTy *tmp = sycl::malloc_device( @@ -3982,11 +3965,7 @@ sycl::event gemm_contig_tree_k_impl(sycl::queue &exec_q, (reduction_nelems + preferred_reductions_per_wi * wg - 1) / (preferred_reductions_per_wi * wg); - // max_max_wg prevents running out of resources on CPU - constexpr size_t max_max_wg = 2048; - size_t max_wg = std::min( - max_max_wg, - dev.get_info() / 2); + size_t max_wg = reduction_detail::get_work_group_size(dev); if (reduction_nelems <= preferred_reductions_per_wi * max_wg) { resTy *tmp = sycl::malloc_device( @@ -4139,11 +4118,7 @@ sycl::event gemm_contig_tree_nm_impl(sycl::queue &exec_q, (reduction_nelems + preferred_reductions_per_wi * wg - 1) / (preferred_reductions_per_wi * wg); - // max_max_wg prevents running out of resources on CPU - constexpr size_t max_max_wg = 2048; - size_t max_wg = std::min( - max_max_wg, - dev.get_info() / 2); + size_t max_wg = reduction_detail::get_work_group_size(dev); if (reduction_nelems <= preferred_reductions_per_wi * max_wg) { resTy *tmp = sycl::malloc_device( diff --git a/dpctl/tensor/libtensor/include/kernels/reductions.hpp b/dpctl/tensor/libtensor/include/kernels/reductions.hpp index 341e4739fb..f1e0dff15d 100644 --- a/dpctl/tensor/libtensor/include/kernels/reductions.hpp +++ b/dpctl/tensor/libtensor/include/kernels/reductions.hpp @@ -48,6 +48,18 @@ namespace tensor namespace kernels { +namespace reduction_detail +{ + +inline size_t get_work_group_size(const sycl::device &d) +{ + // prevents running out of resources on CPU + return std::min( + 2048, d.get_info() / 2); +} + +} // namespace reduction_detail + template struct needs_workaround { static constexpr bool value = @@ -111,7 +123,15 @@ struct SequentialReduction const ssize_t inp_offset = inp_iter_offset + inp_reduction_offset; using dpctl::tensor::type_utils::convert_impl; - outT val = convert_impl(inp_[inp_offset]); + outT val; + if constexpr (su_ns::IsLogicalAnd::value || + su_ns::IsLogicalOr::value) + { + val = convert_impl(inp_[inp_offset]); + } + else { + val = convert_impl(inp_[inp_offset]); + } red_val = reduction_op_(red_val, val); } @@ -194,15 +214,35 @@ struct ReductionOverGroupWithAtomicFunctor auto inp_offset = inp_iter_offset + inp_reduction_offset; using dpctl::tensor::type_utils::convert_impl; - outT val = convert_impl(inp_[inp_offset]); + outT val; + if constexpr (su_ns::IsLogicalAnd::value || + su_ns::IsLogicalOr::value) + { + // handle nans + val = convert_impl(inp_[inp_offset]); + } + else { + val = convert_impl(inp_[inp_offset]); + } local_red_val = reduction_op_(local_red_val, val); } auto work_group = it.get_group(); // This only works if reduction_op_ is from small set of operators - outT red_val_over_wg = sycl::reduce_over_group( - work_group, local_red_val, identity_, reduction_op_); + outT red_val_over_wg; + if constexpr (su_ns::IsLogicalAnd::value) { + red_val_over_wg = static_cast( + sycl::all_of_group(work_group, local_red_val)); + } + else if constexpr (su_ns::IsLogicalOr::value) { + red_val_over_wg = static_cast( + sycl::any_of_group(work_group, local_red_val)); + } + else { + red_val_over_wg = sycl::reduce_over_group(work_group, local_red_val, + identity_, reduction_op_); + } if (work_group.leader()) { sycl::atomic_ref::value) { res_ref += red_val_over_wg; } - else if constexpr (std::is_same_v>) - { + else if constexpr (su_ns::IsMaximum::value) { res_ref.fetch_max(red_val_over_wg); } - else if constexpr (std::is_same_v>) - { + else if constexpr (su_ns::IsMinimum::value) { res_ref.fetch_min(red_val_over_wg); } + else if constexpr (su_ns::IsLogicalAnd::value) { + res_ref.fetch_and(red_val_over_wg); + } + else if constexpr (su_ns::IsLogicalOr::value) { + res_ref.fetch_or(red_val_over_wg); + } else { outT read_val = res_ref.load(); outT new_val{}; @@ -304,7 +348,16 @@ struct CustomReductionOverGroupWithAtomicFunctor auto inp_offset = inp_iter_offset + inp_reduction_offset; using dpctl::tensor::type_utils::convert_impl; - outT val = convert_impl(inp_[inp_offset]); + outT val; + if constexpr (su_ns::IsLogicalAnd::value || + su_ns::IsLogicalOr::value) + { + // handle nans + val = convert_impl(inp_[inp_offset]); + } + else { + val = convert_impl(inp_[inp_offset]); + } local_red_val = reduction_op_(local_red_val, val); } @@ -318,28 +371,280 @@ struct CustomReductionOverGroupWithAtomicFunctor sycl::memory_scope::device, sycl::access::address_space::global_space> res_ref(out_[out_iter_offset]); - outT read_val = res_ref.load(); - outT new_val{}; - do { - new_val = reduction_op_(read_val, red_val_over_wg); - } while (!res_ref.compare_exchange_strong(read_val, new_val)); + // retain these checks in case a reduce_over_group work-around is + // needed + if constexpr (su_ns::IsSyclPlus::value) { + res_ref += red_val_over_wg; + } + else if constexpr (su_ns::IsSyclMaximum::value) { + res_ref.fetch_max(red_val_over_wg); + } + else if constexpr (su_ns::IsSyclMinimum::value) { + res_ref.fetch_min(red_val_over_wg); + } + else if constexpr (su_ns::IsSyclLogicalAnd::value) { + res_ref.fetch_and(red_val_over_wg); + } + else if constexpr (su_ns::IsSyclLogicalOr::value) + { + res_ref.fetch_or(red_val_over_wg); + } + else { + outT read_val = res_ref.load(); + outT new_val{}; + do { + new_val = reduction_op_(read_val, red_val_over_wg); + } while (!res_ref.compare_exchange_strong(read_val, new_val)); + } } } }; -template -class reduction_over_group_with_atomics_krn; +template +struct ReductionOverGroupNoAtomicFunctor +{ +private: + const argT *inp_ = nullptr; + outT *out_ = nullptr; + const ReductionOp reduction_op_; + const outT identity_; + const InputOutputIterIndexerT inp_out_iter_indexer_; + const InputRedIndexerT inp_reduced_dims_indexer_; + size_t reduction_max_gid_ = 0; + size_t iter_gws_ = 1; + size_t reductions_per_wi = 16; -template class custom_reduction_wrapper; +public: + ReductionOverGroupNoAtomicFunctor( + const argT *data, + outT *res, + const ReductionOp &reduction_op, + const outT &identity_val, + const InputOutputIterIndexerT &arg_res_iter_indexer, + const InputRedIndexerT &arg_reduced_dims_indexer, + size_t reduction_size, + size_t iteration_size, + size_t reduction_size_per_wi) + : inp_(data), out_(res), reduction_op_(reduction_op), + identity_(identity_val), inp_out_iter_indexer_(arg_res_iter_indexer), + inp_reduced_dims_indexer_(arg_reduced_dims_indexer), + reduction_max_gid_(reduction_size), iter_gws_(iteration_size), + reductions_per_wi(reduction_size_per_wi) + { + } -template -class reduction_over_group_with_atomics_init_krn; + void operator()(sycl::nd_item<1> it) const + { + const size_t reduction_lid = it.get_local_id(0); + const size_t wg = it.get_local_range(0); // 0 <= reduction_lid < wg -template -class reduction_seq_strided_krn; + const size_t iter_gid = it.get_group(0) % iter_gws_; + const size_t reduction_batch_id = it.get_group(0) / iter_gws_; + const size_t n_reduction_groups = it.get_group_range(0) / iter_gws_; -template -class reduction_seq_contig_krn; + // work-items operates over input with indices + // inp_data_id = reduction_batch_id * wg * reductions_per_wi + m * wg + // + reduction_lid + // for 0 <= m < reductions_per_wi + + const auto &inp_out_iter_offsets_ = inp_out_iter_indexer_(iter_gid); + const auto &inp_iter_offset = inp_out_iter_offsets_.get_first_offset(); + const auto &out_iter_offset = inp_out_iter_offsets_.get_second_offset(); + + outT local_red_val(identity_); + size_t arg_reduce_gid0 = + reduction_lid + reduction_batch_id * wg * reductions_per_wi; + for (size_t m = 0; m < reductions_per_wi; ++m) { + size_t arg_reduce_gid = arg_reduce_gid0 + m * wg; + + if (arg_reduce_gid < reduction_max_gid_) { + auto inp_reduction_offset = + inp_reduced_dims_indexer_(arg_reduce_gid); + auto inp_offset = inp_iter_offset + inp_reduction_offset; + + using dpctl::tensor::type_utils::convert_impl; + outT val; + if constexpr (su_ns::IsLogicalAnd::value || + su_ns::IsLogicalOr::value) + { + // handle nans + val = convert_impl(inp_[inp_offset]); + } + else { + val = convert_impl(inp_[inp_offset]); + } + + local_red_val = reduction_op_(local_red_val, val); + } + } + + auto work_group = it.get_group(); + // This only works if reduction_op_ is from small set of operators + outT red_val_over_wg; + if constexpr (su_ns::IsLogicalAnd::value) { + red_val_over_wg = sycl::all_of_group(work_group, local_red_val); + } + else if constexpr (su_ns::IsLogicalOr::value) { + red_val_over_wg = sycl::any_of_group(work_group, local_red_val); + } + else { + red_val_over_wg = sycl::reduce_over_group(work_group, local_red_val, + identity_, reduction_op_); + } + + if (work_group.leader()) { + // each group writes to a different memory location + out_[out_iter_offset * n_reduction_groups + reduction_batch_id] = + red_val_over_wg; + } + } +}; + +/* = Reduction, using custom_reduce_over_group and not using atomic_ref*/ + +template +struct CustomReductionOverGroupNoAtomicFunctor +{ +private: + const argT *inp_ = nullptr; + outT *out_ = nullptr; + const ReductionOp reduction_op_; + outT identity_; + const InputOutputIterIndexerT inp_out_iter_indexer_; + const InputRedIndexerT inp_reduced_dims_indexer_; + SlmT local_mem_; + size_t reduction_max_gid_ = 0; + size_t iter_gws_ = 1; + size_t reductions_per_wi = 16; + +public: + CustomReductionOverGroupNoAtomicFunctor( + const argT *data, + outT *res, + const ReductionOp &reduction_op, + const outT &identity_val, + const InputOutputIterIndexerT &arg_res_iter_indexer, + const InputRedIndexerT &arg_reduced_dims_indexer, + SlmT local_mem, + size_t reduction_size, + size_t iteration_size, + size_t reduction_size_per_wi) + : inp_(data), out_(res), reduction_op_(reduction_op), + identity_(identity_val), inp_out_iter_indexer_(arg_res_iter_indexer), + inp_reduced_dims_indexer_(arg_reduced_dims_indexer), + local_mem_(local_mem), reduction_max_gid_(reduction_size), + iter_gws_(iteration_size), reductions_per_wi(reduction_size_per_wi) + { + } + + void operator()(sycl::nd_item<1> it) const + { + const size_t reduction_lid = it.get_local_id(0); + const size_t wg = it.get_local_range(0); // 0 <= reduction_lid < wg + + const size_t iter_gid = it.get_group(0) % iter_gws_; + const size_t reduction_batch_id = it.get_group(0) / iter_gws_; + const size_t n_reduction_groups = it.get_group_range(0) / iter_gws_; + + // work-items operates over input with indices + // inp_data_id = reduction_batch_id * wg * reductions_per_wi + m * wg + // + reduction_lid + // for 0 <= m < reductions_per_wi + + auto inp_out_iter_offsets_ = inp_out_iter_indexer_(iter_gid); + const auto &inp_iter_offset = inp_out_iter_offsets_.get_first_offset(); + const auto &out_iter_offset = inp_out_iter_offsets_.get_second_offset(); + + outT local_red_val(identity_); + size_t arg_reduce_gid0 = + reduction_lid + reduction_batch_id * wg * reductions_per_wi; + for (size_t m = 0; m < reductions_per_wi; ++m) { + size_t arg_reduce_gid = arg_reduce_gid0 + m * wg; + + if (arg_reduce_gid < reduction_max_gid_) { + auto inp_reduction_offset = + inp_reduced_dims_indexer_(arg_reduce_gid); + auto inp_offset = inp_iter_offset + inp_reduction_offset; + + using dpctl::tensor::type_utils::convert_impl; + outT val; + if constexpr (std::is_same_v> || + std::is_same_v>) + { + // handle nans + val = convert_impl(inp_[inp_offset]); + } + else { + val = convert_impl(inp_[inp_offset]); + } + + local_red_val = reduction_op_(local_red_val, val); + } + } + + auto work_group = it.get_group(); + // This only works if reduction_op_ is from small set of operators + outT red_val_over_wg = su_ns::custom_reduce_over_group( + work_group, local_mem_, local_red_val, reduction_op_); + + if (work_group.leader()) { + // each group writes to a different memory location + out_[out_iter_offset * n_reduction_groups + reduction_batch_id] = + red_val_over_wg; + } + } +}; + +template < + typename argTy, + typename resTy, + typename ReductionOpT, + typename InputOutputIterIndexerT, + typename ReductionIndexerT, + template + class kernel_name_token> +sycl::event +sequential_reduction(sycl::queue &exec_q, + const argTy *arg, + resTy *res, + resTy identity_val, + size_t iter_nelems, + size_t reduction_nelems, + const InputOutputIterIndexerT &in_out_iter_indexer, + const ReductionIndexerT &reduction_indexer, + const std::vector &depends) +{ + sycl::event red_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using KernelName = + class kernel_name_token; + + cgh.parallel_for( + sycl::range<1>(iter_nelems), + SequentialReduction( + arg, res, ReductionOpT(), identity_val, in_out_iter_indexer, + reduction_indexer, reduction_nelems)); + }); + + return red_ev; +} + +template class custom_reduction_wrapper; template < typename argTy, @@ -406,9 +711,18 @@ submit_atomic_reduction(sycl::queue &exec_q, return red_ev; } -typedef sycl::event (*reduction_strided_impl_fn_ptr)( - sycl::queue &, - size_t, +template +class reduction_over_group_with_atomics_init_krn; + +template +class reduction_seq_krn; + +template +class reduction_over_group_with_atomics_krn; + +typedef sycl::event (*reduction_strided_impl_fn_ptr)( + sycl::queue &, + size_t, size_t, const char *, char *, @@ -460,18 +774,13 @@ sycl::event reduction_over_group_with_atomics_strided_impl( const ReductionIndexerT reduction_indexer{red_nd, reduction_arg_offset, reduction_shape_stride}; - sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(depends); - - cgh.parallel_for>( - sycl::range<1>(iter_nelems), - SequentialReduction( - arg_tp, res_tp, ReductionOpT(), identity_val, - in_out_iter_indexer, reduction_indexer, reduction_nelems)); - }); + sycl::event comp_ev = + sequential_reduction( + exec_q, arg_tp, res_tp, identity_val, iter_nelems, + reduction_nelems, in_out_iter_indexer, reduction_indexer, + depends); return comp_ev; } @@ -580,18 +889,13 @@ sycl::event reduction_axis1_over_group_with_atomics_contig_impl( NoOpIndexerT{}}; constexpr ReductionIndexerT reduction_indexer{}; - sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(depends); - - cgh.parallel_for>( - sycl::range<1>(iter_nelems), - SequentialReduction( - arg_tp, res_tp, ReductionOpT(), identity_val, - in_out_iter_indexer, reduction_indexer, reduction_nelems)); - }); + sycl::event comp_ev = + sequential_reduction( + exec_q, arg_tp, res_tp, identity_val, iter_nelems, + reduction_nelems, in_out_iter_indexer, reduction_indexer, + depends); return comp_ev; } @@ -673,23 +977,13 @@ sycl::event reduction_axis0_over_group_with_atomics_contig_impl( 0, static_cast(reduction_nelems), static_cast(iter_nelems)}; - sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(depends); - - using KernelName = - class reduction_seq_contig_krn; - - sycl::range<1> iter_range{iter_nelems}; - - cgh.parallel_for( - iter_range, - SequentialReduction( - arg_tp, res_tp, ReductionOpT(), identity_val, - in_out_iter_indexer, reduction_indexer, reduction_nelems)); - }); + sycl::event comp_ev = + sequential_reduction( + exec_q, arg_tp, res_tp, identity_val, iter_nelems, + reduction_nelems, in_out_iter_indexer, reduction_indexer, + depends); return comp_ev; } @@ -736,204 +1030,6 @@ sycl::event reduction_axis0_over_group_with_atomics_contig_impl( /* = Reduction, using sycl::reduce_over_group, but not using atomic_ref = */ -template -struct ReductionOverGroupNoAtomicFunctor -{ -private: - const argT *inp_ = nullptr; - outT *out_ = nullptr; - const ReductionOp reduction_op_; - const outT identity_; - const InputOutputIterIndexerT inp_out_iter_indexer_; - const InputRedIndexerT inp_reduced_dims_indexer_; - size_t reduction_max_gid_ = 0; - size_t iter_gws_ = 1; - size_t reductions_per_wi = 16; - -public: - ReductionOverGroupNoAtomicFunctor( - const argT *data, - outT *res, - const ReductionOp &reduction_op, - const outT &identity_val, - const InputOutputIterIndexerT &arg_res_iter_indexer, - const InputRedIndexerT &arg_reduced_dims_indexer, - size_t reduction_size, - size_t iteration_size, - size_t reduction_size_per_wi) - : inp_(data), out_(res), reduction_op_(reduction_op), - identity_(identity_val), inp_out_iter_indexer_(arg_res_iter_indexer), - inp_reduced_dims_indexer_(arg_reduced_dims_indexer), - reduction_max_gid_(reduction_size), iter_gws_(iteration_size), - reductions_per_wi(reduction_size_per_wi) - { - } - - void operator()(sycl::nd_item<1> it) const - { - const size_t reduction_lid = it.get_local_id(0); - const size_t wg = it.get_local_range(0); // 0 <= reduction_lid < wg - - const size_t iter_gid = it.get_group(0) % iter_gws_; - const size_t reduction_batch_id = it.get_group(0) / iter_gws_; - const size_t n_reduction_groups = it.get_group_range(0) / iter_gws_; - - // work-items operates over input with indices - // inp_data_id = reduction_batch_id * wg * reductions_per_wi + m * wg - // + reduction_lid - // for 0 <= m < reductions_per_wi - - const auto &inp_out_iter_offsets_ = inp_out_iter_indexer_(iter_gid); - const auto &inp_iter_offset = inp_out_iter_offsets_.get_first_offset(); - const auto &out_iter_offset = inp_out_iter_offsets_.get_second_offset(); - - outT local_red_val(identity_); - size_t arg_reduce_gid0 = - reduction_lid + reduction_batch_id * wg * reductions_per_wi; - for (size_t m = 0; m < reductions_per_wi; ++m) { - size_t arg_reduce_gid = arg_reduce_gid0 + m * wg; - - if (arg_reduce_gid < reduction_max_gid_) { - auto inp_reduction_offset = - inp_reduced_dims_indexer_(arg_reduce_gid); - auto inp_offset = inp_iter_offset + inp_reduction_offset; - - using dpctl::tensor::type_utils::convert_impl; - outT val = convert_impl(inp_[inp_offset]); - - local_red_val = reduction_op_(local_red_val, val); - } - } - - auto work_group = it.get_group(); - // This only works if reduction_op_ is from small set of operators - outT red_val_over_wg = sycl::reduce_over_group( - work_group, local_red_val, identity_, reduction_op_); - - if (work_group.leader()) { - // each group writes to a different memory location - out_[out_iter_offset * n_reduction_groups + reduction_batch_id] = - red_val_over_wg; - } - } -}; - -/* = Reduction, using custom_reduce_over_group and not using atomic_ref*/ - -template -struct CustomReductionOverGroupNoAtomicFunctor -{ -private: - const argT *inp_ = nullptr; - outT *out_ = nullptr; - const ReductionOp reduction_op_; - outT identity_; - const InputOutputIterIndexerT inp_out_iter_indexer_; - const InputRedIndexerT inp_reduced_dims_indexer_; - SlmT local_mem_; - size_t reduction_max_gid_ = 0; - size_t iter_gws_ = 1; - size_t reductions_per_wi = 16; - -public: - CustomReductionOverGroupNoAtomicFunctor( - const argT *data, - outT *res, - const ReductionOp &reduction_op, - const outT &identity_val, - const InputOutputIterIndexerT &arg_res_iter_indexer, - const InputRedIndexerT &arg_reduced_dims_indexer, - SlmT local_mem, - size_t reduction_size, - size_t iteration_size, - size_t reduction_size_per_wi) - : inp_(data), out_(res), reduction_op_(reduction_op), - identity_(identity_val), inp_out_iter_indexer_(arg_res_iter_indexer), - inp_reduced_dims_indexer_(arg_reduced_dims_indexer), - local_mem_(local_mem), reduction_max_gid_(reduction_size), - iter_gws_(iteration_size), reductions_per_wi(reduction_size_per_wi) - { - } - - void operator()(sycl::nd_item<1> it) const - { - const size_t reduction_lid = it.get_local_id(0); - const size_t wg = it.get_local_range(0); // 0 <= reduction_lid < wg - - const size_t iter_gid = it.get_group(0) % iter_gws_; - const size_t reduction_batch_id = it.get_group(0) / iter_gws_; - const size_t n_reduction_groups = it.get_group_range(0) / iter_gws_; - - // work-items operates over input with indices - // inp_data_id = reduction_batch_id * wg * reductions_per_wi + m * wg - // + reduction_lid - // for 0 <= m < reductions_per_wi - - auto inp_out_iter_offsets_ = inp_out_iter_indexer_(iter_gid); - const auto &inp_iter_offset = inp_out_iter_offsets_.get_first_offset(); - const auto &out_iter_offset = inp_out_iter_offsets_.get_second_offset(); - - outT local_red_val(identity_); - size_t arg_reduce_gid0 = - reduction_lid + reduction_batch_id * wg * reductions_per_wi; - for (size_t m = 0; m < reductions_per_wi; ++m) { - size_t arg_reduce_gid = arg_reduce_gid0 + m * wg; - - if (arg_reduce_gid < reduction_max_gid_) { - auto inp_reduction_offset = - inp_reduced_dims_indexer_(arg_reduce_gid); - auto inp_offset = inp_iter_offset + inp_reduction_offset; - - using dpctl::tensor::type_utils::convert_impl; - outT val = convert_impl(inp_[inp_offset]); - - local_red_val = reduction_op_(local_red_val, val); - } - } - - auto work_group = it.get_group(); - // This only works if reduction_op_ is from small set of operators - outT red_val_over_wg = su_ns::custom_reduce_over_group( - work_group, local_mem_, local_red_val, reduction_op_); - - if (work_group.leader()) { - // each group writes to a different memory location - out_[out_iter_offset * n_reduction_groups + reduction_batch_id] = - red_val_over_wg; - } - } -}; - -typedef sycl::event (*reduction_strided_impl_fn_ptr)( - sycl::queue &, - size_t, - size_t, - const char *, - char *, - int, - const ssize_t *, - ssize_t, - ssize_t, - int, - const ssize_t *, - ssize_t, - const std::vector &); - -template -class reduction_over_group_temps_krn; - -template -class reduction_over_group_temps_empty_krn; - template < typename argTy, typename resTy, @@ -998,27 +1094,38 @@ submit_no_atomic_reduction(sycl::queue &exec_q, return red_ev; } -namespace detail -{ -inline size_t get_work_group_size(const sycl::device &d) -{ - // prevents running out of resources on CPU - return std::min( - 2048, d.get_info() / 2); -} -} // namespace detail +template +class reduction_over_group_temps_krn; -template -sycl::event reduction_over_group_temps_strided_impl( - sycl::queue &exec_q, - size_t iter_nelems, // number of reductions (num. of rows in a matrix - // when reducing over rows) - size_t reduction_nelems, // size of each reduction (length of rows, i.e. - // number of columns) - const char *arg_cp, - char *res_cp, - int iter_nd, - const ssize_t *iter_shape_and_strides, +typedef sycl::event (*reduction_strided_impl_fn_ptr)( + sycl::queue &, + size_t, + size_t, + const char *, + char *, + int, + const ssize_t *, + ssize_t, + ssize_t, + int, + const ssize_t *, + ssize_t, + const std::vector &); + +template +class reduction_over_group_temps_empty_krn; + +template +sycl::event reduction_over_group_temps_strided_impl( + sycl::queue &exec_q, + size_t iter_nelems, // number of reductions (num. of rows in a matrix + // when reducing over rows) + size_t reduction_nelems, // size of each reduction (length of rows, i.e. + // number of columns) + const char *arg_cp, + char *res_cp, + int iter_nd, + const ssize_t *iter_shape_and_strides, ssize_t iter_arg_offset, ssize_t iter_res_offset, int red_nd, @@ -1061,36 +1168,29 @@ sycl::event reduction_over_group_temps_strided_impl( size_t wg = choose_workgroup_size<4>(reduction_nelems, sg_sizes); if (reduction_nelems < wg) { - sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(depends); - - using InputOutputIterIndexerT = - dpctl::tensor::offset_utils::TwoOffsets_StridedIndexer; - using ReductionIndexerT = - dpctl::tensor::offset_utils::StridedIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_StridedIndexer; + using ReductionIndexerT = dpctl::tensor::offset_utils::StridedIndexer; - const InputOutputIterIndexerT in_out_iter_indexer{ - iter_nd, iter_arg_offset, iter_res_offset, - iter_shape_and_strides}; - const ReductionIndexerT reduction_indexer{ - red_nd, reduction_arg_offset, reduction_shape_stride}; + const InputOutputIterIndexerT in_out_iter_indexer{ + iter_nd, iter_arg_offset, iter_res_offset, iter_shape_and_strides}; + const ReductionIndexerT reduction_indexer{red_nd, reduction_arg_offset, + reduction_shape_stride}; - cgh.parallel_for>( - sycl::range<1>(iter_nelems), - SequentialReduction( - arg_tp, res_tp, ReductionOpT(), identity_val, - in_out_iter_indexer, reduction_indexer, reduction_nelems)); - }); + sycl::event comp_ev = + sequential_reduction( + exec_q, arg_tp, res_tp, identity_val, iter_nelems, + reduction_nelems, in_out_iter_indexer, reduction_indexer, + depends); return comp_ev; } constexpr size_t preferred_reductions_per_wi = 8; // prevents running out of resources on CPU - size_t max_wg = detail::get_work_group_size(d); + size_t max_wg = reduction_detail::get_work_group_size(d); size_t reductions_per_wi(preferred_reductions_per_wi); if (reduction_nelems <= preferred_reductions_per_wi * max_wg) { @@ -1316,39 +1416,33 @@ sycl::event reduction_axis1_over_group_temps_contig_impl( size_t wg = choose_workgroup_size<4>(reduction_nelems, sg_sizes); if (reduction_nelems < wg) { - sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(depends); - - using InputIterIndexerT = - dpctl::tensor::offset_utils::Strided1DIndexer; - using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; - using InputOutputIterIndexerT = - dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< - InputIterIndexerT, NoOpIndexerT>; - using ReductionIndexerT = NoOpIndexerT; + using InputIterIndexerT = dpctl::tensor::offset_utils::Strided1DIndexer; + using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + InputIterIndexerT, NoOpIndexerT>; + using ReductionIndexerT = NoOpIndexerT; - const InputOutputIterIndexerT in_out_iter_indexer{ - InputIterIndexerT{0, static_cast(iter_nelems), - static_cast(reduction_nelems)}, - NoOpIndexerT{}}; - constexpr ReductionIndexerT reduction_indexer{}; + const InputOutputIterIndexerT in_out_iter_indexer{ + InputIterIndexerT{0, static_cast(iter_nelems), + static_cast(reduction_nelems)}, + NoOpIndexerT{}}; + constexpr ReductionIndexerT reduction_indexer{}; - cgh.parallel_for>( - sycl::range<1>(iter_nelems), - SequentialReduction( - arg_tp, res_tp, ReductionOpT(), identity_val, - in_out_iter_indexer, reduction_indexer, reduction_nelems)); - }); + sycl::event comp_ev = + sequential_reduction( + exec_q, arg_tp, res_tp, identity_val, iter_nelems, + reduction_nelems, in_out_iter_indexer, reduction_indexer, + depends); return comp_ev; } constexpr size_t preferred_reductions_per_wi = 8; // prevents running out of resources on CPU - size_t max_wg = detail::get_work_group_size(d); + size_t max_wg = reduction_detail::get_work_group_size(d); size_t reductions_per_wi(preferred_reductions_per_wi); if (reduction_nelems <= preferred_reductions_per_wi * max_wg) { @@ -1564,43 +1658,32 @@ sycl::event reduction_axis0_over_group_temps_contig_impl( size_t wg = choose_workgroup_size<4>(reduction_nelems, sg_sizes); if (reduction_nelems < wg) { - sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(depends); - - using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; - using InputOutputIterIndexerT = - dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< - NoOpIndexerT, NoOpIndexerT>; - using ReductionIndexerT = - dpctl::tensor::offset_utils::Strided1DIndexer; - - const InputOutputIterIndexerT in_out_iter_indexer{NoOpIndexerT{}, - NoOpIndexerT{}}; - const ReductionIndexerT reduction_indexer{ - 0, static_cast(reduction_nelems), - static_cast(iter_nelems)}; - - using KernelName = - class reduction_seq_contig_krn; + using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + NoOpIndexerT, NoOpIndexerT>; + using ReductionIndexerT = dpctl::tensor::offset_utils::Strided1DIndexer; - sycl::range<1> iter_range{iter_nelems}; + const InputOutputIterIndexerT in_out_iter_indexer{NoOpIndexerT{}, + NoOpIndexerT{}}; + const ReductionIndexerT reduction_indexer{ + 0, static_cast(reduction_nelems), + static_cast(iter_nelems)}; - cgh.parallel_for( - iter_range, - SequentialReduction( - arg_tp, res_tp, ReductionOpT(), identity_val, - in_out_iter_indexer, reduction_indexer, reduction_nelems)); - }); + sycl::event comp_ev = + sequential_reduction( + exec_q, arg_tp, res_tp, identity_val, iter_nelems, + reduction_nelems, in_out_iter_indexer, reduction_indexer, + depends); return comp_ev; } constexpr size_t preferred_reductions_per_wi = 8; // prevents running out of resources on CPU - size_t max_wg = detail::get_work_group_size(d); + size_t max_wg = reduction_detail::get_work_group_size(d); size_t reductions_per_wi(preferred_reductions_per_wi); if (reduction_nelems <= preferred_reductions_per_wi * max_wg) { @@ -1657,1327 +1740,134 @@ sycl::event reduction_axis0_over_group_temps_contig_impl( resTy *partially_reduced_tmp = sycl::malloc_device( iter_nelems * (reduction_groups + second_iter_reduction_groups_), exec_q); - resTy *partially_reduced_tmp2 = nullptr; - - if (partially_reduced_tmp == nullptr) { - throw std::runtime_error("Unable to allocate device_memory"); - } - else { - partially_reduced_tmp2 = - partially_reduced_tmp + reduction_groups * iter_nelems; - } - - sycl::event first_reduction_ev; - { - using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; - using ColsIndexerT = dpctl::tensor::offset_utils::Strided1DIndexer; - using InputOutputIterIndexerT = - dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< - NoOpIndexerT, NoOpIndexerT>; - using ReductionIndexerT = ColsIndexerT; - - constexpr NoOpIndexerT columns_indexer{}; - constexpr NoOpIndexerT noop_tmp_indexer{}; - const InputOutputIterIndexerT in_out_iter_indexer{columns_indexer, - noop_tmp_indexer}; - const ReductionIndexerT reduction_indexer{ - 0, /* size */ static_cast(reduction_nelems), - /* step */ static_cast(iter_nelems)}; - - first_reduction_ev = submit_no_atomic_reduction< - argTy, resTy, ReductionOpT, InputOutputIterIndexerT, - ReductionIndexerT, reduction_over_group_temps_krn>( - exec_q, arg_tp, partially_reduced_tmp, identity_val, wg, - iter_nelems, reduction_nelems, preferred_reductions_per_wi, - reduction_groups, in_out_iter_indexer, reduction_indexer, - depends); - } - - size_t remaining_reduction_nelems = reduction_groups; - - resTy *temp_arg = partially_reduced_tmp; - resTy *temp2_arg = partially_reduced_tmp2; - sycl::event dependent_ev = first_reduction_ev; - - while (remaining_reduction_nelems > - preferred_reductions_per_wi * max_wg) { - size_t reduction_groups_ = (remaining_reduction_nelems + - preferred_reductions_per_wi * wg - 1) / - (preferred_reductions_per_wi * wg); - assert(reduction_groups_ > 1); - - // keep reducing - using InputIndexerT = dpctl::tensor::offset_utils::Strided1DIndexer; - using ResIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; - using InputOutputIterIndexerT = - dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< - InputIndexerT, ResIndexerT>; - using ReductionIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; - - const InputIndexerT inp_indexer{ - 0, static_cast(iter_nelems), - static_cast(reduction_groups_)}; - constexpr ResIndexerT res_iter_indexer{}; - - const InputOutputIterIndexerT in_out_iter_indexer{inp_indexer, - res_iter_indexer}; - constexpr ReductionIndexerT reduction_indexer{}; - - sycl::event partial_reduction_ev = submit_no_atomic_reduction< - resTy, resTy, ReductionOpT, InputOutputIterIndexerT, - ReductionIndexerT, reduction_over_group_temps_krn>( - exec_q, temp_arg, temp2_arg, identity_val, wg, iter_nelems, - remaining_reduction_nelems, preferred_reductions_per_wi, - reduction_groups_, in_out_iter_indexer, reduction_indexer, - {dependent_ev}); - - remaining_reduction_nelems = reduction_groups_; - std::swap(temp_arg, temp2_arg); - dependent_ev = std::move(partial_reduction_ev); - } - - // final reduction to res - using InputIndexerT = dpctl::tensor::offset_utils::Strided1DIndexer; - using ResIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; - using InputOutputIterIndexerT = - dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< - InputIndexerT, ResIndexerT>; - using ReductionIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; - - const InputIndexerT inp_indexer{ - 0, static_cast(iter_nelems), - static_cast(remaining_reduction_nelems)}; - constexpr ResIndexerT res_iter_indexer{}; - - const InputOutputIterIndexerT in_out_iter_indexer{inp_indexer, - res_iter_indexer}; - constexpr ReductionIndexerT reduction_indexer{}; - - wg = max_wg; - reductions_per_wi = - std::max(1, (remaining_reduction_nelems + wg - 1) / wg); - - reduction_groups = - (remaining_reduction_nelems + reductions_per_wi * wg - 1) / - (reductions_per_wi * wg); - assert(reduction_groups == 1); - - sycl::event final_reduction_ev = submit_no_atomic_reduction< - resTy, resTy, ReductionOpT, InputOutputIterIndexerT, - ReductionIndexerT, reduction_over_group_temps_krn>( - exec_q, temp_arg, res_tp, identity_val, wg, iter_nelems, - remaining_reduction_nelems, reductions_per_wi, reduction_groups, - in_out_iter_indexer, reduction_indexer, {dependent_ev}); - - sycl::event cleanup_host_task_event = - exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(final_reduction_ev); - const sycl::context &ctx = exec_q.get_context(); - - cgh.host_task([ctx, partially_reduced_tmp] { - sycl::free(partially_reduced_tmp, ctx); - }); - }); - - // FIXME: do not return host-task event - // Instead collect all host-tasks to a list - - return cleanup_host_task_event; - } -} - -/* @brief Types supported by comparison-reduction code based on atomic_ref */ -template -struct TypePairSupportDataForCompReductionAtomic -{ - - /* value is true if a kernel for must be instantiated, false - * otherwise */ - // disjunction is C++17 feature, supported by DPC++ - static constexpr bool is_defined = std::disjunction< - // input int32 - td_ns::TypePairDefinedEntry, - // input uint32 - td_ns::TypePairDefinedEntry, - // input int64 - td_ns::TypePairDefinedEntry, - // input uint64 - td_ns::TypePairDefinedEntry, - // input float - td_ns::TypePairDefinedEntry, - // input double - td_ns::TypePairDefinedEntry, - // fall-through - td_ns::NotDefinedEntry>::is_defined; -}; - -template -struct TypePairSupportDataForCompReductionTemps -{ - - // disjunction is C++17 feature, supported by DPC++ - static constexpr bool is_defined = std::disjunction< - // input bool - td_ns::TypePairDefinedEntry, - // input int8_t - td_ns::TypePairDefinedEntry, - // input uint8_t - td_ns::TypePairDefinedEntry, - - // input int16_t - td_ns::TypePairDefinedEntry, - // input uint16_t - td_ns::TypePairDefinedEntry, - - // input int32_t - td_ns::TypePairDefinedEntry, - // input uint32_t - td_ns::TypePairDefinedEntry, - - // input int64_t - td_ns::TypePairDefinedEntry, - - // input uint32_t - td_ns::TypePairDefinedEntry, - - // input half - td_ns::TypePairDefinedEntry, - - // input float - td_ns::TypePairDefinedEntry, - - // input double - td_ns::TypePairDefinedEntry, - - // input std::complex - td_ns::TypePairDefinedEntry, - outTy, - std::complex>, - - td_ns::TypePairDefinedEntry, - outTy, - std::complex>, - - // fall-through - td_ns::NotDefinedEntry>::is_defined; -}; - -template -struct MaxOverAxisAtomicStridedFactory -{ - fnT get() const - { - if constexpr (TypePairSupportDataForCompReductionAtomic< - srcTy, dstTy>::is_defined) - { - if constexpr (std::is_floating_point::value) { - using ReductionOpT = su_ns::Maximum; - return dpctl::tensor::kernels:: - reduction_over_group_with_atomics_strided_impl< - srcTy, dstTy, ReductionOpT>; - } - else { - using ReductionOpT = sycl::maximum; - return dpctl::tensor::kernels:: - reduction_over_group_with_atomics_strided_impl< - srcTy, dstTy, ReductionOpT>; - } - } - else { - return nullptr; - } - } -}; - -template -struct MaxOverAxisTempsStridedFactory -{ - fnT get() const - { - if constexpr (TypePairSupportDataForCompReductionTemps< - srcTy, dstTy>::is_defined) - { - if constexpr (std::is_integral_v && - !std::is_same_v) { - using ReductionOpT = sycl::maximum; - return dpctl::tensor::kernels:: - reduction_over_group_temps_strided_impl; - } - else { - using ReductionOpT = su_ns::Maximum; - return dpctl::tensor::kernels:: - reduction_over_group_temps_strided_impl; - } - } - else { - return nullptr; - } - } -}; - -template -struct MaxOverAxis1AtomicContigFactory -{ - fnT get() const - { - if constexpr (TypePairSupportDataForCompReductionAtomic< - srcTy, dstTy>::is_defined) - { - if constexpr (std::is_floating_point::value) { - using ReductionOpT = su_ns::Maximum; - return dpctl::tensor::kernels:: - reduction_axis1_over_group_with_atomics_contig_impl< - srcTy, dstTy, ReductionOpT>; - } - else { - using ReductionOpT = sycl::maximum; - return dpctl::tensor::kernels:: - reduction_axis1_over_group_with_atomics_contig_impl< - srcTy, dstTy, ReductionOpT>; - } - } - else { - return nullptr; - } - } -}; - -template -struct MaxOverAxis0AtomicContigFactory -{ - fnT get() const - { - if constexpr (TypePairSupportDataForCompReductionAtomic< - srcTy, dstTy>::is_defined) - { - if constexpr (std::is_floating_point::value) { - using ReductionOpT = su_ns::Maximum; - return dpctl::tensor::kernels:: - reduction_axis0_over_group_with_atomics_contig_impl< - srcTy, dstTy, ReductionOpT>; - } - else { - using ReductionOpT = sycl::maximum; - return dpctl::tensor::kernels:: - reduction_axis0_over_group_with_atomics_contig_impl< - srcTy, dstTy, ReductionOpT>; - } - } - else { - return nullptr; - } - } -}; - -template -struct MaxOverAxis1TempsContigFactory -{ - fnT get() const - { - if constexpr (TypePairSupportDataForCompReductionTemps< - srcTy, dstTy>::is_defined) - { - if constexpr (std::is_integral_v && - !std::is_same_v) { - using ReductionOpT = sycl::maximum; - return dpctl::tensor::kernels:: - reduction_axis1_over_group_temps_contig_impl; - } - else { - using ReductionOpT = su_ns::Maximum; - return dpctl::tensor::kernels:: - reduction_axis1_over_group_temps_contig_impl; - } - } - else { - return nullptr; - } - } -}; - -template -struct MaxOverAxis0TempsContigFactory -{ - fnT get() const - { - if constexpr (TypePairSupportDataForCompReductionTemps< - srcTy, dstTy>::is_defined) - { - if constexpr (std::is_integral_v && - !std::is_same_v) { - using ReductionOpT = sycl::maximum; - return dpctl::tensor::kernels:: - reduction_axis0_over_group_temps_contig_impl; - } - else { - using ReductionOpT = su_ns::Maximum; - return dpctl::tensor::kernels:: - reduction_axis0_over_group_temps_contig_impl; - } - } - else { - return nullptr; - } - } -}; - -template -struct MinOverAxisAtomicStridedFactory -{ - fnT get() const - { - if constexpr (TypePairSupportDataForCompReductionAtomic< - srcTy, dstTy>::is_defined) - { - if constexpr (std::is_floating_point::value) { - using ReductionOpT = su_ns::Minimum; - return dpctl::tensor::kernels:: - reduction_over_group_with_atomics_strided_impl< - srcTy, dstTy, ReductionOpT>; - } - else { - using ReductionOpT = sycl::minimum; - return dpctl::tensor::kernels:: - reduction_over_group_with_atomics_strided_impl< - srcTy, dstTy, ReductionOpT>; - } - } - else { - return nullptr; - } - } -}; - -template -struct MinOverAxisTempsStridedFactory -{ - fnT get() const - { - if constexpr (TypePairSupportDataForCompReductionTemps< - srcTy, dstTy>::is_defined) - { - if constexpr (std::is_integral_v && - !std::is_same_v) { - using ReductionOpT = sycl::minimum; - return dpctl::tensor::kernels:: - reduction_over_group_temps_strided_impl; - } - else { - using ReductionOpT = su_ns::Minimum; - return dpctl::tensor::kernels:: - reduction_over_group_temps_strided_impl; - } - } - else { - return nullptr; - } - } -}; - -template -struct MinOverAxis1AtomicContigFactory -{ - fnT get() const - { - if constexpr (TypePairSupportDataForCompReductionAtomic< - srcTy, dstTy>::is_defined) - { - if constexpr (std::is_floating_point::value) { - using ReductionOpT = su_ns::Minimum; - return dpctl::tensor::kernels:: - reduction_axis1_over_group_with_atomics_contig_impl< - srcTy, dstTy, ReductionOpT>; - } - else { - using ReductionOpT = sycl::minimum; - return dpctl::tensor::kernels:: - reduction_axis1_over_group_with_atomics_contig_impl< - srcTy, dstTy, ReductionOpT>; - } - } - else { - return nullptr; - } - } -}; - -template -struct MinOverAxis0AtomicContigFactory -{ - fnT get() const - { - if constexpr (TypePairSupportDataForCompReductionAtomic< - srcTy, dstTy>::is_defined) - { - if constexpr (std::is_floating_point::value) { - using ReductionOpT = su_ns::Minimum; - return dpctl::tensor::kernels:: - reduction_axis0_over_group_with_atomics_contig_impl< - srcTy, dstTy, ReductionOpT>; - } - else { - using ReductionOpT = sycl::minimum; - return dpctl::tensor::kernels:: - reduction_axis0_over_group_with_atomics_contig_impl< - srcTy, dstTy, ReductionOpT>; - } - } - else { - return nullptr; - } - } -}; - -template -struct MinOverAxis1TempsContigFactory -{ - fnT get() const - { - if constexpr (TypePairSupportDataForCompReductionTemps< - srcTy, dstTy>::is_defined) - { - if constexpr (std::is_integral_v && - !std::is_same_v) { - using ReductionOpT = sycl::minimum; - return dpctl::tensor::kernels:: - reduction_axis1_over_group_temps_contig_impl; - } - else { - using ReductionOpT = su_ns::Minimum; - return dpctl::tensor::kernels:: - reduction_axis1_over_group_temps_contig_impl; - } - } - else { - return nullptr; - } - } -}; - -template -struct MinOverAxis0TempsContigFactory -{ - fnT get() const - { - if constexpr (TypePairSupportDataForCompReductionTemps< - srcTy, dstTy>::is_defined) - { - if constexpr (std::is_integral_v && - !std::is_same_v) { - using ReductionOpT = sycl::minimum; - return dpctl::tensor::kernels:: - reduction_axis0_over_group_temps_contig_impl; - } - else { - using ReductionOpT = su_ns::Minimum; - return dpctl::tensor::kernels:: - reduction_axis0_over_group_temps_contig_impl; - } - } - else { - return nullptr; - } - } -}; - -// Sum - -/* @brief Types supported by plus-reduction code based on atomic_ref */ -template -struct TypePairSupportDataForSumReductionAtomic -{ - - /* value if true a kernel for must be instantiated, false - * otherwise */ - static constexpr bool is_defined = std::disjunction< // disjunction is C++17 - // feature, supported - // by DPC++ input bool - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - // input int8 - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - // input uint8 - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - // input int16 - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - // input uint16 - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - // input int32 - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - // input uint32 - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - // input int64 - td_ns::TypePairDefinedEntry, - // input uint64 - td_ns::TypePairDefinedEntry, - // fall-through - td_ns::NotDefinedEntry>::is_defined; -}; - -template -struct TypePairSupportDataForSumReductionTemps -{ - - static constexpr bool is_defined = std::disjunction< // disjunction is C++17 - // feature, supported - // by DPC++ input bool - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - - // input int8_t - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - - // input uint8_t - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - - // input int16_t - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - - // input uint16_t - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - - // input int32_t - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - - // input uint32_t - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - - // input int64_t - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - - // input uint64_t - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - - // input half - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns:: - TypePairDefinedEntry>, - td_ns::TypePairDefinedEntry>, - - // input float - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry>, - td_ns::TypePairDefinedEntry>, - - // input double - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry>, - - // input std::complex - td_ns::TypePairDefinedEntry, - outTy, - std::complex>, - td_ns::TypePairDefinedEntry, - outTy, - std::complex>, - - td_ns::TypePairDefinedEntry, - outTy, - std::complex>, - - // fall-throug - td_ns::NotDefinedEntry>::is_defined; -}; - -template -struct SumOverAxisAtomicStridedFactory -{ - fnT get() const - { - if constexpr (TypePairSupportDataForSumReductionAtomic< - srcTy, dstTy>::is_defined) - { - using ReductionOpT = sycl::plus; - return dpctl::tensor::kernels:: - reduction_over_group_with_atomics_strided_impl; - } - else { - return nullptr; - } - } -}; - -template -struct SumOverAxisTempsStridedFactory -{ - fnT get() const - { - if constexpr (TypePairSupportDataForSumReductionTemps< - srcTy, dstTy>::is_defined) { - using ReductionOpT = sycl::plus; - return dpctl::tensor::kernels:: - reduction_over_group_temps_strided_impl; - } - else { - return nullptr; - } - } -}; - -template -struct SumOverAxis1AtomicContigFactory -{ - fnT get() const - { - if constexpr (TypePairSupportDataForSumReductionAtomic< - srcTy, dstTy>::is_defined) - { - using ReductionOpT = sycl::plus; - return dpctl::tensor::kernels:: - reduction_axis1_over_group_with_atomics_contig_impl< - srcTy, dstTy, ReductionOpT>; - } - else { - return nullptr; - } - } -}; - -template -struct SumOverAxis0AtomicContigFactory -{ - fnT get() const - { - if constexpr (TypePairSupportDataForSumReductionAtomic< - srcTy, dstTy>::is_defined) - { - using ReductionOpT = sycl::plus; - return dpctl::tensor::kernels:: - reduction_axis0_over_group_with_atomics_contig_impl< - srcTy, dstTy, ReductionOpT>; - } - else { - return nullptr; - } - } -}; - -template -struct SumOverAxis1TempsContigFactory -{ - fnT get() const - { - if constexpr (TypePairSupportDataForSumReductionTemps< - srcTy, dstTy>::is_defined) { - using ReductionOpT = sycl::plus; - return dpctl::tensor::kernels:: - reduction_axis1_over_group_temps_contig_impl; - } - else { - return nullptr; - } - } -}; - -template -struct SumOverAxis0TempsContigFactory -{ - fnT get() const - { - if constexpr (TypePairSupportDataForSumReductionTemps< - srcTy, dstTy>::is_defined) { - using ReductionOpT = sycl::plus; - return dpctl::tensor::kernels:: - reduction_axis0_over_group_temps_contig_impl; - } - else { - return nullptr; - } - } -}; - -// Product - -/* @brief Types supported by plus-reduction code based on atomic_ref */ -template -struct TypePairSupportDataForProductReductionAtomic -{ - - /* value if true a kernel for must be instantiated, false - * otherwise */ - static constexpr bool is_defined = std::disjunction< // disjunction is C++17 - // feature, supported - // by DPC++ input bool - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - // input int8 - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - // input uint8 - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - // input int16 - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - // input uint16 - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - // input int32 - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - // input uint32 - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - // input int64 - td_ns::TypePairDefinedEntry, - // input uint64 - td_ns::TypePairDefinedEntry, - // fall-through - td_ns::NotDefinedEntry>::is_defined; -}; - -template -struct TypePairSupportDataForProductReductionTemps -{ - - static constexpr bool is_defined = std::disjunction< // disjunction is C++17 - // feature, supported - // by DPC++ input bool - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - - // input int8_t - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - - // input uint8_t - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - - // input int16_t - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - - // input uint16_t - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - - // input int32_t - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - - // input uint32_t - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - - // input int64_t - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - - // input uint32_t - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - - // input half - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns:: - TypePairDefinedEntry>, - td_ns::TypePairDefinedEntry>, - - // input float - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry>, - td_ns::TypePairDefinedEntry>, - - // input double - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry>, - - // input std::complex - td_ns::TypePairDefinedEntry, - outTy, - std::complex>, - td_ns::TypePairDefinedEntry, - outTy, - std::complex>, - - td_ns::TypePairDefinedEntry, - outTy, - std::complex>, - - // fall-throug - td_ns::NotDefinedEntry>::is_defined; -}; - -template -struct ProductOverAxisAtomicStridedFactory -{ - fnT get() const - { - if constexpr (TypePairSupportDataForProductReductionAtomic< - srcTy, dstTy>::is_defined) - { - using ReductionOpT = sycl::multiplies; - return dpctl::tensor::kernels:: - reduction_over_group_with_atomics_strided_impl; - } - else { - return nullptr; - } - } -}; - -template -struct ProductOverAxisTempsStridedFactory -{ - fnT get() const - { - if constexpr (TypePairSupportDataForProductReductionTemps< - srcTy, dstTy>::is_defined) - { - using ReductionOpT = sycl::multiplies; - return dpctl::tensor::kernels:: - reduction_over_group_temps_strided_impl; - } - else { - return nullptr; - } - } -}; - -template -struct ProductOverAxis1AtomicContigFactory -{ - fnT get() const - { - if constexpr (TypePairSupportDataForProductReductionAtomic< - srcTy, dstTy>::is_defined) - { - using ReductionOpT = sycl::multiplies; - return dpctl::tensor::kernels:: - reduction_axis1_over_group_with_atomics_contig_impl< - srcTy, dstTy, ReductionOpT>; - } - else { - return nullptr; - } - } -}; - -template -struct ProductOverAxis0AtomicContigFactory -{ - fnT get() const - { - if constexpr (TypePairSupportDataForProductReductionAtomic< - srcTy, dstTy>::is_defined) - { - using ReductionOpT = sycl::multiplies; - return dpctl::tensor::kernels:: - reduction_axis0_over_group_with_atomics_contig_impl< - srcTy, dstTy, ReductionOpT>; - } - else { - return nullptr; - } - } -}; - -template -struct ProductOverAxis1TempsContigFactory -{ - fnT get() const - { - if constexpr (TypePairSupportDataForProductReductionTemps< - srcTy, dstTy>::is_defined) - { - using ReductionOpT = sycl::multiplies; - return dpctl::tensor::kernels:: - reduction_axis1_over_group_temps_contig_impl; + resTy *partially_reduced_tmp2 = nullptr; + + if (partially_reduced_tmp == nullptr) { + throw std::runtime_error("Unable to allocate device_memory"); } else { - return nullptr; + partially_reduced_tmp2 = + partially_reduced_tmp + reduction_groups * iter_nelems; } - } -}; -template -struct ProductOverAxis0TempsContigFactory -{ - fnT get() const - { - if constexpr (TypePairSupportDataForProductReductionTemps< - srcTy, dstTy>::is_defined) + sycl::event first_reduction_ev; { - using ReductionOpT = sycl::multiplies; - return dpctl::tensor::kernels:: - reduction_axis0_over_group_temps_contig_impl; - } - else { - return nullptr; + using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using ColsIndexerT = dpctl::tensor::offset_utils::Strided1DIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + NoOpIndexerT, NoOpIndexerT>; + using ReductionIndexerT = ColsIndexerT; + + constexpr NoOpIndexerT columns_indexer{}; + constexpr NoOpIndexerT noop_tmp_indexer{}; + const InputOutputIterIndexerT in_out_iter_indexer{columns_indexer, + noop_tmp_indexer}; + const ReductionIndexerT reduction_indexer{ + 0, /* size */ static_cast(reduction_nelems), + /* step */ static_cast(iter_nelems)}; + + first_reduction_ev = submit_no_atomic_reduction< + argTy, resTy, ReductionOpT, InputOutputIterIndexerT, + ReductionIndexerT, reduction_over_group_temps_krn>( + exec_q, arg_tp, partially_reduced_tmp, identity_val, wg, + iter_nelems, reduction_nelems, preferred_reductions_per_wi, + reduction_groups, in_out_iter_indexer, reduction_indexer, + depends); } - } -}; -template -struct TypePairSupportDataForHypotReductionTemps -{ + size_t remaining_reduction_nelems = reduction_groups; - static constexpr bool is_defined = std::disjunction< // disjunction is C++17 - // feature, supported - // by DPC++ input bool - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - - // input int8_t - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - - // input uint8_t - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - - // input int16_t - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - - // input uint16_t - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - - // input int32_t - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - - // input uint32_t - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - - // input int64_t - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - - // input uint64_t - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - - // input half - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - - // input float - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - - // input double - td_ns::TypePairDefinedEntry, - - // fall-through - td_ns::NotDefinedEntry>::is_defined; -}; + resTy *temp_arg = partially_reduced_tmp; + resTy *temp2_arg = partially_reduced_tmp2; + sycl::event dependent_ev = first_reduction_ev; -template -struct HypotOverAxisTempsStridedFactory -{ - fnT get() const - { - if constexpr (TypePairSupportDataForHypotReductionTemps< - srcTy, dstTy>::is_defined) - { - using ReductionOpT = su_ns::Hypot; - return dpctl::tensor::kernels:: - reduction_over_group_temps_strided_impl; - } - else { - return nullptr; - } - } -}; + while (remaining_reduction_nelems > + preferred_reductions_per_wi * max_wg) { + size_t reduction_groups_ = (remaining_reduction_nelems + + preferred_reductions_per_wi * wg - 1) / + (preferred_reductions_per_wi * wg); + assert(reduction_groups_ > 1); -template -struct HypotOverAxis1TempsContigFactory -{ - fnT get() const - { - if constexpr (TypePairSupportDataForHypotReductionTemps< - srcTy, dstTy>::is_defined) - { - using ReductionOpT = su_ns::Hypot; - return dpctl::tensor::kernels:: - reduction_axis1_over_group_temps_contig_impl; - } - else { - return nullptr; - } - } -}; + // keep reducing + using InputIndexerT = dpctl::tensor::offset_utils::Strided1DIndexer; + using ResIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + InputIndexerT, ResIndexerT>; + using ReductionIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; -template -struct HypotOverAxis0TempsContigFactory -{ - fnT get() const - { - if constexpr (TypePairSupportDataForHypotReductionTemps< - srcTy, dstTy>::is_defined) - { - using ReductionOpT = su_ns::Hypot; - return dpctl::tensor::kernels:: - reduction_axis0_over_group_temps_contig_impl; - } - else { - return nullptr; - } - } -}; + const InputIndexerT inp_indexer{ + 0, static_cast(iter_nelems), + static_cast(reduction_groups_)}; + constexpr ResIndexerT res_iter_indexer{}; -template -struct TypePairSupportDataForLogSumExpReductionTemps -{ + const InputOutputIterIndexerT in_out_iter_indexer{inp_indexer, + res_iter_indexer}; + constexpr ReductionIndexerT reduction_indexer{}; - static constexpr bool is_defined = std::disjunction< // disjunction is C++17 - // feature, supported - // by DPC++ input bool - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - - // input int8_t - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - - // input uint8_t - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - - // input int16_t - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - - // input uint16_t - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - - // input int32_t - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - - // input uint32_t - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - - // input int64_t - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - - // input uint64_t - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - - // input half - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - - // input float - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - - // input double - td_ns::TypePairDefinedEntry, - - // fall-through - td_ns::NotDefinedEntry>::is_defined; -}; + sycl::event partial_reduction_ev = submit_no_atomic_reduction< + resTy, resTy, ReductionOpT, InputOutputIterIndexerT, + ReductionIndexerT, reduction_over_group_temps_krn>( + exec_q, temp_arg, temp2_arg, identity_val, wg, iter_nelems, + remaining_reduction_nelems, preferred_reductions_per_wi, + reduction_groups_, in_out_iter_indexer, reduction_indexer, + {dependent_ev}); -template -struct LogSumExpOverAxisTempsStridedFactory -{ - fnT get() const - { - if constexpr (TypePairSupportDataForLogSumExpReductionTemps< - srcTy, dstTy>::is_defined) - { - using ReductionOpT = su_ns::LogSumExp; - return dpctl::tensor::kernels:: - reduction_over_group_temps_strided_impl; - } - else { - return nullptr; + remaining_reduction_nelems = reduction_groups_; + std::swap(temp_arg, temp2_arg); + dependent_ev = std::move(partial_reduction_ev); } - } -}; -template -struct LogSumExpOverAxis1TempsContigFactory -{ - fnT get() const - { - if constexpr (TypePairSupportDataForLogSumExpReductionTemps< - srcTy, dstTy>::is_defined) - { - using ReductionOpT = su_ns::LogSumExp; - return dpctl::tensor::kernels:: - reduction_axis1_over_group_temps_contig_impl; - } - else { - return nullptr; - } - } -}; + // final reduction to res + using InputIndexerT = dpctl::tensor::offset_utils::Strided1DIndexer; + using ResIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + InputIndexerT, ResIndexerT>; + using ReductionIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; -template -struct LogSumExpOverAxis0TempsContigFactory -{ - fnT get() const - { - if constexpr (TypePairSupportDataForLogSumExpReductionTemps< - srcTy, dstTy>::is_defined) - { - using ReductionOpT = su_ns::LogSumExp; - return dpctl::tensor::kernels:: - reduction_axis0_over_group_temps_contig_impl; - } - else { - return nullptr; - } + const InputIndexerT inp_indexer{ + 0, static_cast(iter_nelems), + static_cast(remaining_reduction_nelems)}; + constexpr ResIndexerT res_iter_indexer{}; + + const InputOutputIterIndexerT in_out_iter_indexer{inp_indexer, + res_iter_indexer}; + constexpr ReductionIndexerT reduction_indexer{}; + + wg = max_wg; + reductions_per_wi = + std::max(1, (remaining_reduction_nelems + wg - 1) / wg); + + reduction_groups = + (remaining_reduction_nelems + reductions_per_wi * wg - 1) / + (reductions_per_wi * wg); + assert(reduction_groups == 1); + + sycl::event final_reduction_ev = submit_no_atomic_reduction< + resTy, resTy, ReductionOpT, InputOutputIterIndexerT, + ReductionIndexerT, reduction_over_group_temps_krn>( + exec_q, temp_arg, res_tp, identity_val, wg, iter_nelems, + remaining_reduction_nelems, reductions_per_wi, reduction_groups, + in_out_iter_indexer, reduction_indexer, {dependent_ev}); + + sycl::event cleanup_host_task_event = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(final_reduction_ev); + const sycl::context &ctx = exec_q.get_context(); + + cgh.host_task([ctx, partially_reduced_tmp] { + sycl::free(partially_reduced_tmp, ctx); + }); + }); + + // FIXME: do not return host-task event + // Instead collect all host-tasks to a list + + return cleanup_host_task_event; } -}; +} // Argmax and Argmin @@ -3697,7 +2587,7 @@ sycl::event search_over_group_temps_strided_impl( constexpr size_t preferred_reductions_per_wi = 4; // prevents running out of resources on CPU - size_t max_wg = detail::get_work_group_size(d); + size_t max_wg = reduction_detail::get_work_group_size(d); size_t reductions_per_wi(preferred_reductions_per_wi); if (reduction_nelems <= preferred_reductions_per_wi * max_wg) { @@ -3993,7 +2883,7 @@ sycl::event search_axis1_over_group_temps_contig_impl( constexpr size_t preferred_reductions_per_wi = 8; // prevents running out of resources on CPU - size_t max_wg = detail::get_work_group_size(d); + size_t max_wg = reduction_detail::get_work_group_size(d); size_t reductions_per_wi(preferred_reductions_per_wi); if (reduction_nelems <= preferred_reductions_per_wi * max_wg) { @@ -4276,7 +3166,7 @@ sycl::event search_axis0_over_group_temps_contig_impl( constexpr size_t preferred_reductions_per_wi = 8; // prevents running out of resources on CPU - size_t max_wg = detail::get_work_group_size(d); + size_t max_wg = reduction_detail::get_work_group_size(d); size_t reductions_per_wi(preferred_reductions_per_wi); if (reduction_nelems <= preferred_reductions_per_wi * max_wg) { @@ -4489,265 +3379,6 @@ sycl::event search_axis0_over_group_temps_contig_impl( } } -template -struct TypePairSupportDataForSearchReductionTemps -{ - - static constexpr bool is_defined = std::disjunction< // disjunction is C++17 - // feature, supported - // by DPC++ input bool - td_ns::TypePairDefinedEntry, - // input int8_t - td_ns::TypePairDefinedEntry, - - // input uint8_t - td_ns::TypePairDefinedEntry, - - // input int16_t - td_ns::TypePairDefinedEntry, - - // input uint16_t - td_ns::TypePairDefinedEntry, - - // input int32_t - td_ns::TypePairDefinedEntry, - // input uint32_t - td_ns::TypePairDefinedEntry, - - // input int64_t - td_ns::TypePairDefinedEntry, - - // input uint32_t - td_ns::TypePairDefinedEntry, - - // input half - td_ns::TypePairDefinedEntry, - - // input float - td_ns::TypePairDefinedEntry, - - // input double - td_ns::TypePairDefinedEntry, - - // input std::complex - td_ns::TypePairDefinedEntry, - outTy, - std::int64_t>, - - td_ns::TypePairDefinedEntry, - outTy, - std::int64_t>, - - // fall-through - td_ns::NotDefinedEntry>::is_defined; -}; - -template -struct ArgmaxOverAxisTempsStridedFactory -{ - fnT get() const - { - if constexpr (TypePairSupportDataForSearchReductionTemps< - srcTy, dstTy>::is_defined) - { - if constexpr (std::is_integral_v && - !std::is_same_v) { - // op for values - using ReductionOpT = sycl::maximum; - // op for indices - using IndexOpT = sycl::minimum; - return dpctl::tensor::kernels:: - search_over_group_temps_strided_impl< - srcTy, dstTy, ReductionOpT, IndexOpT>; - } - else { - // op for values - using ReductionOpT = su_ns::Maximum; - // op for indices - using IndexOpT = sycl::minimum; - return dpctl::tensor::kernels:: - search_over_group_temps_strided_impl< - srcTy, dstTy, ReductionOpT, IndexOpT>; - } - } - else { - return nullptr; - } - } -}; - -template -struct ArgmaxOverAxis1TempsContigFactory -{ - fnT get() const - { - if constexpr (TypePairSupportDataForSearchReductionTemps< - srcTy, dstTy>::is_defined) - { - if constexpr (std::is_integral_v && - !std::is_same_v) { - // op for values - using ReductionOpT = sycl::maximum; - // op for indices - using IndexOpT = sycl::minimum; - return dpctl::tensor::kernels:: - search_axis1_over_group_temps_contig_impl< - srcTy, dstTy, ReductionOpT, IndexOpT>; - } - else { - // op for values - using ReductionOpT = su_ns::Maximum; - // op for indices - using IndexOpT = sycl::minimum; - return dpctl::tensor::kernels:: - search_axis1_over_group_temps_contig_impl< - srcTy, dstTy, ReductionOpT, IndexOpT>; - } - } - else { - return nullptr; - } - } -}; - -template -struct ArgmaxOverAxis0TempsContigFactory -{ - fnT get() const - { - if constexpr (TypePairSupportDataForSearchReductionTemps< - srcTy, dstTy>::is_defined) - { - if constexpr (std::is_integral_v && - !std::is_same_v) { - // op for values - using ReductionOpT = sycl::maximum; - // op for indices - using IndexOpT = sycl::minimum; - return dpctl::tensor::kernels:: - search_axis0_over_group_temps_contig_impl< - srcTy, dstTy, ReductionOpT, IndexOpT>; - } - else { - // op for values - using ReductionOpT = su_ns::Maximum; - // op for indices - using IndexOpT = sycl::minimum; - return dpctl::tensor::kernels:: - search_axis0_over_group_temps_contig_impl< - srcTy, dstTy, ReductionOpT, IndexOpT>; - } - } - else { - return nullptr; - } - } -}; - -template -struct ArgminOverAxisTempsStridedFactory -{ - fnT get() const - { - if constexpr (TypePairSupportDataForSearchReductionTemps< - srcTy, dstTy>::is_defined) - { - if constexpr (std::is_integral_v && - !std::is_same_v) { - // op for values - using ReductionOpT = sycl::minimum; - // op for indices - using IndexOpT = sycl::minimum; - return dpctl::tensor::kernels:: - search_over_group_temps_strided_impl< - srcTy, dstTy, ReductionOpT, IndexOpT>; - } - else { - // op for values - using ReductionOpT = su_ns::Minimum; - // op for indices - using IndexOpT = sycl::minimum; - return dpctl::tensor::kernels:: - search_over_group_temps_strided_impl< - srcTy, dstTy, ReductionOpT, IndexOpT>; - } - } - else { - return nullptr; - } - } -}; - -template -struct ArgminOverAxis1TempsContigFactory -{ - fnT get() const - { - if constexpr (TypePairSupportDataForSearchReductionTemps< - srcTy, dstTy>::is_defined) - { - if constexpr (std::is_integral_v && - !std::is_same_v) { - // op for values - using ReductionOpT = sycl::minimum; - // op for indices - using IndexOpT = sycl::minimum; - return dpctl::tensor::kernels:: - search_axis1_over_group_temps_contig_impl< - srcTy, dstTy, ReductionOpT, IndexOpT>; - } - else { - // op for values - using ReductionOpT = su_ns::Minimum; - // op for indices - using IndexOpT = sycl::minimum; - return dpctl::tensor::kernels:: - search_axis1_over_group_temps_contig_impl< - srcTy, dstTy, ReductionOpT, IndexOpT>; - } - } - else { - return nullptr; - } - } -}; - -template -struct ArgminOverAxis0TempsContigFactory -{ - fnT get() const - { - if constexpr (TypePairSupportDataForSearchReductionTemps< - srcTy, dstTy>::is_defined) - { - if constexpr (std::is_integral_v && - !std::is_same_v) { - // op for values - using ReductionOpT = sycl::minimum; - // op for indices - using IndexOpT = sycl::minimum; - return dpctl::tensor::kernels:: - search_axis0_over_group_temps_contig_impl< - srcTy, dstTy, ReductionOpT, IndexOpT>; - } - else { - // op for values - using ReductionOpT = su_ns::Minimum; - // op for indices - using IndexOpT = sycl::minimum; - return dpctl::tensor::kernels:: - search_axis0_over_group_temps_contig_impl< - srcTy, dstTy, ReductionOpT, IndexOpT>; - } - } - else { - return nullptr; - } - } -}; - } // namespace kernels } // namespace tensor } // namespace dpctl diff --git a/dpctl/tensor/libtensor/include/utils/sycl_utils.hpp b/dpctl/tensor/libtensor/include/utils/sycl_utils.hpp index d2032fdb65..6301b3b9bc 100644 --- a/dpctl/tensor/libtensor/include/utils/sycl_utils.hpp +++ b/dpctl/tensor/libtensor/include/utils/sycl_utils.hpp @@ -214,6 +214,9 @@ template using IsMaximum = std::bool_constant> || std::is_same_v>>; +template +using IsSyclMaximum = std::bool_constant>>; + template struct GetIdentity::value>> { @@ -244,6 +247,9 @@ template using IsMinimum = std::bool_constant> || std::is_same_v>>; +template +using IsSyclMinimum = std::bool_constant>>; + template struct GetIdentity::value>> { @@ -273,6 +279,10 @@ struct GetIdentity using IsPlus = std::bool_constant> || std::is_same_v>>; + +template +using IsSyclPlus = std::bool_constant>>; + // Multiplies template @@ -280,6 +290,10 @@ using IsMultiplies = std::bool_constant> || std::is_same_v>>; +template +using IsSyclMultiplies = + std::bool_constant>>; + template struct GetIdentity::value>> { @@ -326,6 +340,40 @@ struct GetIdentity::value>> static constexpr T value = 0; }; +// Logical_And + +template +using IsLogicalAnd = + std::bool_constant> || + std::is_same_v>>; + +template +using IsSyclLogicalAnd = + std::bool_constant>>; + +template +struct GetIdentity::value>> +{ + static constexpr T value = static_cast(1); +}; + +// Logical_Or + +template +using IsLogicalOr = + std::bool_constant> || + std::is_same_v>>; + +template +using IsSyclLogicalOr = + std::bool_constant>>; + +template +struct GetIdentity::value>> +{ + static constexpr T value = static_cast(0); +}; + // Identity template struct Identity diff --git a/dpctl/tensor/libtensor/source/reductions/all.cpp b/dpctl/tensor/libtensor/source/reductions/all.cpp index dbb37276f1..9c40ccfbb4 100644 --- a/dpctl/tensor/libtensor/source/reductions/all.cpp +++ b/dpctl/tensor/libtensor/source/reductions/all.cpp @@ -29,7 +29,8 @@ #include #include -#include "kernels/boolean_reductions.hpp" +#include "kernels/reductions.hpp" +#include "reduction_atomic_support.hpp" #include "reduction_over_axis.hpp" #include "utils/type_dispatch.hpp" @@ -47,46 +48,79 @@ namespace td_ns = dpctl::tensor::type_dispatch; namespace impl { -using dpctl::tensor::kernels::boolean_reduction_strided_impl_fn_ptr; -static boolean_reduction_strided_impl_fn_ptr +using dpctl::tensor::kernels::reduction_strided_impl_fn_ptr; +static reduction_strided_impl_fn_ptr all_reduction_strided_dispatch_vector[td_ns::num_types]; -using dpctl::tensor::kernels::boolean_reduction_contig_impl_fn_ptr; -static boolean_reduction_contig_impl_fn_ptr +using dpctl::tensor::kernels::reduction_contig_impl_fn_ptr; +static reduction_contig_impl_fn_ptr all_reduction_axis1_contig_dispatch_vector[td_ns::num_types]; - -static boolean_reduction_contig_impl_fn_ptr +static reduction_contig_impl_fn_ptr all_reduction_axis0_contig_dispatch_vector[td_ns::num_types]; +template struct AllStridedFactory +{ + fnT get() const + { + using dstTy = std::int32_t; + using ReductionOpT = sycl::logical_and; + return dpctl::tensor::kernels:: + reduction_over_group_with_atomics_strided_impl; + } +}; + +template struct AllAxis1ContigFactory +{ + fnT get() const + { + using dstTy = std::int32_t; + using ReductionOpT = sycl::logical_and; + return dpctl::tensor::kernels:: + reduction_axis1_over_group_with_atomics_contig_impl; + } +}; + +template struct AllAxis0ContigFactory +{ + fnT get() const + { + using dstTy = std::int32_t; + using ReductionOpT = sycl::logical_and; + return dpctl::tensor::kernels:: + reduction_axis0_over_group_with_atomics_contig_impl; + } +}; + void populate_all_dispatch_vectors(void) { using td_ns::DispatchVectorBuilder; - using dpctl::tensor::kernels::boolean_reduction_strided_impl_fn_ptr; - - using dpctl::tensor::kernels::AllStridedFactory; - DispatchVectorBuilder + DispatchVectorBuilder all_dvb1; all_dvb1.populate_dispatch_vector(all_reduction_strided_dispatch_vector); - using dpctl::tensor::kernels::boolean_reduction_contig_impl_fn_ptr; - - using dpctl::tensor::kernels::AllAxis1ContigFactory; - DispatchVectorBuilder + DispatchVectorBuilder all_dvb2; all_dvb2.populate_dispatch_vector( all_reduction_axis1_contig_dispatch_vector); - using dpctl::tensor::kernels::AllAxis0ContigFactory; - DispatchVectorBuilder + DispatchVectorBuilder all_dvb3; all_dvb3.populate_dispatch_vector( all_reduction_axis0_contig_dispatch_vector); }; +using atomic_support::atomic_support_fn_ptr_t; +using atomic_support::check_atomic_support; +static atomic_support_fn_ptr_t all_atomic_support = + check_atomic_support; + } // namespace impl void init_all(py::module_ m) @@ -99,6 +133,8 @@ void init_all(py::module_ m) using impl::all_reduction_axis1_contig_dispatch_vector; using impl::all_reduction_strided_dispatch_vector; + using impl::all_atomic_support; + auto all_pyapi = [&](const arrayT &src, int trailing_dims_to_reduce, const arrayT &dst, sycl::queue &exec_q, const event_vecT &depends = {}) { @@ -106,7 +142,7 @@ void init_all(py::module_ m) src, trailing_dims_to_reduce, dst, exec_q, depends, all_reduction_axis1_contig_dispatch_vector, all_reduction_axis0_contig_dispatch_vector, - all_reduction_strided_dispatch_vector); + all_reduction_strided_dispatch_vector, all_atomic_support); }; m.def("_all", all_pyapi, "", py::arg("src"), py::arg("trailing_dims_to_reduce"), py::arg("dst"), diff --git a/dpctl/tensor/libtensor/source/reductions/any.cpp b/dpctl/tensor/libtensor/source/reductions/any.cpp index c191bb1f6b..3a7cce9e2c 100644 --- a/dpctl/tensor/libtensor/source/reductions/any.cpp +++ b/dpctl/tensor/libtensor/source/reductions/any.cpp @@ -29,7 +29,8 @@ #include #include -#include "kernels/boolean_reductions.hpp" +#include "kernels/reductions.hpp" +#include "reduction_atomic_support.hpp" #include "reduction_over_axis.hpp" #include "utils/type_dispatch.hpp" @@ -46,46 +47,80 @@ namespace td_ns = dpctl::tensor::type_dispatch; namespace impl { -using dpctl::tensor::kernels::boolean_reduction_strided_impl_fn_ptr; -static boolean_reduction_strided_impl_fn_ptr + +using dpctl::tensor::kernels::reduction_strided_impl_fn_ptr; +static reduction_strided_impl_fn_ptr any_reduction_strided_dispatch_vector[td_ns::num_types]; -using dpctl::tensor::kernels::boolean_reduction_contig_impl_fn_ptr; -static boolean_reduction_contig_impl_fn_ptr +using dpctl::tensor::kernels::reduction_contig_impl_fn_ptr; +static reduction_contig_impl_fn_ptr any_reduction_axis1_contig_dispatch_vector[td_ns::num_types]; - -static boolean_reduction_contig_impl_fn_ptr +static reduction_contig_impl_fn_ptr any_reduction_axis0_contig_dispatch_vector[td_ns::num_types]; +template struct AnyStridedFactory +{ + fnT get() const + { + using dstTy = std::int32_t; + using ReductionOpT = sycl::logical_or; + return dpctl::tensor::kernels:: + reduction_over_group_with_atomics_strided_impl; + } +}; + +template struct AnyAxis1ContigFactory +{ + fnT get() const + { + using dstTy = std::int32_t; + using ReductionOpT = sycl::logical_or; + return dpctl::tensor::kernels:: + reduction_axis1_over_group_with_atomics_contig_impl; + } +}; + +template struct AnyAxis0ContigFactory +{ + fnT get() const + { + using dstTy = std::int32_t; + using ReductionOpT = sycl::logical_or; + return dpctl::tensor::kernels:: + reduction_axis0_over_group_with_atomics_contig_impl; + } +}; + void populate_any_dispatch_vectors(void) { using td_ns::DispatchVectorBuilder; - using dpctl::tensor::kernels::boolean_reduction_strided_impl_fn_ptr; - - using dpctl::tensor::kernels::AnyStridedFactory; - DispatchVectorBuilder + DispatchVectorBuilder any_dvb1; any_dvb1.populate_dispatch_vector(any_reduction_strided_dispatch_vector); - using dpctl::tensor::kernels::boolean_reduction_contig_impl_fn_ptr; - - using dpctl::tensor::kernels::AnyAxis1ContigFactory; - DispatchVectorBuilder + DispatchVectorBuilder any_dvb2; any_dvb2.populate_dispatch_vector( any_reduction_axis1_contig_dispatch_vector); - using dpctl::tensor::kernels::AnyAxis0ContigFactory; - DispatchVectorBuilder + DispatchVectorBuilder any_dvb3; any_dvb3.populate_dispatch_vector( any_reduction_axis0_contig_dispatch_vector); }; +using atomic_support::atomic_support_fn_ptr_t; +using atomic_support::check_atomic_support; +static atomic_support_fn_ptr_t any_atomic_support = + check_atomic_support; + } // namespace impl void init_any(py::module_ m) @@ -98,6 +133,8 @@ void init_any(py::module_ m) using impl::any_reduction_axis1_contig_dispatch_vector; using impl::any_reduction_strided_dispatch_vector; + using impl::any_atomic_support; + auto any_pyapi = [&](const arrayT &src, int trailing_dims_to_reduce, const arrayT &dst, sycl::queue &exec_q, const event_vecT &depends = {}) { @@ -105,7 +142,7 @@ void init_any(py::module_ m) src, trailing_dims_to_reduce, dst, exec_q, depends, any_reduction_axis1_contig_dispatch_vector, any_reduction_axis0_contig_dispatch_vector, - any_reduction_strided_dispatch_vector); + any_reduction_strided_dispatch_vector, any_atomic_support); }; m.def("_any", any_pyapi, "", py::arg("src"), py::arg("trailing_dims_to_reduce"), py::arg("dst"), diff --git a/dpctl/tensor/libtensor/source/reductions/argmax.cpp b/dpctl/tensor/libtensor/source/reductions/argmax.cpp index bdf5deb33b..3331423ddc 100644 --- a/dpctl/tensor/libtensor/source/reductions/argmax.cpp +++ b/dpctl/tensor/libtensor/source/reductions/argmax.cpp @@ -27,11 +27,12 @@ #include #include #include +#include #include #include "kernels/reductions.hpp" #include "reduction_over_axis.hpp" -#include "utils/type_dispatch.hpp" +#include "utils/type_dispatch_building.hpp" namespace py = pybind11; @@ -61,24 +62,178 @@ static search_contig_impl_fn_ptr argmax_over_axis0_contig_temps_dispatch_table[td_ns::num_types] [td_ns::num_types]; +template +struct TypePairSupportForArgmaxReductionTemps +{ + + static constexpr bool is_defined = std::disjunction< // disjunction is C++17 + // feature, supported + // by DPC++ input bool + td_ns::TypePairDefinedEntry, + // input int8_t + td_ns::TypePairDefinedEntry, + + // input uint8_t + td_ns::TypePairDefinedEntry, + + // input int16_t + td_ns::TypePairDefinedEntry, + + // input uint16_t + td_ns::TypePairDefinedEntry, + + // input int32_t + td_ns::TypePairDefinedEntry, + // input uint32_t + td_ns::TypePairDefinedEntry, + + // input int64_t + td_ns::TypePairDefinedEntry, + + // input uint32_t + td_ns::TypePairDefinedEntry, + + // input half + td_ns::TypePairDefinedEntry, + + // input float + td_ns::TypePairDefinedEntry, + + // input double + td_ns::TypePairDefinedEntry, + + // input std::complex + td_ns::TypePairDefinedEntry, + outTy, + std::int64_t>, + + td_ns::TypePairDefinedEntry, + outTy, + std::int64_t>, + + // fall-through + td_ns::NotDefinedEntry>::is_defined; +}; + +template +struct ArgmaxOverAxisTempsStridedFactory +{ + fnT get() const + { + if constexpr (TypePairSupportForArgmaxReductionTemps::is_defined) + { + if constexpr (std::is_integral_v && + !std::is_same_v) { + // op for values + using ReductionOpT = sycl::maximum; + // op for indices + using IndexOpT = sycl::minimum; + return dpctl::tensor::kernels:: + search_over_group_temps_strided_impl< + srcTy, dstTy, ReductionOpT, IndexOpT>; + } + else { + // op for values + using ReductionOpT = su_ns::Maximum; + // op for indices + using IndexOpT = sycl::minimum; + return dpctl::tensor::kernels:: + search_over_group_temps_strided_impl< + srcTy, dstTy, ReductionOpT, IndexOpT>; + } + } + else { + return nullptr; + } + } +}; + +template +struct ArgmaxOverAxis1TempsContigFactory +{ + fnT get() const + { + if constexpr (TypePairSupportForArgmaxReductionTemps::is_defined) + { + if constexpr (std::is_integral_v && + !std::is_same_v) { + // op for values + using ReductionOpT = sycl::maximum; + // op for indices + using IndexOpT = sycl::minimum; + return dpctl::tensor::kernels:: + search_axis1_over_group_temps_contig_impl< + srcTy, dstTy, ReductionOpT, IndexOpT>; + } + else { + // op for values + using ReductionOpT = su_ns::Maximum; + // op for indices + using IndexOpT = sycl::minimum; + return dpctl::tensor::kernels:: + search_axis1_over_group_temps_contig_impl< + srcTy, dstTy, ReductionOpT, IndexOpT>; + } + } + else { + return nullptr; + } + } +}; + +template +struct ArgmaxOverAxis0TempsContigFactory +{ + fnT get() const + { + if constexpr (TypePairSupportForArgmaxReductionTemps::is_defined) + { + if constexpr (std::is_integral_v && + !std::is_same_v) { + // op for values + using ReductionOpT = sycl::maximum; + // op for indices + using IndexOpT = sycl::minimum; + return dpctl::tensor::kernels:: + search_axis0_over_group_temps_contig_impl< + srcTy, dstTy, ReductionOpT, IndexOpT>; + } + else { + // op for values + using ReductionOpT = su_ns::Maximum; + // op for indices + using IndexOpT = sycl::minimum; + return dpctl::tensor::kernels:: + search_axis0_over_group_temps_contig_impl< + srcTy, dstTy, ReductionOpT, IndexOpT>; + } + } + else { + return nullptr; + } + } +}; + void populate_argmax_over_axis_dispatch_tables(void) { using dpctl::tensor::kernels::search_strided_impl_fn_ptr; using td_ns::DispatchTableBuilder; - using dpctl::tensor::kernels::ArgmaxOverAxisTempsStridedFactory; DispatchTableBuilder dtb1; dtb1.populate_dispatch_table(argmax_over_axis_strided_temps_dispatch_table); - using dpctl::tensor::kernels::ArgmaxOverAxis1TempsContigFactory; DispatchTableBuilder dtb2; dtb2.populate_dispatch_table(argmax_over_axis1_contig_temps_dispatch_table); - using dpctl::tensor::kernels::ArgmaxOverAxis0TempsContigFactory; DispatchTableBuilder dtb3; diff --git a/dpctl/tensor/libtensor/source/reductions/argmin.cpp b/dpctl/tensor/libtensor/source/reductions/argmin.cpp index f620dc307c..582a96247c 100644 --- a/dpctl/tensor/libtensor/source/reductions/argmin.cpp +++ b/dpctl/tensor/libtensor/source/reductions/argmin.cpp @@ -27,11 +27,12 @@ #include #include #include +#include #include #include "kernels/reductions.hpp" #include "reduction_over_axis.hpp" -#include "utils/type_dispatch.hpp" +#include "utils/type_dispatch_building.hpp" namespace py = pybind11; @@ -61,24 +62,178 @@ static search_contig_impl_fn_ptr argmin_over_axis0_contig_temps_dispatch_table[td_ns::num_types] [td_ns::num_types]; +template +struct TypePairSupportForArgminReductionTemps +{ + + static constexpr bool is_defined = std::disjunction< // disjunction is C++17 + // feature, supported + // by DPC++ input bool + td_ns::TypePairDefinedEntry, + // input int8_t + td_ns::TypePairDefinedEntry, + + // input uint8_t + td_ns::TypePairDefinedEntry, + + // input int16_t + td_ns::TypePairDefinedEntry, + + // input uint16_t + td_ns::TypePairDefinedEntry, + + // input int32_t + td_ns::TypePairDefinedEntry, + // input uint32_t + td_ns::TypePairDefinedEntry, + + // input int64_t + td_ns::TypePairDefinedEntry, + + // input uint32_t + td_ns::TypePairDefinedEntry, + + // input half + td_ns::TypePairDefinedEntry, + + // input float + td_ns::TypePairDefinedEntry, + + // input double + td_ns::TypePairDefinedEntry, + + // input std::complex + td_ns::TypePairDefinedEntry, + outTy, + std::int64_t>, + + td_ns::TypePairDefinedEntry, + outTy, + std::int64_t>, + + // fall-through + td_ns::NotDefinedEntry>::is_defined; +}; + +template +struct ArgminOverAxisTempsStridedFactory +{ + fnT get() const + { + if constexpr (TypePairSupportForArgminReductionTemps::is_defined) + { + if constexpr (std::is_integral_v && + !std::is_same_v) { + // op for values + using ReductionOpT = sycl::minimum; + // op for indices + using IndexOpT = sycl::minimum; + return dpctl::tensor::kernels:: + search_over_group_temps_strided_impl< + srcTy, dstTy, ReductionOpT, IndexOpT>; + } + else { + // op for values + using ReductionOpT = su_ns::Minimum; + // op for indices + using IndexOpT = sycl::minimum; + return dpctl::tensor::kernels:: + search_over_group_temps_strided_impl< + srcTy, dstTy, ReductionOpT, IndexOpT>; + } + } + else { + return nullptr; + } + } +}; + +template +struct ArgminOverAxis1TempsContigFactory +{ + fnT get() const + { + if constexpr (TypePairSupportForArgminReductionTemps::is_defined) + { + if constexpr (std::is_integral_v && + !std::is_same_v) { + // op for values + using ReductionOpT = sycl::minimum; + // op for indices + using IndexOpT = sycl::minimum; + return dpctl::tensor::kernels:: + search_axis1_over_group_temps_contig_impl< + srcTy, dstTy, ReductionOpT, IndexOpT>; + } + else { + // op for values + using ReductionOpT = su_ns::Minimum; + // op for indices + using IndexOpT = sycl::minimum; + return dpctl::tensor::kernels:: + search_axis1_over_group_temps_contig_impl< + srcTy, dstTy, ReductionOpT, IndexOpT>; + } + } + else { + return nullptr; + } + } +}; + +template +struct ArgminOverAxis0TempsContigFactory +{ + fnT get() const + { + if constexpr (TypePairSupportForArgminReductionTemps::is_defined) + { + if constexpr (std::is_integral_v && + !std::is_same_v) { + // op for values + using ReductionOpT = sycl::minimum; + // op for indices + using IndexOpT = sycl::minimum; + return dpctl::tensor::kernels:: + search_axis0_over_group_temps_contig_impl< + srcTy, dstTy, ReductionOpT, IndexOpT>; + } + else { + // op for values + using ReductionOpT = su_ns::Minimum; + // op for indices + using IndexOpT = sycl::minimum; + return dpctl::tensor::kernels:: + search_axis0_over_group_temps_contig_impl< + srcTy, dstTy, ReductionOpT, IndexOpT>; + } + } + else { + return nullptr; + } + } +}; + void populate_argmin_over_axis_dispatch_tables(void) { using dpctl::tensor::kernels::search_strided_impl_fn_ptr; using td_ns::DispatchTableBuilder; - using dpctl::tensor::kernels::ArgminOverAxisTempsStridedFactory; DispatchTableBuilder dtb1; dtb1.populate_dispatch_table(argmin_over_axis_strided_temps_dispatch_table); - using dpctl::tensor::kernels::ArgminOverAxis1TempsContigFactory; DispatchTableBuilder dtb2; dtb2.populate_dispatch_table(argmin_over_axis1_contig_temps_dispatch_table); - using dpctl::tensor::kernels::ArgminOverAxis0TempsContigFactory; DispatchTableBuilder dtb3; diff --git a/dpctl/tensor/libtensor/source/reductions/logsumexp.cpp b/dpctl/tensor/libtensor/source/reductions/logsumexp.cpp index 3c9b0452fb..d2bb6e3877 100644 --- a/dpctl/tensor/libtensor/source/reductions/logsumexp.cpp +++ b/dpctl/tensor/libtensor/source/reductions/logsumexp.cpp @@ -27,11 +27,12 @@ #include #include #include +#include #include #include "kernels/reductions.hpp" #include "reduction_over_axis.hpp" -#include "utils/type_dispatch.hpp" +#include "utils/type_dispatch_building.hpp" namespace py = pybind11; @@ -60,27 +61,142 @@ static reduction_contig_impl_fn_ptr logsumexp_over_axis0_contig_temps_dispatch_table[td_ns::num_types] [td_ns::num_types]; +template +struct TypePairSupportDataForLogSumExpReductionTemps +{ + + static constexpr bool is_defined = std::disjunction< // disjunction is C++17 + // feature, supported + // by DPC++ input bool + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input int8_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input uint8_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input int16_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input uint16_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input int32_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input uint32_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input int64_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input uint64_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input half + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input float + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input double + td_ns::TypePairDefinedEntry, + + // fall-through + td_ns::NotDefinedEntry>::is_defined; +}; + +template +struct LogSumExpOverAxisTempsStridedFactory +{ + fnT get() const + { + if constexpr (TypePairSupportDataForLogSumExpReductionTemps< + srcTy, dstTy>::is_defined) + { + using ReductionOpT = su_ns::LogSumExp; + return dpctl::tensor::kernels:: + reduction_over_group_temps_strided_impl; + } + else { + return nullptr; + } + } +}; + +template +struct LogSumExpOverAxis1TempsContigFactory +{ + fnT get() const + { + if constexpr (TypePairSupportDataForLogSumExpReductionTemps< + srcTy, dstTy>::is_defined) + { + using ReductionOpT = su_ns::LogSumExp; + return dpctl::tensor::kernels:: + reduction_axis1_over_group_temps_contig_impl; + } + else { + return nullptr; + } + } +}; + +template +struct LogSumExpOverAxis0TempsContigFactory +{ + fnT get() const + { + if constexpr (TypePairSupportDataForLogSumExpReductionTemps< + srcTy, dstTy>::is_defined) + { + using ReductionOpT = su_ns::LogSumExp; + return dpctl::tensor::kernels:: + reduction_axis0_over_group_temps_contig_impl; + } + else { + return nullptr; + } + } +}; + void populate_logsumexp_over_axis_dispatch_tables(void) { using dpctl::tensor::kernels::reduction_contig_impl_fn_ptr; using dpctl::tensor::kernels::reduction_strided_impl_fn_ptr; using namespace td_ns; - using dpctl::tensor::kernels::LogSumExpOverAxisTempsStridedFactory; DispatchTableBuilder dtb1; dtb1.populate_dispatch_table( logsumexp_over_axis_strided_temps_dispatch_table); - using dpctl::tensor::kernels::LogSumExpOverAxis1TempsContigFactory; DispatchTableBuilder dtb2; dtb2.populate_dispatch_table( logsumexp_over_axis1_contig_temps_dispatch_table); - using dpctl::tensor::kernels::LogSumExpOverAxis0TempsContigFactory; DispatchTableBuilder dtb3; diff --git a/dpctl/tensor/libtensor/source/reductions/max.cpp b/dpctl/tensor/libtensor/source/reductions/max.cpp index 22d5232d32..cfd6daf06e 100644 --- a/dpctl/tensor/libtensor/source/reductions/max.cpp +++ b/dpctl/tensor/libtensor/source/reductions/max.cpp @@ -27,10 +27,11 @@ #include #include #include +#include #include #include "kernels/reductions.hpp" -#include "utils/type_dispatch.hpp" +#include "utils/type_dispatch_building.hpp" #include "reduction_atomic_support.hpp" #include "reduction_over_axis.hpp" @@ -71,41 +72,275 @@ static reduction_contig_impl_fn_ptr max_over_axis0_contig_temps_dispatch_table[td_ns::num_types] [td_ns::num_types]; +/* @brief Types supported by max reduction code based on atomic_ref */ +template +struct TypePairSupportDataForMaxReductionAtomic +{ + + /* value is true if a kernel for must be instantiated, false + * otherwise */ + // disjunction is C++17 feature, supported by DPC++ + static constexpr bool is_defined = std::disjunction< + // input int32 + td_ns::TypePairDefinedEntry, + // input uint32 + td_ns::TypePairDefinedEntry, + // input int64 + td_ns::TypePairDefinedEntry, + // input uint64 + td_ns::TypePairDefinedEntry, + // input float + td_ns::TypePairDefinedEntry, + // input double + td_ns::TypePairDefinedEntry, + // fall-through + td_ns::NotDefinedEntry>::is_defined; +}; + +template +struct TypePairSupportDataForMaxReductionTemps +{ + + // disjunction is C++17 feature, supported by DPC++ + static constexpr bool is_defined = std::disjunction< + // input bool + td_ns::TypePairDefinedEntry, + // input int8_t + td_ns::TypePairDefinedEntry, + // input uint8_t + td_ns::TypePairDefinedEntry, + + // input int16_t + td_ns::TypePairDefinedEntry, + // input uint16_t + td_ns::TypePairDefinedEntry, + + // input int32_t + td_ns::TypePairDefinedEntry, + // input uint32_t + td_ns::TypePairDefinedEntry, + + // input int64_t + td_ns::TypePairDefinedEntry, + + // input uint32_t + td_ns::TypePairDefinedEntry, + + // input half + td_ns::TypePairDefinedEntry, + + // input float + td_ns::TypePairDefinedEntry, + + // input double + td_ns::TypePairDefinedEntry, + + // input std::complex + td_ns::TypePairDefinedEntry, + outTy, + std::complex>, + + td_ns::TypePairDefinedEntry, + outTy, + std::complex>, + + // fall-through + td_ns::NotDefinedEntry>::is_defined; +}; + +template +struct MaxOverAxisAtomicStridedFactory +{ + fnT get() const + { + if constexpr (TypePairSupportDataForMaxReductionAtomic< + srcTy, dstTy>::is_defined) + { + if constexpr (std::is_floating_point::value) { + using ReductionOpT = su_ns::Maximum; + return dpctl::tensor::kernels:: + reduction_over_group_with_atomics_strided_impl< + srcTy, dstTy, ReductionOpT>; + } + else { + using ReductionOpT = sycl::maximum; + return dpctl::tensor::kernels:: + reduction_over_group_with_atomics_strided_impl< + srcTy, dstTy, ReductionOpT>; + } + } + else { + return nullptr; + } + } +}; + +template +struct MaxOverAxisTempsStridedFactory +{ + fnT get() const + { + if constexpr (TypePairSupportDataForMaxReductionTemps< + srcTy, dstTy>::is_defined) { + if constexpr (std::is_integral_v && + !std::is_same_v) { + using ReductionOpT = sycl::maximum; + return dpctl::tensor::kernels:: + reduction_over_group_temps_strided_impl; + } + else { + using ReductionOpT = su_ns::Maximum; + return dpctl::tensor::kernels:: + reduction_over_group_temps_strided_impl; + } + } + else { + return nullptr; + } + } +}; + +template +struct MaxOverAxis1AtomicContigFactory +{ + fnT get() const + { + if constexpr (TypePairSupportDataForMaxReductionAtomic< + srcTy, dstTy>::is_defined) + { + if constexpr (std::is_floating_point::value) { + using ReductionOpT = su_ns::Maximum; + return dpctl::tensor::kernels:: + reduction_axis1_over_group_with_atomics_contig_impl< + srcTy, dstTy, ReductionOpT>; + } + else { + using ReductionOpT = sycl::maximum; + return dpctl::tensor::kernels:: + reduction_axis1_over_group_with_atomics_contig_impl< + srcTy, dstTy, ReductionOpT>; + } + } + else { + return nullptr; + } + } +}; + +template +struct MaxOverAxis0AtomicContigFactory +{ + fnT get() const + { + if constexpr (TypePairSupportDataForMaxReductionAtomic< + srcTy, dstTy>::is_defined) + { + if constexpr (std::is_floating_point::value) { + using ReductionOpT = su_ns::Maximum; + return dpctl::tensor::kernels:: + reduction_axis0_over_group_with_atomics_contig_impl< + srcTy, dstTy, ReductionOpT>; + } + else { + using ReductionOpT = sycl::maximum; + return dpctl::tensor::kernels:: + reduction_axis0_over_group_with_atomics_contig_impl< + srcTy, dstTy, ReductionOpT>; + } + } + else { + return nullptr; + } + } +}; + +template +struct MaxOverAxis1TempsContigFactory +{ + fnT get() const + { + if constexpr (TypePairSupportDataForMaxReductionTemps< + srcTy, dstTy>::is_defined) { + if constexpr (std::is_integral_v && + !std::is_same_v) { + using ReductionOpT = sycl::maximum; + return dpctl::tensor::kernels:: + reduction_axis1_over_group_temps_contig_impl; + } + else { + using ReductionOpT = su_ns::Maximum; + return dpctl::tensor::kernels:: + reduction_axis1_over_group_temps_contig_impl; + } + } + else { + return nullptr; + } + } +}; + +template +struct MaxOverAxis0TempsContigFactory +{ + fnT get() const + { + if constexpr (TypePairSupportDataForMaxReductionTemps< + srcTy, dstTy>::is_defined) { + if constexpr (std::is_integral_v && + !std::is_same_v) { + using ReductionOpT = sycl::maximum; + return dpctl::tensor::kernels:: + reduction_axis0_over_group_temps_contig_impl; + } + else { + using ReductionOpT = su_ns::Maximum; + return dpctl::tensor::kernels:: + reduction_axis0_over_group_temps_contig_impl; + } + } + else { + return nullptr; + } + } +}; + void populate_max_over_axis_dispatch_tables(void) { using td_ns::DispatchTableBuilder; - using dpctl::tensor::kernels::MaxOverAxisAtomicStridedFactory; DispatchTableBuilder dtb1; dtb1.populate_dispatch_table(max_over_axis_strided_atomic_dispatch_table); - using dpctl::tensor::kernels::MaxOverAxisTempsStridedFactory; DispatchTableBuilder dtb2; dtb2.populate_dispatch_table(max_over_axis_strided_temps_dispatch_table); - using dpctl::tensor::kernels::MaxOverAxis1AtomicContigFactory; DispatchTableBuilder dtb3; dtb3.populate_dispatch_table(max_over_axis1_contig_atomic_dispatch_table); - using dpctl::tensor::kernels::MaxOverAxis0AtomicContigFactory; DispatchTableBuilder dtb4; dtb4.populate_dispatch_table(max_over_axis0_contig_atomic_dispatch_table); - using dpctl::tensor::kernels::MaxOverAxis1TempsContigFactory; DispatchTableBuilder dtb5; dtb5.populate_dispatch_table(max_over_axis1_contig_temps_dispatch_table); - using dpctl::tensor::kernels::MaxOverAxis0TempsContigFactory; DispatchTableBuilder dtb6; diff --git a/dpctl/tensor/libtensor/source/reductions/min.cpp b/dpctl/tensor/libtensor/source/reductions/min.cpp index cf5a5db414..3c2293f10c 100644 --- a/dpctl/tensor/libtensor/source/reductions/min.cpp +++ b/dpctl/tensor/libtensor/source/reductions/min.cpp @@ -27,10 +27,11 @@ #include #include #include +#include #include #include "kernels/reductions.hpp" -#include "utils/type_dispatch.hpp" +#include "utils/type_dispatch_building.hpp" #include "reduction_atomic_support.hpp" #include "reduction_over_axis.hpp" @@ -71,43 +72,277 @@ static reduction_contig_impl_fn_ptr min_over_axis0_contig_temps_dispatch_table[td_ns::num_types] [td_ns::num_types]; +/* @brief Types supported by min reduction code based on atomic_ref */ +template +struct TypePairSupportDataForMinReductionAtomic +{ + + /* value is true if a kernel for must be instantiated, false + * otherwise */ + // disjunction is C++17 feature, supported by DPC++ + static constexpr bool is_defined = std::disjunction< + // input int32 + td_ns::TypePairDefinedEntry, + // input uint32 + td_ns::TypePairDefinedEntry, + // input int64 + td_ns::TypePairDefinedEntry, + // input uint64 + td_ns::TypePairDefinedEntry, + // input float + td_ns::TypePairDefinedEntry, + // input double + td_ns::TypePairDefinedEntry, + // fall-through + td_ns::NotDefinedEntry>::is_defined; +}; + +template +struct TypePairSupportDataForMinReductionTemps +{ + + // disjunction is C++17 feature, supported by DPC++ + static constexpr bool is_defined = std::disjunction< + // input bool + td_ns::TypePairDefinedEntry, + // input int8_t + td_ns::TypePairDefinedEntry, + // input uint8_t + td_ns::TypePairDefinedEntry, + + // input int16_t + td_ns::TypePairDefinedEntry, + // input uint16_t + td_ns::TypePairDefinedEntry, + + // input int32_t + td_ns::TypePairDefinedEntry, + // input uint32_t + td_ns::TypePairDefinedEntry, + + // input int64_t + td_ns::TypePairDefinedEntry, + + // input uint32_t + td_ns::TypePairDefinedEntry, + + // input half + td_ns::TypePairDefinedEntry, + + // input float + td_ns::TypePairDefinedEntry, + + // input double + td_ns::TypePairDefinedEntry, + + // input std::complex + td_ns::TypePairDefinedEntry, + outTy, + std::complex>, + + td_ns::TypePairDefinedEntry, + outTy, + std::complex>, + + // fall-through + td_ns::NotDefinedEntry>::is_defined; +}; + +template +struct MinOverAxisAtomicStridedFactory +{ + fnT get() const + { + if constexpr (TypePairSupportDataForMinReductionAtomic< + srcTy, dstTy>::is_defined) + { + if constexpr (std::is_floating_point::value) { + using ReductionOpT = su_ns::Minimum; + return dpctl::tensor::kernels:: + reduction_over_group_with_atomics_strided_impl< + srcTy, dstTy, ReductionOpT>; + } + else { + using ReductionOpT = sycl::minimum; + return dpctl::tensor::kernels:: + reduction_over_group_with_atomics_strided_impl< + srcTy, dstTy, ReductionOpT>; + } + } + else { + return nullptr; + } + } +}; + +template +struct MinOverAxisTempsStridedFactory +{ + fnT get() const + { + if constexpr (TypePairSupportDataForMinReductionTemps< + srcTy, dstTy>::is_defined) { + if constexpr (std::is_integral_v && + !std::is_same_v) { + using ReductionOpT = sycl::minimum; + return dpctl::tensor::kernels:: + reduction_over_group_temps_strided_impl; + } + else { + using ReductionOpT = su_ns::Minimum; + return dpctl::tensor::kernels:: + reduction_over_group_temps_strided_impl; + } + } + else { + return nullptr; + } + } +}; + +template +struct MinOverAxis1AtomicContigFactory +{ + fnT get() const + { + if constexpr (TypePairSupportDataForMinReductionAtomic< + srcTy, dstTy>::is_defined) + { + if constexpr (std::is_floating_point::value) { + using ReductionOpT = su_ns::Minimum; + return dpctl::tensor::kernels:: + reduction_axis1_over_group_with_atomics_contig_impl< + srcTy, dstTy, ReductionOpT>; + } + else { + using ReductionOpT = sycl::minimum; + return dpctl::tensor::kernels:: + reduction_axis1_over_group_with_atomics_contig_impl< + srcTy, dstTy, ReductionOpT>; + } + } + else { + return nullptr; + } + } +}; + +template +struct MinOverAxis0AtomicContigFactory +{ + fnT get() const + { + if constexpr (TypePairSupportDataForMinReductionAtomic< + srcTy, dstTy>::is_defined) + { + if constexpr (std::is_floating_point::value) { + using ReductionOpT = su_ns::Minimum; + return dpctl::tensor::kernels:: + reduction_axis0_over_group_with_atomics_contig_impl< + srcTy, dstTy, ReductionOpT>; + } + else { + using ReductionOpT = sycl::minimum; + return dpctl::tensor::kernels:: + reduction_axis0_over_group_with_atomics_contig_impl< + srcTy, dstTy, ReductionOpT>; + } + } + else { + return nullptr; + } + } +}; + +template +struct MinOverAxis1TempsContigFactory +{ + fnT get() const + { + if constexpr (TypePairSupportDataForMinReductionTemps< + srcTy, dstTy>::is_defined) { + if constexpr (std::is_integral_v && + !std::is_same_v) { + using ReductionOpT = sycl::minimum; + return dpctl::tensor::kernels:: + reduction_axis1_over_group_temps_contig_impl; + } + else { + using ReductionOpT = su_ns::Minimum; + return dpctl::tensor::kernels:: + reduction_axis1_over_group_temps_contig_impl; + } + } + else { + return nullptr; + } + } +}; + +template +struct MinOverAxis0TempsContigFactory +{ + fnT get() const + { + if constexpr (TypePairSupportDataForMinReductionTemps< + srcTy, dstTy>::is_defined) { + if constexpr (std::is_integral_v && + !std::is_same_v) { + using ReductionOpT = sycl::minimum; + return dpctl::tensor::kernels:: + reduction_axis0_over_group_temps_contig_impl; + } + else { + using ReductionOpT = su_ns::Minimum; + return dpctl::tensor::kernels:: + reduction_axis0_over_group_temps_contig_impl; + } + } + else { + return nullptr; + } + } +}; + void populate_min_over_axis_dispatch_tables(void) { using dpctl::tensor::kernels::reduction_contig_impl_fn_ptr; using dpctl::tensor::kernels::reduction_strided_impl_fn_ptr; using td_ns::DispatchTableBuilder; - using dpctl::tensor::kernels::MinOverAxisAtomicStridedFactory; DispatchTableBuilder dtb1; dtb1.populate_dispatch_table(min_over_axis_strided_atomic_dispatch_table); - using dpctl::tensor::kernels::MinOverAxisTempsStridedFactory; DispatchTableBuilder dtb2; dtb2.populate_dispatch_table(min_over_axis_strided_temps_dispatch_table); - using dpctl::tensor::kernels::MinOverAxis1AtomicContigFactory; DispatchTableBuilder dtb3; dtb3.populate_dispatch_table(min_over_axis1_contig_atomic_dispatch_table); - using dpctl::tensor::kernels::MinOverAxis0AtomicContigFactory; DispatchTableBuilder dtb4; dtb4.populate_dispatch_table(min_over_axis0_contig_atomic_dispatch_table); - using dpctl::tensor::kernels::MinOverAxis1TempsContigFactory; DispatchTableBuilder dtb5; dtb5.populate_dispatch_table(min_over_axis1_contig_temps_dispatch_table); - using dpctl::tensor::kernels::MinOverAxis0TempsContigFactory; DispatchTableBuilder dtb6; diff --git a/dpctl/tensor/libtensor/source/reductions/prod.cpp b/dpctl/tensor/libtensor/source/reductions/prod.cpp index 1f52982d6d..9b8df53a01 100644 --- a/dpctl/tensor/libtensor/source/reductions/prod.cpp +++ b/dpctl/tensor/libtensor/source/reductions/prod.cpp @@ -27,10 +27,11 @@ #include #include #include +#include #include #include "kernels/reductions.hpp" -#include "utils/type_dispatch.hpp" +#include "utils/type_dispatch_building.hpp" #include "reduction_atomic_support.hpp" #include "reduction_over_axis.hpp" @@ -71,43 +72,311 @@ static reduction_contig_impl_fn_ptr prod_over_axis0_contig_temps_dispatch_table[td_ns::num_types] [td_ns::num_types]; +/* @brief Types supported by plus-reduction code based on atomic_ref */ +template +struct TypePairSupportDataForProductReductionAtomic +{ + + /* value if true a kernel for must be instantiated, false + * otherwise */ + static constexpr bool is_defined = std::disjunction< // disjunction is C++17 + // feature, supported + // by DPC++ input bool + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + // input int8 + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + // input uint8 + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + // input int16 + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + // input uint16 + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + // input int32 + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + // input uint32 + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + // input int64 + td_ns::TypePairDefinedEntry, + // input uint64 + td_ns::TypePairDefinedEntry, + // fall-through + td_ns::NotDefinedEntry>::is_defined; +}; + +template +struct TypePairSupportDataForProductReductionTemps +{ + + static constexpr bool is_defined = std::disjunction< // disjunction is C++17 + // feature, supported + // by DPC++ input bool + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input int8_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input uint8_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input int16_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input uint16_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input int32_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input uint32_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input int64_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input uint32_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input half + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns:: + TypePairDefinedEntry>, + td_ns::TypePairDefinedEntry>, + + // input float + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry>, + td_ns::TypePairDefinedEntry>, + + // input double + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry>, + + // input std::complex + td_ns::TypePairDefinedEntry, + outTy, + std::complex>, + td_ns::TypePairDefinedEntry, + outTy, + std::complex>, + + td_ns::TypePairDefinedEntry, + outTy, + std::complex>, + + // fall-throug + td_ns::NotDefinedEntry>::is_defined; +}; + +template +struct ProductOverAxisAtomicStridedFactory +{ + fnT get() const + { + if constexpr (TypePairSupportDataForProductReductionAtomic< + srcTy, dstTy>::is_defined) + { + using ReductionOpT = sycl::multiplies; + return dpctl::tensor::kernels:: + reduction_over_group_with_atomics_strided_impl; + } + else { + return nullptr; + } + } +}; + +template +struct ProductOverAxisTempsStridedFactory +{ + fnT get() const + { + if constexpr (TypePairSupportDataForProductReductionTemps< + srcTy, dstTy>::is_defined) + { + using ReductionOpT = sycl::multiplies; + return dpctl::tensor::kernels:: + reduction_over_group_temps_strided_impl; + } + else { + return nullptr; + } + } +}; + +template +struct ProductOverAxis1AtomicContigFactory +{ + fnT get() const + { + if constexpr (TypePairSupportDataForProductReductionAtomic< + srcTy, dstTy>::is_defined) + { + using ReductionOpT = sycl::multiplies; + return dpctl::tensor::kernels:: + reduction_axis1_over_group_with_atomics_contig_impl< + srcTy, dstTy, ReductionOpT>; + } + else { + return nullptr; + } + } +}; + +template +struct ProductOverAxis0AtomicContigFactory +{ + fnT get() const + { + if constexpr (TypePairSupportDataForProductReductionAtomic< + srcTy, dstTy>::is_defined) + { + using ReductionOpT = sycl::multiplies; + return dpctl::tensor::kernels:: + reduction_axis0_over_group_with_atomics_contig_impl< + srcTy, dstTy, ReductionOpT>; + } + else { + return nullptr; + } + } +}; + +template +struct ProductOverAxis1TempsContigFactory +{ + fnT get() const + { + if constexpr (TypePairSupportDataForProductReductionTemps< + srcTy, dstTy>::is_defined) + { + using ReductionOpT = sycl::multiplies; + return dpctl::tensor::kernels:: + reduction_axis1_over_group_temps_contig_impl; + } + else { + return nullptr; + } + } +}; + +template +struct ProductOverAxis0TempsContigFactory +{ + fnT get() const + { + if constexpr (TypePairSupportDataForProductReductionTemps< + srcTy, dstTy>::is_defined) + { + using ReductionOpT = sycl::multiplies; + return dpctl::tensor::kernels:: + reduction_axis0_over_group_temps_contig_impl; + } + else { + return nullptr; + } + } +}; + void populate_prod_over_axis_dispatch_tables(void) { using dpctl::tensor::kernels::reduction_contig_impl_fn_ptr; using dpctl::tensor::kernels::reduction_strided_impl_fn_ptr; using namespace td_ns; - using dpctl::tensor::kernels::ProductOverAxisAtomicStridedFactory; DispatchTableBuilder dtb1; dtb1.populate_dispatch_table(prod_over_axis_strided_atomic_dispatch_table); - using dpctl::tensor::kernels::ProductOverAxisTempsStridedFactory; DispatchTableBuilder dtb2; dtb2.populate_dispatch_table(prod_over_axis_strided_temps_dispatch_table); - using dpctl::tensor::kernels::ProductOverAxis1AtomicContigFactory; DispatchTableBuilder dtb3; dtb3.populate_dispatch_table(prod_over_axis1_contig_atomic_dispatch_table); - using dpctl::tensor::kernels::ProductOverAxis0AtomicContigFactory; DispatchTableBuilder dtb4; dtb4.populate_dispatch_table(prod_over_axis0_contig_atomic_dispatch_table); - using dpctl::tensor::kernels::ProductOverAxis1TempsContigFactory; DispatchTableBuilder dtb5; dtb5.populate_dispatch_table(prod_over_axis1_contig_temps_dispatch_table); - using dpctl::tensor::kernels::ProductOverAxis0TempsContigFactory; DispatchTableBuilder dtb6; diff --git a/dpctl/tensor/libtensor/source/reductions/reduce_hypot.cpp b/dpctl/tensor/libtensor/source/reductions/reduce_hypot.cpp index fbd90beb60..ce655126a6 100644 --- a/dpctl/tensor/libtensor/source/reductions/reduce_hypot.cpp +++ b/dpctl/tensor/libtensor/source/reductions/reduce_hypot.cpp @@ -27,11 +27,12 @@ #include #include #include +#include #include #include "kernels/reductions.hpp" #include "reduction_over_axis.hpp" -#include "utils/type_dispatch.hpp" +#include "utils/type_dispatch_building.hpp" namespace py = pybind11; @@ -60,25 +61,140 @@ static reduction_contig_impl_fn_ptr hypot_over_axis0_contig_temps_dispatch_table[td_ns::num_types] [td_ns::num_types]; +template +struct TypePairSupportDataForHypotReductionTemps +{ + + static constexpr bool is_defined = std::disjunction< // disjunction is C++17 + // feature, supported + // by DPC++ input bool + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input int8_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input uint8_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input int16_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input uint16_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input int32_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input uint32_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input int64_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input uint64_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input half + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input float + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input double + td_ns::TypePairDefinedEntry, + + // fall-through + td_ns::NotDefinedEntry>::is_defined; +}; + +template +struct HypotOverAxisTempsStridedFactory +{ + fnT get() const + { + if constexpr (TypePairSupportDataForHypotReductionTemps< + srcTy, dstTy>::is_defined) + { + using ReductionOpT = su_ns::Hypot; + return dpctl::tensor::kernels:: + reduction_over_group_temps_strided_impl; + } + else { + return nullptr; + } + } +}; + +template +struct HypotOverAxis1TempsContigFactory +{ + fnT get() const + { + if constexpr (TypePairSupportDataForHypotReductionTemps< + srcTy, dstTy>::is_defined) + { + using ReductionOpT = su_ns::Hypot; + return dpctl::tensor::kernels:: + reduction_axis1_over_group_temps_contig_impl; + } + else { + return nullptr; + } + } +}; + +template +struct HypotOverAxis0TempsContigFactory +{ + fnT get() const + { + if constexpr (TypePairSupportDataForHypotReductionTemps< + srcTy, dstTy>::is_defined) + { + using ReductionOpT = su_ns::Hypot; + return dpctl::tensor::kernels:: + reduction_axis0_over_group_temps_contig_impl; + } + else { + return nullptr; + } + } +}; + void populate_hypot_over_axis_dispatch_tables(void) { using dpctl::tensor::kernels::reduction_contig_impl_fn_ptr; using dpctl::tensor::kernels::reduction_strided_impl_fn_ptr; using namespace td_ns; - using dpctl::tensor::kernels::HypotOverAxisTempsStridedFactory; DispatchTableBuilder dtb1; dtb1.populate_dispatch_table(hypot_over_axis_strided_temps_dispatch_table); - using dpctl::tensor::kernels::HypotOverAxis1TempsContigFactory; DispatchTableBuilder dtb2; dtb2.populate_dispatch_table(hypot_over_axis1_contig_temps_dispatch_table); - using dpctl::tensor::kernels::HypotOverAxis0TempsContigFactory; DispatchTableBuilder dtb3; diff --git a/dpctl/tensor/libtensor/source/reductions/reduction_over_axis.hpp b/dpctl/tensor/libtensor/source/reductions/reduction_over_axis.hpp index 726d9d24c9..f9c61db5a8 100644 --- a/dpctl/tensor/libtensor/source/reductions/reduction_over_axis.hpp +++ b/dpctl/tensor/libtensor/source/reductions/reduction_over_axis.hpp @@ -37,7 +37,6 @@ #include #include -#include "kernels/boolean_reductions.hpp" #include "kernels/reductions.hpp" #include "simplify_iteration_space.hpp" #include "utils/memory_overlap.hpp" @@ -1074,7 +1073,9 @@ std::pair py_search_over_axis( /*! @brief Template implementing Python API for boolean reductions over an axis */ -template +template std::pair py_boolean_reduction(const dpctl::tensor::usm_ndarray &src, int trailing_dims_to_reduce, @@ -1083,7 +1084,8 @@ py_boolean_reduction(const dpctl::tensor::usm_ndarray &src, const std::vector &depends, const contig_dispatchT &axis1_contig_dispatch_vector, const contig_dispatchT &axis0_contig_dispatch_vector, - const strided_dispatchT &strided_dispatch_vector) + const strided_dispatchT &strided_dispatch_vector, + const atomic_support_fnT check_atomic_support) { int src_nd = src.get_ndim(); int iter_nd = src_nd - trailing_dims_to_reduce; @@ -1148,6 +1150,16 @@ py_boolean_reduction(const dpctl::tensor::usm_ndarray &src, "Unexpected data type of destination array, expecting 'int32'"); } + void *data_ptr = dst.get_data(); + const auto &ctx = exec_q.get_context(); + auto usm_type = sycl::get_pointer_type(data_ptr, ctx); + + bool supports_atomics = check_atomic_support(exec_q, usm_type); + if (!supports_atomics) { + throw py::value_error( + "This reduction is not supported for this device and usm_type."); + } + bool is_src_c_contig = src.is_c_contiguous(); bool is_src_f_contig = src.is_f_contiguous(); bool is_dst_c_contig = dst.is_c_contiguous(); diff --git a/dpctl/tensor/libtensor/source/reductions/sum.cpp b/dpctl/tensor/libtensor/source/reductions/sum.cpp index 79adf5eaed..6184f400fa 100644 --- a/dpctl/tensor/libtensor/source/reductions/sum.cpp +++ b/dpctl/tensor/libtensor/source/reductions/sum.cpp @@ -27,10 +27,11 @@ #include #include #include +#include #include #include "kernels/reductions.hpp" -#include "utils/type_dispatch.hpp" +#include "utils/type_dispatch_building.hpp" #include "reduction_atomic_support.hpp" #include "reduction_over_axis.hpp" @@ -71,43 +72,308 @@ static reduction_contig_impl_fn_ptr sum_over_axis0_contig_temps_dispatch_table[td_ns::num_types] [td_ns::num_types]; +/* @brief Types supported by plus-reduction code based on atomic_ref */ +template +struct TypePairSupportDataForSumReductionAtomic +{ + + /* value if true a kernel for must be instantiated, false + * otherwise */ + static constexpr bool is_defined = std::disjunction< // disjunction is C++17 + // feature, supported + // by DPC++ input bool + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + // input int8 + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + // input uint8 + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + // input int16 + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + // input uint16 + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + // input int32 + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + // input uint32 + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + // input int64 + td_ns::TypePairDefinedEntry, + // input uint64 + td_ns::TypePairDefinedEntry, + // fall-through + td_ns::NotDefinedEntry>::is_defined; +}; + +template +struct TypePairSupportDataForSumReductionTemps +{ + + static constexpr bool is_defined = std::disjunction< // disjunction is C++17 + // feature, supported + // by DPC++ input bool + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input int8_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input uint8_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input int16_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input uint16_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input int32_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input uint32_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input int64_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input uint64_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input half + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns:: + TypePairDefinedEntry>, + td_ns::TypePairDefinedEntry>, + + // input float + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry>, + td_ns::TypePairDefinedEntry>, + + // input double + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry>, + + // input std::complex + td_ns::TypePairDefinedEntry, + outTy, + std::complex>, + td_ns::TypePairDefinedEntry, + outTy, + std::complex>, + + td_ns::TypePairDefinedEntry, + outTy, + std::complex>, + + // fall-throug + td_ns::NotDefinedEntry>::is_defined; +}; + +template +struct SumOverAxisAtomicStridedFactory +{ + fnT get() const + { + if constexpr (TypePairSupportDataForSumReductionAtomic< + srcTy, dstTy>::is_defined) + { + using ReductionOpT = sycl::plus; + return dpctl::tensor::kernels:: + reduction_over_group_with_atomics_strided_impl; + } + else { + return nullptr; + } + } +}; + +template +struct SumOverAxisTempsStridedFactory +{ + fnT get() const + { + if constexpr (TypePairSupportDataForSumReductionTemps< + srcTy, dstTy>::is_defined) { + using ReductionOpT = sycl::plus; + return dpctl::tensor::kernels:: + reduction_over_group_temps_strided_impl; + } + else { + return nullptr; + } + } +}; + +template +struct SumOverAxis1AtomicContigFactory +{ + fnT get() const + { + if constexpr (TypePairSupportDataForSumReductionAtomic< + srcTy, dstTy>::is_defined) + { + using ReductionOpT = sycl::plus; + return dpctl::tensor::kernels:: + reduction_axis1_over_group_with_atomics_contig_impl< + srcTy, dstTy, ReductionOpT>; + } + else { + return nullptr; + } + } +}; + +template +struct SumOverAxis0AtomicContigFactory +{ + fnT get() const + { + if constexpr (TypePairSupportDataForSumReductionAtomic< + srcTy, dstTy>::is_defined) + { + using ReductionOpT = sycl::plus; + return dpctl::tensor::kernels:: + reduction_axis0_over_group_with_atomics_contig_impl< + srcTy, dstTy, ReductionOpT>; + } + else { + return nullptr; + } + } +}; + +template +struct SumOverAxis1TempsContigFactory +{ + fnT get() const + { + if constexpr (TypePairSupportDataForSumReductionTemps< + srcTy, dstTy>::is_defined) { + using ReductionOpT = sycl::plus; + return dpctl::tensor::kernels:: + reduction_axis1_over_group_temps_contig_impl; + } + else { + return nullptr; + } + } +}; + +template +struct SumOverAxis0TempsContigFactory +{ + fnT get() const + { + if constexpr (TypePairSupportDataForSumReductionTemps< + srcTy, dstTy>::is_defined) { + using ReductionOpT = sycl::plus; + return dpctl::tensor::kernels:: + reduction_axis0_over_group_temps_contig_impl; + } + else { + return nullptr; + } + } +}; + void populate_sum_over_axis_dispatch_tables(void) { using dpctl::tensor::kernels::reduction_contig_impl_fn_ptr; using dpctl::tensor::kernels::reduction_strided_impl_fn_ptr; using namespace td_ns; - using dpctl::tensor::kernels::SumOverAxisAtomicStridedFactory; DispatchTableBuilder dtb1; dtb1.populate_dispatch_table(sum_over_axis_strided_atomic_dispatch_table); - using dpctl::tensor::kernels::SumOverAxisTempsStridedFactory; DispatchTableBuilder dtb2; dtb2.populate_dispatch_table(sum_over_axis_strided_temps_dispatch_table); - using dpctl::tensor::kernels::SumOverAxis1AtomicContigFactory; DispatchTableBuilder dtb3; dtb3.populate_dispatch_table(sum_over_axis1_contig_atomic_dispatch_table); - using dpctl::tensor::kernels::SumOverAxis0AtomicContigFactory; DispatchTableBuilder dtb4; dtb4.populate_dispatch_table(sum_over_axis0_contig_atomic_dispatch_table); - using dpctl::tensor::kernels::SumOverAxis1TempsContigFactory; DispatchTableBuilder dtb5; dtb5.populate_dispatch_table(sum_over_axis1_contig_temps_dispatch_table); - using dpctl::tensor::kernels::SumOverAxis0TempsContigFactory; DispatchTableBuilder dtb6;