Skip to content

Commit d82f3a9

Browse files
Merge pull request #1458 from IntelPython/fix-reduction-contig_impl-offset-handling
Fix reduction contig impl offset handling
2 parents 03fd737 + bfba152 commit d82f3a9

File tree

3 files changed

+53
-20
lines changed

3 files changed

+53
-20
lines changed

dpctl/tensor/libtensor/include/kernels/reductions.hpp

+33-20
Original file line numberDiff line numberDiff line change
@@ -50,12 +50,18 @@ namespace tensor
5050
namespace kernels
5151
{
5252

53+
template <typename ReductionOpT, typename T> struct needs_workaround
54+
{
55+
static constexpr bool value =
56+
std::is_same_v<ReductionOpT, sycl::multiplies<T>> &&
57+
(std::is_same_v<T, std::int64_t> || std::is_same_v<T, std::uint64_t>);
58+
};
59+
5360
template <typename ReductionOpT, typename T> struct can_use_reduce_over_group
5461
{
5562
static constexpr bool value =
5663
sycl::has_known_identity<ReductionOpT, T>::value &&
57-
!std::is_same_v<T, std::int64_t> && !std::is_same_v<T, std::uint64_t> &&
58-
!std::is_same_v<ReductionOpT, sycl::multiplies<T>>;
64+
!needs_workaround<ReductionOpT, T>::value;
5965
};
6066

6167
template <typename argT,
@@ -1088,7 +1094,7 @@ sycl::event reduction_over_group_temps_strided_impl(
10881094
// max_max_wg prevents running out of resources on CPU
10891095
constexpr size_t max_max_wg = 2048;
10901096
size_t max_wg = std::min(
1091-
max_max_wg, d.get_info<sycl::info::device::max_work_group_size>());
1097+
max_max_wg, d.get_info<sycl::info::device::max_work_group_size>() / 2);
10921098

10931099
size_t reductions_per_wi(preferrered_reductions_per_wi);
10941100
if (reduction_nelems <= preferrered_reductions_per_wi * max_wg) {
@@ -1339,7 +1345,7 @@ sycl::event reduction_over_group_temps_strided_impl(
13391345
static_cast<py::ssize_t>(remaining_reduction_nelems)};
13401346
ResIndexerT res_iter_indexer{iter_nd, iter_res_offset,
13411347
/* shape */ iter_shape_and_strides,
1342-
/*s trides */ iter_shape_and_strides +
1348+
/* strides */ iter_shape_and_strides +
13431349
2 * iter_nd};
13441350

13451351
InputOutputIterIndexerT in_out_iter_indexer{inp_indexer,
@@ -1424,8 +1430,9 @@ sycl::event reduction_axis1_over_group_temps_contig_impl(
14241430
py::ssize_t reduction_arg_offset,
14251431
const std::vector<sycl::event> &depends)
14261432
{
1427-
const argTy *arg_tp = reinterpret_cast<const argTy *>(arg_cp);
1428-
resTy *res_tp = reinterpret_cast<resTy *>(res_cp);
1433+
const argTy *arg_tp = reinterpret_cast<const argTy *>(arg_cp) +
1434+
iter_arg_offset + reduction_arg_offset;
1435+
resTy *res_tp = reinterpret_cast<resTy *>(res_cp) + iter_res_offset;
14291436

14301437
constexpr resTy identity_val = su_ns::Identity<ReductionOpT, resTy>::value;
14311438

@@ -1437,7 +1444,7 @@ sycl::event reduction_axis1_over_group_temps_contig_impl(
14371444
// max_max_wg prevents running out of resources on CPU
14381445
constexpr size_t max_max_wg = 2048;
14391446
size_t max_wg = std::min(
1440-
max_max_wg, d.get_info<sycl::info::device::max_work_group_size>());
1447+
max_max_wg, d.get_info<sycl::info::device::max_work_group_size>() / 2);
14411448

14421449
size_t reductions_per_wi(preferrered_reductions_per_wi);
14431450
if (reduction_nelems <= preferrered_reductions_per_wi * max_wg) {
@@ -1767,8 +1774,9 @@ sycl::event reduction_axis0_over_group_temps_contig_impl(
17671774
py::ssize_t reduction_arg_offset,
17681775
const std::vector<sycl::event> &depends)
17691776
{
1770-
const argTy *arg_tp = reinterpret_cast<const argTy *>(arg_cp);
1771-
resTy *res_tp = reinterpret_cast<resTy *>(res_cp);
1777+
const argTy *arg_tp = reinterpret_cast<const argTy *>(arg_cp) +
1778+
iter_arg_offset + reduction_arg_offset;
1779+
resTy *res_tp = reinterpret_cast<resTy *>(res_cp) + iter_res_offset;
17721780

17731781
constexpr resTy identity_val = su_ns::Identity<ReductionOpT, resTy>::value;
17741782

@@ -1780,7 +1788,7 @@ sycl::event reduction_axis0_over_group_temps_contig_impl(
17801788
// max_max_wg prevents running out of resources on CPU
17811789
constexpr size_t max_max_wg = 2048;
17821790
size_t max_wg = std::min(
1783-
max_max_wg, d.get_info<sycl::info::device::max_work_group_size>());
1791+
max_max_wg, d.get_info<sycl::info::device::max_work_group_size>() / 2);
17841792

17851793
size_t reductions_per_wi(preferrered_reductions_per_wi);
17861794
if (reduction_nelems <= preferrered_reductions_per_wi * max_wg) {
@@ -3875,8 +3883,9 @@ sycl::event search_over_group_temps_strided_impl(
38753883

38763884
constexpr size_t preferrered_reductions_per_wi = 4;
38773885
// max_max_wg prevents running out of resources on CPU
3878-
size_t max_wg = std::min(
3879-
size_t(2048), d.get_info<sycl::info::device::max_work_group_size>());
3886+
size_t max_wg =
3887+
std::min(size_t(2048),
3888+
d.get_info<sycl::info::device::max_work_group_size>() / 2);
38803889

38813890
size_t reductions_per_wi(preferrered_reductions_per_wi);
38823891
if (reduction_nelems <= preferrered_reductions_per_wi * max_wg) {
@@ -4258,8 +4267,9 @@ sycl::event search_axis1_over_group_temps_contig_impl(
42584267
py::ssize_t reduction_arg_offset,
42594268
const std::vector<sycl::event> &depends)
42604269
{
4261-
const argTy *arg_tp = reinterpret_cast<const argTy *>(arg_cp);
4262-
resTy *res_tp = reinterpret_cast<resTy *>(res_cp);
4270+
const argTy *arg_tp = reinterpret_cast<const argTy *>(arg_cp) +
4271+
iter_arg_offset + reduction_arg_offset;
4272+
resTy *res_tp = reinterpret_cast<resTy *>(res_cp) + iter_res_offset;
42634273

42644274
constexpr argTy identity_val = su_ns::Identity<ReductionOpT, argTy>::value;
42654275
constexpr resTy idx_identity_val = su_ns::Identity<IndexOpT, resTy>::value;
@@ -4270,8 +4280,9 @@ sycl::event search_axis1_over_group_temps_contig_impl(
42704280

42714281
constexpr size_t preferrered_reductions_per_wi = 8;
42724282
// max_max_wg prevents running out of resources on CPU
4273-
size_t max_wg = std::min(
4274-
size_t(2048), d.get_info<sycl::info::device::max_work_group_size>());
4283+
size_t max_wg =
4284+
std::min(size_t(2048),
4285+
d.get_info<sycl::info::device::max_work_group_size>() / 2);
42754286

42764287
size_t reductions_per_wi(preferrered_reductions_per_wi);
42774288
if (reduction_nelems <= preferrered_reductions_per_wi * max_wg) {
@@ -4635,8 +4646,9 @@ sycl::event search_axis0_over_group_temps_contig_impl(
46354646
py::ssize_t reduction_arg_offset,
46364647
const std::vector<sycl::event> &depends)
46374648
{
4638-
const argTy *arg_tp = reinterpret_cast<const argTy *>(arg_cp);
4639-
resTy *res_tp = reinterpret_cast<resTy *>(res_cp);
4649+
const argTy *arg_tp = reinterpret_cast<const argTy *>(arg_cp) +
4650+
iter_arg_offset + reduction_arg_offset;
4651+
resTy *res_tp = reinterpret_cast<resTy *>(res_cp) + iter_res_offset;
46404652

46414653
constexpr argTy identity_val = su_ns::Identity<ReductionOpT, argTy>::value;
46424654
constexpr resTy idx_identity_val = su_ns::Identity<IndexOpT, resTy>::value;
@@ -4647,8 +4659,9 @@ sycl::event search_axis0_over_group_temps_contig_impl(
46474659

46484660
constexpr size_t preferrered_reductions_per_wi = 8;
46494661
// max_max_wg prevents running out of resources on CPU
4650-
size_t max_wg = std::min(
4651-
size_t(2048), d.get_info<sycl::info::device::max_work_group_size>());
4662+
size_t max_wg =
4663+
std::min(size_t(2048),
4664+
d.get_info<sycl::info::device::max_work_group_size>() / 2);
46524665

46534666
size_t reductions_per_wi(preferrered_reductions_per_wi);
46544667
if (reduction_nelems <= preferrered_reductions_per_wi * max_wg) {

dpctl/tests/test_tensor_sum.py

+9
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ def test_sum_arg_dtype_default_output_dtype_matrix(arg_dtype):
4343
q = get_queue_or_skip()
4444
skip_if_dtype_not_supported(arg_dtype, q)
4545

46+
# test reduction for C-contiguous input
4647
m = dpt.ones(100, dtype=arg_dtype)
4748
r = dpt.sum(m)
4849

@@ -55,12 +56,20 @@ def test_sum_arg_dtype_default_output_dtype_matrix(arg_dtype):
5556
assert r.dtype.kind == "f"
5657
elif m.dtype.kind == "c":
5758
assert r.dtype.kind == "c"
59+
5860
assert dpt.all(r == 100)
5961

62+
# test reduction for strided input
6063
m = dpt.ones(200, dtype=arg_dtype)[:1:-2]
6164
r = dpt.sum(m)
6265
assert dpt.all(r == 99)
6366

67+
# test reduction for strided input which can be simplified
68+
# to contiguous computation
69+
m = dpt.ones(100, dtype=arg_dtype)
70+
r = dpt.sum(dpt.flip(m))
71+
assert dpt.all(r == 100)
72+
6473

6574
@pytest.mark.parametrize("arg_dtype", _all_dtypes)
6675
@pytest.mark.parametrize("out_dtype", _all_dtypes[1:])

dpctl/tests/test_usm_ndarray_reductions.py

+11
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,17 @@ def test_search_reduction_kernels(arg_dtype):
169169
m = dpt.argmax(x)
170170
assert m == idx
171171

172+
# test case of strided input mapping to contig
173+
# implementation
174+
m = dpt.argmax(dpt.flip(x))
175+
assert m == x.size - 1 - idx
176+
177+
# test case of strided implementation
178+
y = dpt.ones(2 * x.size, dtype=arg_dtype, sycl_queue=q)
179+
y[::2] = x
180+
m = dpt.argmax(y)
181+
assert m == 2 * idx
182+
172183
x = dpt.reshape(x, (24, 1025))
173184

174185
x[idx_tup[0], :] = 3

0 commit comments

Comments
 (0)