@@ -3401,6 +3401,129 @@ struct LogSumExpOverAxis0TempsContigFactory
3401
3401
3402
3402
// Argmax and Argmin
3403
3403
3404
+ /* Sequential search reduction */
3405
+
3406
+ template <typename argT,
3407
+ typename outT,
3408
+ typename ReductionOp,
3409
+ typename IdxReductionOp,
3410
+ typename InputOutputIterIndexerT,
3411
+ typename InputRedIndexerT>
3412
+ struct SequentialSearchReduction
3413
+ {
3414
+ private:
3415
+ const argT *inp_ = nullptr ;
3416
+ outT *out_ = nullptr ;
3417
+ ReductionOp reduction_op_;
3418
+ argT identity_;
3419
+ IdxReductionOp idx_reduction_op_;
3420
+ outT idx_identity_;
3421
+ InputOutputIterIndexerT inp_out_iter_indexer_;
3422
+ InputRedIndexerT inp_reduced_dims_indexer_;
3423
+ size_t reduction_max_gid_ = 0 ;
3424
+
3425
+ public:
3426
+ SequentialSearchReduction (const argT *inp,
3427
+ outT *res,
3428
+ ReductionOp reduction_op,
3429
+ const argT &identity_val,
3430
+ IdxReductionOp idx_reduction_op,
3431
+ const outT &idx_identity_val,
3432
+ InputOutputIterIndexerT arg_res_iter_indexer,
3433
+ InputRedIndexerT arg_reduced_dims_indexer,
3434
+ size_t reduction_size)
3435
+ : inp_(inp), out_(res), reduction_op_(reduction_op),
3436
+ identity_ (identity_val), idx_reduction_op_(idx_reduction_op),
3437
+ idx_identity_(idx_identity_val),
3438
+ inp_out_iter_indexer_(arg_res_iter_indexer),
3439
+ inp_reduced_dims_indexer_(arg_reduced_dims_indexer),
3440
+ reduction_max_gid_(reduction_size)
3441
+ {
3442
+ }
3443
+
3444
+ void operator ()(sycl::id<1 > id) const
3445
+ {
3446
+
3447
+ auto const &inp_out_iter_offsets_ = inp_out_iter_indexer_ (id[0 ]);
3448
+ const py::ssize_t &inp_iter_offset =
3449
+ inp_out_iter_offsets_.get_first_offset ();
3450
+ const py::ssize_t &out_iter_offset =
3451
+ inp_out_iter_offsets_.get_second_offset ();
3452
+
3453
+ argT red_val (identity_);
3454
+ outT idx_val (idx_identity_);
3455
+ for (size_t m = 0 ; m < reduction_max_gid_; ++m) {
3456
+ const py::ssize_t inp_reduction_offset =
3457
+ inp_reduced_dims_indexer_ (m);
3458
+ const py::ssize_t inp_offset =
3459
+ inp_iter_offset + inp_reduction_offset;
3460
+
3461
+ argT val = inp_[inp_offset];
3462
+ if (val == red_val) {
3463
+ idx_val = idx_reduction_op_ (idx_val, static_cast <outT>(m));
3464
+ }
3465
+ else {
3466
+ if constexpr (su_ns::IsMinimum<argT, ReductionOp>::value) {
3467
+ using dpctl::tensor::type_utils::is_complex;
3468
+ if constexpr (is_complex<argT>::value) {
3469
+ using dpctl::tensor::math_utils::less_complex;
3470
+ // less_complex always returns false for NaNs, so check
3471
+ if (less_complex<argT>(val, red_val) ||
3472
+ std::isnan (std::real (val)) ||
3473
+ std::isnan (std::imag (val)))
3474
+ {
3475
+ red_val = val;
3476
+ idx_val = static_cast <outT>(m);
3477
+ }
3478
+ }
3479
+ else if constexpr (std::is_floating_point_v<argT> ||
3480
+ std::is_same_v<argT, sycl::half>)
3481
+ {
3482
+ if (val < red_val || std::isnan (val)) {
3483
+ red_val = val;
3484
+ idx_val = static_cast <outT>(m);
3485
+ }
3486
+ }
3487
+ else {
3488
+ if (val < red_val) {
3489
+ red_val = val;
3490
+ idx_val = static_cast <outT>(m);
3491
+ }
3492
+ }
3493
+ }
3494
+ else if constexpr (su_ns::IsMaximum<argT, ReductionOp>::value) {
3495
+ using dpctl::tensor::type_utils::is_complex;
3496
+ if constexpr (is_complex<argT>::value) {
3497
+ using dpctl::tensor::math_utils::greater_complex;
3498
+ if (greater_complex<argT>(val, red_val) ||
3499
+ std::isnan (std::real (val)) ||
3500
+ std::isnan (std::imag (val)))
3501
+ {
3502
+ red_val = val;
3503
+ idx_val = static_cast <outT>(m);
3504
+ }
3505
+ }
3506
+ else if constexpr (std::is_floating_point_v<argT> ||
3507
+ std::is_same_v<argT, sycl::half>)
3508
+ {
3509
+ if (val > red_val || std::isnan (val)) {
3510
+ red_val = val;
3511
+ idx_val = static_cast <outT>(m);
3512
+ }
3513
+ }
3514
+ else {
3515
+ if (val > red_val) {
3516
+ red_val = val;
3517
+ idx_val = static_cast <outT>(m);
3518
+ }
3519
+ }
3520
+ }
3521
+ }
3522
+ }
3523
+ out_[out_iter_offset] = idx_val;
3524
+ }
3525
+ };
3526
+
3404
3527
/* = Search reduction using reduce_over_group*/
3405
3528
3406
3529
template <typename argT,
@@ -3670,7 +3793,9 @@ struct CustomSearchReduction
3670
3793
}
3671
3794
}
3672
3795
}
3673
- else if constexpr (std::is_floating_point_v<argT>) {
3796
+ else if constexpr (std::is_floating_point_v<argT> ||
3797
+ std::is_same_v<argT, sycl::half>)
3798
+ {
3674
3799
if (val < local_red_val || std::isnan (val)) {
3675
3800
local_red_val = val;
3676
3801
if constexpr (!First) {
@@ -3714,7 +3839,9 @@ struct CustomSearchReduction
3714
3839
}
3715
3840
}
3716
3841
}
3717
- else if constexpr (std::is_floating_point_v<argT>) {
3842
+ else if constexpr (std::is_floating_point_v<argT> ||
3843
+ std::is_same_v<argT, sycl::half>)
3844
+ {
3718
3845
if (val > local_red_val || std::isnan (val)) {
3719
3846
local_red_val = val;
3720
3847
if constexpr (!First) {
@@ -3757,7 +3884,9 @@ struct CustomSearchReduction
3757
3884
? local_idx
3758
3885
: idx_identity_;
3759
3886
}
3760
- else if constexpr (std::is_floating_point_v<argT>) {
3887
+ else if constexpr (std::is_floating_point_v<argT> ||
3888
+ std::is_same_v<argT, sycl::half>)
3889
+ {
3761
3890
// equality does not hold for NaNs, so check here
3762
3891
local_idx =
3763
3892
(red_val_over_wg == local_red_val || std::isnan (local_red_val))
@@ -3799,6 +3928,14 @@ typedef sycl::event (*search_strided_impl_fn_ptr)(
3799
3928
py::ssize_t ,
3800
3929
const std::vector<sycl::event> &);
3801
3930
3931
+ template <typename T1,
3932
+ typename T2,
3933
+ typename T3,
3934
+ typename T4,
3935
+ typename T5,
3936
+ typename T6>
3937
+ class search_seq_strided_krn ;
3938
+
3802
3939
template <typename T1,
3803
3940
typename T2,
3804
3941
typename T3,
@@ -3820,6 +3957,14 @@ template <typename T1,
3820
3957
bool b2>
3821
3958
class custom_search_over_group_temps_strided_krn ;
3822
3959
3960
+ template <typename T1,
3961
+ typename T2,
3962
+ typename T3,
3963
+ typename T4,
3964
+ typename T5,
3965
+ typename T6>
3966
+ class search_seq_contig_krn ;
3967
+
3823
3968
template <typename T1,
3824
3969
typename T2,
3825
3970
typename T3,
@@ -4019,6 +4164,36 @@ sycl::event search_over_group_temps_strided_impl(
4019
4164
const auto &sg_sizes = d.get_info <sycl::info::device::sub_group_sizes>();
4020
4165
size_t wg = choose_workgroup_size<4 >(reduction_nelems, sg_sizes);
4021
4166
4167
+ if (reduction_nelems < wg) {
4168
+ sycl::event comp_ev = exec_q.submit ([&](sycl::handler &cgh) {
4169
+ cgh.depends_on (depends);
4170
+
4171
+ using InputOutputIterIndexerT =
4172
+ dpctl::tensor::offset_utils::TwoOffsets_StridedIndexer;
4173
+ using ReductionIndexerT =
4174
+ dpctl::tensor::offset_utils::StridedIndexer;
4175
+
4176
+ InputOutputIterIndexerT in_out_iter_indexer{
4177
+ iter_nd, iter_arg_offset, iter_res_offset,
4178
+ iter_shape_and_strides};
4179
+ ReductionIndexerT reduction_indexer{red_nd, reduction_arg_offset,
4180
+ reduction_shape_stride};
4181
+
4182
+ cgh.parallel_for <class search_seq_strided_krn <
4183
+ argTy, resTy, ReductionOpT, IndexOpT, InputOutputIterIndexerT,
4184
+ ReductionIndexerT>>(
4185
+ sycl::range<1 >(iter_nelems),
4186
+ SequentialSearchReduction<argTy, resTy, ReductionOpT, IndexOpT,
4187
+ InputOutputIterIndexerT,
4188
+ ReductionIndexerT>(
4189
+ arg_tp, res_tp, ReductionOpT (), identity_val, IndexOpT (),
4190
+ idx_identity_val, in_out_iter_indexer, reduction_indexer,
4191
+ reduction_nelems));
4192
+ });
4193
+
4194
+ return comp_ev;
4195
+ }
4196
+
4022
4197
constexpr size_t preferred_reductions_per_wi = 4 ;
4023
4198
// max_max_wg prevents running out of resources on CPU
4024
4199
size_t max_wg =
@@ -4419,6 +4594,39 @@ sycl::event search_axis1_over_group_temps_contig_impl(
4419
4594
const auto &sg_sizes = d.get_info <sycl::info::device::sub_group_sizes>();
4420
4595
size_t wg = choose_workgroup_size<4 >(reduction_nelems, sg_sizes);
4421
4596
4597
+ if (reduction_nelems < wg) {
4598
+ sycl::event comp_ev = exec_q.submit ([&](sycl::handler &cgh) {
4599
+ cgh.depends_on (depends);
4600
+
4601
+ using InputIterIndexerT =
4602
+ dpctl::tensor::offset_utils::Strided1DIndexer;
4603
+ using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer;
4604
+ using InputOutputIterIndexerT =
4605
+ dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer<
4606
+ InputIterIndexerT, NoOpIndexerT>;
4607
+ using ReductionIndexerT = NoOpIndexerT;
4608
+
4609
+ InputOutputIterIndexerT in_out_iter_indexer{
4610
+ InputIterIndexerT{0 , static_cast <py::ssize_t >(iter_nelems),
4611
+ static_cast <py::ssize_t >(reduction_nelems)},
4612
+ NoOpIndexerT{}};
4613
+ ReductionIndexerT reduction_indexer{};
4614
+
4615
+ cgh.parallel_for <class search_seq_contig_krn <
4616
+ argTy, resTy, ReductionOpT, IndexOpT, InputOutputIterIndexerT,
4617
+ ReductionIndexerT>>(
4618
+ sycl::range<1 >(iter_nelems),
4619
+ SequentialSearchReduction<argTy, resTy, ReductionOpT, IndexOpT,
4620
+ InputOutputIterIndexerT,
4621
+ ReductionIndexerT>(
4622
+ arg_tp, res_tp, ReductionOpT (), identity_val, IndexOpT (),
4623
+ idx_identity_val, in_out_iter_indexer, reduction_indexer,
4624
+ reduction_nelems));
4625
+ });
4626
+
4627
+ return comp_ev;
4628
+ }
4629
+
4422
4630
constexpr size_t preferred_reductions_per_wi = 8 ;
4423
4631
// max_max_wg prevents running out of resources on CPU
4424
4632
size_t max_wg =
@@ -4801,6 +5009,43 @@ sycl::event search_axis0_over_group_temps_contig_impl(
4801
5009
const auto &sg_sizes = d.get_info <sycl::info::device::sub_group_sizes>();
4802
5010
size_t wg = choose_workgroup_size<4 >(reduction_nelems, sg_sizes);
4803
5011
5012
+ if (reduction_nelems < wg) {
5013
+ sycl::event comp_ev = exec_q.submit ([&](sycl::handler &cgh) {
5014
+ cgh.depends_on (depends);
5015
+
5016
+ using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer;
5017
+ using InputOutputIterIndexerT =
5018
+ dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer<
5019
+ NoOpIndexerT, NoOpIndexerT>;
5020
+ using ReductionIndexerT =
5021
+ dpctl::tensor::offset_utils::Strided1DIndexer;
5022
+
5023
+ InputOutputIterIndexerT in_out_iter_indexer{NoOpIndexerT{},
5024
+ NoOpIndexerT{}};
5025
+ ReductionIndexerT reduction_indexer{
5026
+ 0 , static_cast <py::ssize_t >(reduction_nelems),
5027
+ static_cast <py::ssize_t >(iter_nelems)};
5028
+
5029
+ using KernelName =
5030
+ class search_seq_contig_krn <argTy, resTy, ReductionOpT,
5031
+ IndexOpT, InputOutputIterIndexerT,
5032
+ ReductionIndexerT>;
5033
+
5034
+ sycl::range<1 > iter_range{iter_nelems};
5035
+
5036
+ cgh.parallel_for <KernelName>(
5037
+ iter_range,
5038
+ SequentialSearchReduction<argTy, resTy, ReductionOpT, IndexOpT,
5039
+ InputOutputIterIndexerT,
5040
+ ReductionIndexerT>(
5041
+ arg_tp, res_tp, ReductionOpT (), identity_val, IndexOpT (),
5042
+ idx_identity_val, in_out_iter_indexer, reduction_indexer,
5043
+ reduction_nelems));
5044
+ });
5045
+
5046
+ return comp_ev;
5047
+ }
5048
+
4804
5049
constexpr size_t preferred_reductions_per_wi = 8 ;
4805
5050
// max_max_wg prevents running out of resources on CPU
4806
5051
size_t max_wg =
0 commit comments