@@ -50,12 +50,18 @@ namespace tensor
50
50
namespace kernels
51
51
{
52
52
53
+ template <typename ReductionOpT, typename T> struct needs_workaround
54
+ {
55
+ static constexpr bool value =
56
+ std::is_same_v<ReductionOpT, sycl::multiplies<T>> &&
57
+ (std::is_same_v<T, std::int64_t > || std::is_same_v<T, std::uint64_t >);
58
+ };
59
+
53
60
template <typename ReductionOpT, typename T> struct can_use_reduce_over_group
54
61
{
55
62
static constexpr bool value =
56
63
sycl::has_known_identity<ReductionOpT, T>::value &&
57
- !std::is_same_v<T, std::int64_t > && !std::is_same_v<T, std::uint64_t > &&
58
- !std::is_same_v<ReductionOpT, sycl::multiplies<T>>;
64
+ !needs_workaround<ReductionOpT, T>::value;
59
65
};
60
66
61
67
template <typename argT,
@@ -1088,7 +1094,7 @@ sycl::event reduction_over_group_temps_strided_impl(
1088
1094
// max_max_wg prevents running out of resources on CPU
1089
1095
constexpr size_t max_max_wg = 2048 ;
1090
1096
size_t max_wg = std::min (
1091
- max_max_wg, d.get_info <sycl::info::device::max_work_group_size>());
1097
+ max_max_wg, d.get_info <sycl::info::device::max_work_group_size>() / 2 );
1092
1098
1093
1099
size_t reductions_per_wi (preferrered_reductions_per_wi);
1094
1100
if (reduction_nelems <= preferrered_reductions_per_wi * max_wg) {
@@ -1339,7 +1345,7 @@ sycl::event reduction_over_group_temps_strided_impl(
1339
1345
static_cast <py::ssize_t >(remaining_reduction_nelems)};
1340
1346
ResIndexerT res_iter_indexer{iter_nd, iter_res_offset,
1341
1347
/* shape */ iter_shape_and_strides,
1342
- /* s trides */ iter_shape_and_strides +
1348
+ /* strides */ iter_shape_and_strides +
1343
1349
2 * iter_nd};
1344
1350
1345
1351
InputOutputIterIndexerT in_out_iter_indexer{inp_indexer,
@@ -1424,8 +1430,9 @@ sycl::event reduction_axis1_over_group_temps_contig_impl(
1424
1430
py::ssize_t reduction_arg_offset,
1425
1431
const std::vector<sycl::event> &depends)
1426
1432
{
1427
- const argTy *arg_tp = reinterpret_cast <const argTy *>(arg_cp);
1428
- resTy *res_tp = reinterpret_cast <resTy *>(res_cp);
1433
+ const argTy *arg_tp = reinterpret_cast <const argTy *>(arg_cp) +
1434
+ iter_arg_offset + reduction_arg_offset;
1435
+ resTy *res_tp = reinterpret_cast <resTy *>(res_cp) + iter_res_offset;
1429
1436
1430
1437
constexpr resTy identity_val = su_ns::Identity<ReductionOpT, resTy>::value;
1431
1438
@@ -1437,7 +1444,7 @@ sycl::event reduction_axis1_over_group_temps_contig_impl(
1437
1444
// max_max_wg prevents running out of resources on CPU
1438
1445
constexpr size_t max_max_wg = 2048 ;
1439
1446
size_t max_wg = std::min (
1440
- max_max_wg, d.get_info <sycl::info::device::max_work_group_size>());
1447
+ max_max_wg, d.get_info <sycl::info::device::max_work_group_size>() / 2 );
1441
1448
1442
1449
size_t reductions_per_wi (preferrered_reductions_per_wi);
1443
1450
if (reduction_nelems <= preferrered_reductions_per_wi * max_wg) {
@@ -1767,8 +1774,9 @@ sycl::event reduction_axis0_over_group_temps_contig_impl(
1767
1774
py::ssize_t reduction_arg_offset,
1768
1775
const std::vector<sycl::event> &depends)
1769
1776
{
1770
- const argTy *arg_tp = reinterpret_cast <const argTy *>(arg_cp);
1771
- resTy *res_tp = reinterpret_cast <resTy *>(res_cp);
1777
+ const argTy *arg_tp = reinterpret_cast <const argTy *>(arg_cp) +
1778
+ iter_arg_offset + reduction_arg_offset;
1779
+ resTy *res_tp = reinterpret_cast <resTy *>(res_cp) + iter_res_offset;
1772
1780
1773
1781
constexpr resTy identity_val = su_ns::Identity<ReductionOpT, resTy>::value;
1774
1782
@@ -1780,7 +1788,7 @@ sycl::event reduction_axis0_over_group_temps_contig_impl(
1780
1788
// max_max_wg prevents running out of resources on CPU
1781
1789
constexpr size_t max_max_wg = 2048 ;
1782
1790
size_t max_wg = std::min (
1783
- max_max_wg, d.get_info <sycl::info::device::max_work_group_size>());
1791
+ max_max_wg, d.get_info <sycl::info::device::max_work_group_size>() / 2 );
1784
1792
1785
1793
size_t reductions_per_wi (preferrered_reductions_per_wi);
1786
1794
if (reduction_nelems <= preferrered_reductions_per_wi * max_wg) {
@@ -3875,8 +3883,9 @@ sycl::event search_over_group_temps_strided_impl(
3875
3883
3876
3884
constexpr size_t preferrered_reductions_per_wi = 4 ;
3877
3885
// max_max_wg prevents running out of resources on CPU
3878
- size_t max_wg = std::min (
3879
- size_t (2048 ), d.get_info <sycl::info::device::max_work_group_size>());
3886
+ size_t max_wg =
3887
+ std::min (size_t (2048 ),
3888
+ d.get_info <sycl::info::device::max_work_group_size>() / 2 );
3880
3889
3881
3890
size_t reductions_per_wi (preferrered_reductions_per_wi);
3882
3891
if (reduction_nelems <= preferrered_reductions_per_wi * max_wg) {
@@ -4258,8 +4267,9 @@ sycl::event search_axis1_over_group_temps_contig_impl(
4258
4267
py::ssize_t reduction_arg_offset,
4259
4268
const std::vector<sycl::event> &depends)
4260
4269
{
4261
- const argTy *arg_tp = reinterpret_cast <const argTy *>(arg_cp);
4262
- resTy *res_tp = reinterpret_cast <resTy *>(res_cp);
4270
+ const argTy *arg_tp = reinterpret_cast <const argTy *>(arg_cp) +
4271
+ iter_arg_offset + reduction_arg_offset;
4272
+ resTy *res_tp = reinterpret_cast <resTy *>(res_cp) + iter_res_offset;
4263
4273
4264
4274
constexpr argTy identity_val = su_ns::Identity<ReductionOpT, argTy>::value;
4265
4275
constexpr resTy idx_identity_val = su_ns::Identity<IndexOpT, resTy>::value;
@@ -4270,8 +4280,9 @@ sycl::event search_axis1_over_group_temps_contig_impl(
4270
4280
4271
4281
constexpr size_t preferrered_reductions_per_wi = 8 ;
4272
4282
// max_max_wg prevents running out of resources on CPU
4273
- size_t max_wg = std::min (
4274
- size_t (2048 ), d.get_info <sycl::info::device::max_work_group_size>());
4283
+ size_t max_wg =
4284
+ std::min (size_t (2048 ),
4285
+ d.get_info <sycl::info::device::max_work_group_size>() / 2 );
4275
4286
4276
4287
size_t reductions_per_wi (preferrered_reductions_per_wi);
4277
4288
if (reduction_nelems <= preferrered_reductions_per_wi * max_wg) {
@@ -4635,8 +4646,9 @@ sycl::event search_axis0_over_group_temps_contig_impl(
4635
4646
py::ssize_t reduction_arg_offset,
4636
4647
const std::vector<sycl::event> &depends)
4637
4648
{
4638
- const argTy *arg_tp = reinterpret_cast <const argTy *>(arg_cp);
4639
- resTy *res_tp = reinterpret_cast <resTy *>(res_cp);
4649
+ const argTy *arg_tp = reinterpret_cast <const argTy *>(arg_cp) +
4650
+ iter_arg_offset + reduction_arg_offset;
4651
+ resTy *res_tp = reinterpret_cast <resTy *>(res_cp) + iter_res_offset;
4640
4652
4641
4653
constexpr argTy identity_val = su_ns::Identity<ReductionOpT, argTy>::value;
4642
4654
constexpr resTy idx_identity_val = su_ns::Identity<IndexOpT, resTy>::value;
@@ -4647,8 +4659,9 @@ sycl::event search_axis0_over_group_temps_contig_impl(
4647
4659
4648
4660
constexpr size_t preferrered_reductions_per_wi = 8 ;
4649
4661
// max_max_wg prevents running out of resources on CPU
4650
- size_t max_wg = std::min (
4651
- size_t (2048 ), d.get_info <sycl::info::device::max_work_group_size>());
4662
+ size_t max_wg =
4663
+ std::min (size_t (2048 ),
4664
+ d.get_info <sycl::info::device::max_work_group_size>() / 2 );
4652
4665
4653
4666
size_t reductions_per_wi (preferrered_reductions_per_wi);
4654
4667
if (reduction_nelems <= preferrered_reductions_per_wi * max_wg) {
0 commit comments