Skip to content

Commit af28d98

Browse files
authored
Improves performance of search reductions for small numbers of elements (#1464)
* Adds SequentialSearchReduction functor to search reductions * Search reductions use correct branch for float16 constexpr branch logic accounted for floating point types but not sycl::half, which meant NaNs were not propagating for float16 data
1 parent 097ecf5 commit af28d98

File tree

1 file changed

+248
-3
lines changed

1 file changed

+248
-3
lines changed

dpctl/tensor/libtensor/include/kernels/reductions.hpp

+248-3
Original file line numberDiff line numberDiff line change
@@ -3401,6 +3401,129 @@ struct LogSumExpOverAxis0TempsContigFactory
34013401

34023402
// Argmax and Argmin
34033403

3404+
/* Sequential search reduction */
3405+
3406+
template <typename argT,
3407+
typename outT,
3408+
typename ReductionOp,
3409+
typename IdxReductionOp,
3410+
typename InputOutputIterIndexerT,
3411+
typename InputRedIndexerT>
3412+
struct SequentialSearchReduction
3413+
{
3414+
private:
3415+
const argT *inp_ = nullptr;
3416+
outT *out_ = nullptr;
3417+
ReductionOp reduction_op_;
3418+
argT identity_;
3419+
IdxReductionOp idx_reduction_op_;
3420+
outT idx_identity_;
3421+
InputOutputIterIndexerT inp_out_iter_indexer_;
3422+
InputRedIndexerT inp_reduced_dims_indexer_;
3423+
size_t reduction_max_gid_ = 0;
3424+
3425+
public:
3426+
SequentialSearchReduction(const argT *inp,
3427+
outT *res,
3428+
ReductionOp reduction_op,
3429+
const argT &identity_val,
3430+
IdxReductionOp idx_reduction_op,
3431+
const outT &idx_identity_val,
3432+
InputOutputIterIndexerT arg_res_iter_indexer,
3433+
InputRedIndexerT arg_reduced_dims_indexer,
3434+
size_t reduction_size)
3435+
: inp_(inp), out_(res), reduction_op_(reduction_op),
3436+
identity_(identity_val), idx_reduction_op_(idx_reduction_op),
3437+
idx_identity_(idx_identity_val),
3438+
inp_out_iter_indexer_(arg_res_iter_indexer),
3439+
inp_reduced_dims_indexer_(arg_reduced_dims_indexer),
3440+
reduction_max_gid_(reduction_size)
3441+
{
3442+
}
3443+
3444+
void operator()(sycl::id<1> id) const
3445+
{
3446+
3447+
auto const &inp_out_iter_offsets_ = inp_out_iter_indexer_(id[0]);
3448+
const py::ssize_t &inp_iter_offset =
3449+
inp_out_iter_offsets_.get_first_offset();
3450+
const py::ssize_t &out_iter_offset =
3451+
inp_out_iter_offsets_.get_second_offset();
3452+
3453+
argT red_val(identity_);
3454+
outT idx_val(idx_identity_);
3455+
for (size_t m = 0; m < reduction_max_gid_; ++m) {
3456+
const py::ssize_t inp_reduction_offset =
3457+
inp_reduced_dims_indexer_(m);
3458+
const py::ssize_t inp_offset =
3459+
inp_iter_offset + inp_reduction_offset;
3460+
3461+
argT val = inp_[inp_offset];
3462+
if (val == red_val) {
3463+
idx_val = idx_reduction_op_(idx_val, static_cast<outT>(m));
3464+
}
3465+
else {
3466+
if constexpr (su_ns::IsMinimum<argT, ReductionOp>::value) {
3467+
using dpctl::tensor::type_utils::is_complex;
3468+
if constexpr (is_complex<argT>::value) {
3469+
using dpctl::tensor::math_utils::less_complex;
3470+
// less_complex always returns false for NaNs, so check
3471+
if (less_complex<argT>(val, red_val) ||
3472+
std::isnan(std::real(val)) ||
3473+
std::isnan(std::imag(val)))
3474+
{
3475+
red_val = val;
3476+
idx_val = static_cast<outT>(m);
3477+
}
3478+
}
3479+
else if constexpr (std::is_floating_point_v<argT> ||
3480+
std::is_same_v<argT, sycl::half>)
3481+
{
3482+
if (val < red_val || std::isnan(val)) {
3483+
red_val = val;
3484+
idx_val = static_cast<outT>(m);
3485+
}
3486+
}
3487+
else {
3488+
if (val < red_val) {
3489+
red_val = val;
3490+
idx_val = static_cast<outT>(m);
3491+
}
3492+
}
3493+
}
3494+
else if constexpr (su_ns::IsMaximum<argT, ReductionOp>::value) {
3495+
using dpctl::tensor::type_utils::is_complex;
3496+
if constexpr (is_complex<argT>::value) {
3497+
using dpctl::tensor::math_utils::greater_complex;
3498+
if (greater_complex<argT>(val, red_val) ||
3499+
std::isnan(std::real(val)) ||
3500+
std::isnan(std::imag(val)))
3501+
{
3502+
red_val = val;
3503+
idx_val = static_cast<outT>(m);
3504+
}
3505+
}
3506+
else if constexpr (std::is_floating_point_v<argT> ||
3507+
std::is_same_v<argT, sycl::half>)
3508+
{
3509+
if (val > red_val || std::isnan(val)) {
3510+
red_val = val;
3511+
idx_val = static_cast<outT>(m);
3512+
}
3513+
}
3514+
else {
3515+
if (val > red_val) {
3516+
red_val = val;
3517+
idx_val = static_cast<outT>(m);
3518+
}
3519+
}
3520+
}
3521+
}
3522+
}
3523+
out_[out_iter_offset] = idx_val;
3524+
}
3525+
};
3526+
34043527
/* = Search reduction using reduce_over_group*/
34053528

34063529
template <typename argT,
@@ -3670,7 +3793,9 @@ struct CustomSearchReduction
36703793
}
36713794
}
36723795
}
3673-
else if constexpr (std::is_floating_point_v<argT>) {
3796+
else if constexpr (std::is_floating_point_v<argT> ||
3797+
std::is_same_v<argT, sycl::half>)
3798+
{
36743799
if (val < local_red_val || std::isnan(val)) {
36753800
local_red_val = val;
36763801
if constexpr (!First) {
@@ -3714,7 +3839,9 @@ struct CustomSearchReduction
37143839
}
37153840
}
37163841
}
3717-
else if constexpr (std::is_floating_point_v<argT>) {
3842+
else if constexpr (std::is_floating_point_v<argT> ||
3843+
std::is_same_v<argT, sycl::half>)
3844+
{
37183845
if (val > local_red_val || std::isnan(val)) {
37193846
local_red_val = val;
37203847
if constexpr (!First) {
@@ -3757,7 +3884,9 @@ struct CustomSearchReduction
37573884
? local_idx
37583885
: idx_identity_;
37593886
}
3760-
else if constexpr (std::is_floating_point_v<argT>) {
3887+
else if constexpr (std::is_floating_point_v<argT> ||
3888+
std::is_same_v<argT, sycl::half>)
3889+
{
37613890
// equality does not hold for NaNs, so check here
37623891
local_idx =
37633892
(red_val_over_wg == local_red_val || std::isnan(local_red_val))
@@ -3799,6 +3928,14 @@ typedef sycl::event (*search_strided_impl_fn_ptr)(
37993928
py::ssize_t,
38003929
const std::vector<sycl::event> &);
38013930

3931+
template <typename T1,
3932+
typename T2,
3933+
typename T3,
3934+
typename T4,
3935+
typename T5,
3936+
typename T6>
3937+
class search_seq_strided_krn;
3938+
38023939
template <typename T1,
38033940
typename T2,
38043941
typename T3,
@@ -3820,6 +3957,14 @@ template <typename T1,
38203957
bool b2>
38213958
class custom_search_over_group_temps_strided_krn;
38223959

3960+
template <typename T1,
3961+
typename T2,
3962+
typename T3,
3963+
typename T4,
3964+
typename T5,
3965+
typename T6>
3966+
class search_seq_contig_krn;
3967+
38233968
template <typename T1,
38243969
typename T2,
38253970
typename T3,
@@ -4019,6 +4164,36 @@ sycl::event search_over_group_temps_strided_impl(
40194164
const auto &sg_sizes = d.get_info<sycl::info::device::sub_group_sizes>();
40204165
size_t wg = choose_workgroup_size<4>(reduction_nelems, sg_sizes);
40214166

4167+
if (reduction_nelems < wg) {
4168+
sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) {
4169+
cgh.depends_on(depends);
4170+
4171+
using InputOutputIterIndexerT =
4172+
dpctl::tensor::offset_utils::TwoOffsets_StridedIndexer;
4173+
using ReductionIndexerT =
4174+
dpctl::tensor::offset_utils::StridedIndexer;
4175+
4176+
InputOutputIterIndexerT in_out_iter_indexer{
4177+
iter_nd, iter_arg_offset, iter_res_offset,
4178+
iter_shape_and_strides};
4179+
ReductionIndexerT reduction_indexer{red_nd, reduction_arg_offset,
4180+
reduction_shape_stride};
4181+
4182+
cgh.parallel_for<class search_seq_strided_krn<
4183+
argTy, resTy, ReductionOpT, IndexOpT, InputOutputIterIndexerT,
4184+
ReductionIndexerT>>(
4185+
sycl::range<1>(iter_nelems),
4186+
SequentialSearchReduction<argTy, resTy, ReductionOpT, IndexOpT,
4187+
InputOutputIterIndexerT,
4188+
ReductionIndexerT>(
4189+
arg_tp, res_tp, ReductionOpT(), identity_val, IndexOpT(),
4190+
idx_identity_val, in_out_iter_indexer, reduction_indexer,
4191+
reduction_nelems));
4192+
});
4193+
4194+
return comp_ev;
4195+
}
4196+
40224197
constexpr size_t preferred_reductions_per_wi = 4;
40234198
// max_max_wg prevents running out of resources on CPU
40244199
size_t max_wg =
@@ -4419,6 +4594,39 @@ sycl::event search_axis1_over_group_temps_contig_impl(
44194594
const auto &sg_sizes = d.get_info<sycl::info::device::sub_group_sizes>();
44204595
size_t wg = choose_workgroup_size<4>(reduction_nelems, sg_sizes);
44214596

4597+
if (reduction_nelems < wg) {
4598+
sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) {
4599+
cgh.depends_on(depends);
4600+
4601+
using InputIterIndexerT =
4602+
dpctl::tensor::offset_utils::Strided1DIndexer;
4603+
using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer;
4604+
using InputOutputIterIndexerT =
4605+
dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer<
4606+
InputIterIndexerT, NoOpIndexerT>;
4607+
using ReductionIndexerT = NoOpIndexerT;
4608+
4609+
InputOutputIterIndexerT in_out_iter_indexer{
4610+
InputIterIndexerT{0, static_cast<py::ssize_t>(iter_nelems),
4611+
static_cast<py::ssize_t>(reduction_nelems)},
4612+
NoOpIndexerT{}};
4613+
ReductionIndexerT reduction_indexer{};
4614+
4615+
cgh.parallel_for<class search_seq_contig_krn<
4616+
argTy, resTy, ReductionOpT, IndexOpT, InputOutputIterIndexerT,
4617+
ReductionIndexerT>>(
4618+
sycl::range<1>(iter_nelems),
4619+
SequentialSearchReduction<argTy, resTy, ReductionOpT, IndexOpT,
4620+
InputOutputIterIndexerT,
4621+
ReductionIndexerT>(
4622+
arg_tp, res_tp, ReductionOpT(), identity_val, IndexOpT(),
4623+
idx_identity_val, in_out_iter_indexer, reduction_indexer,
4624+
reduction_nelems));
4625+
});
4626+
4627+
return comp_ev;
4628+
}
4629+
44224630
constexpr size_t preferred_reductions_per_wi = 8;
44234631
// max_max_wg prevents running out of resources on CPU
44244632
size_t max_wg =
@@ -4801,6 +5009,43 @@ sycl::event search_axis0_over_group_temps_contig_impl(
48015009
const auto &sg_sizes = d.get_info<sycl::info::device::sub_group_sizes>();
48025010
size_t wg = choose_workgroup_size<4>(reduction_nelems, sg_sizes);
48035011

5012+
if (reduction_nelems < wg) {
5013+
sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) {
5014+
cgh.depends_on(depends);
5015+
5016+
using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer;
5017+
using InputOutputIterIndexerT =
5018+
dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer<
5019+
NoOpIndexerT, NoOpIndexerT>;
5020+
using ReductionIndexerT =
5021+
dpctl::tensor::offset_utils::Strided1DIndexer;
5022+
5023+
InputOutputIterIndexerT in_out_iter_indexer{NoOpIndexerT{},
5024+
NoOpIndexerT{}};
5025+
ReductionIndexerT reduction_indexer{
5026+
0, static_cast<py::ssize_t>(reduction_nelems),
5027+
static_cast<py::ssize_t>(iter_nelems)};
5028+
5029+
using KernelName =
5030+
class search_seq_contig_krn<argTy, resTy, ReductionOpT,
5031+
IndexOpT, InputOutputIterIndexerT,
5032+
ReductionIndexerT>;
5033+
5034+
sycl::range<1> iter_range{iter_nelems};
5035+
5036+
cgh.parallel_for<KernelName>(
5037+
iter_range,
5038+
SequentialSearchReduction<argTy, resTy, ReductionOpT, IndexOpT,
5039+
InputOutputIterIndexerT,
5040+
ReductionIndexerT>(
5041+
arg_tp, res_tp, ReductionOpT(), identity_val, IndexOpT(),
5042+
idx_identity_val, in_out_iter_indexer, reduction_indexer,
5043+
reduction_nelems));
5044+
});
5045+
5046+
return comp_ev;
5047+
}
5048+
48045049
constexpr size_t preferred_reductions_per_wi = 8;
48055050
// max_max_wg prevents running out of resources on CPU
48065051
size_t max_wg =

0 commit comments

Comments
 (0)