Skip to content

Fix reduction contig impl offset handling #1458

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Oct 27, 2023
53 changes: 33 additions & 20 deletions dpctl/tensor/libtensor/include/kernels/reductions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,18 @@ namespace tensor
namespace kernels
{

template <typename ReductionOpT, typename T> struct needs_workaround
{
static constexpr bool value =
std::is_same_v<ReductionOpT, sycl::multiplies<T>> &&
(std::is_same_v<T, std::int64_t> || std::is_same_v<T, std::uint64_t>);
};

template <typename ReductionOpT, typename T> struct can_use_reduce_over_group
{
static constexpr bool value =
sycl::has_known_identity<ReductionOpT, T>::value &&
!std::is_same_v<T, std::int64_t> && !std::is_same_v<T, std::uint64_t> &&
!std::is_same_v<ReductionOpT, sycl::multiplies<T>>;
!needs_workaround<ReductionOpT, T>::value;
};

template <typename argT,
Expand Down Expand Up @@ -1088,7 +1094,7 @@ sycl::event reduction_over_group_temps_strided_impl(
// max_max_wg prevents running out of resources on CPU
constexpr size_t max_max_wg = 2048;
size_t max_wg = std::min(
max_max_wg, d.get_info<sycl::info::device::max_work_group_size>());
max_max_wg, d.get_info<sycl::info::device::max_work_group_size>() / 2);

size_t reductions_per_wi(preferrered_reductions_per_wi);
if (reduction_nelems <= preferrered_reductions_per_wi * max_wg) {
Expand Down Expand Up @@ -1339,7 +1345,7 @@ sycl::event reduction_over_group_temps_strided_impl(
static_cast<py::ssize_t>(remaining_reduction_nelems)};
ResIndexerT res_iter_indexer{iter_nd, iter_res_offset,
/* shape */ iter_shape_and_strides,
/*s trides */ iter_shape_and_strides +
/* strides */ iter_shape_and_strides +
2 * iter_nd};

InputOutputIterIndexerT in_out_iter_indexer{inp_indexer,
Expand Down Expand Up @@ -1424,8 +1430,9 @@ sycl::event reduction_axis1_over_group_temps_contig_impl(
py::ssize_t reduction_arg_offset,
const std::vector<sycl::event> &depends)
{
const argTy *arg_tp = reinterpret_cast<const argTy *>(arg_cp);
resTy *res_tp = reinterpret_cast<resTy *>(res_cp);
const argTy *arg_tp = reinterpret_cast<const argTy *>(arg_cp) +
iter_arg_offset + reduction_arg_offset;
resTy *res_tp = reinterpret_cast<resTy *>(res_cp) + iter_res_offset;

constexpr resTy identity_val = su_ns::Identity<ReductionOpT, resTy>::value;

Expand All @@ -1437,7 +1444,7 @@ sycl::event reduction_axis1_over_group_temps_contig_impl(
// max_max_wg prevents running out of resources on CPU
constexpr size_t max_max_wg = 2048;
size_t max_wg = std::min(
max_max_wg, d.get_info<sycl::info::device::max_work_group_size>());
max_max_wg, d.get_info<sycl::info::device::max_work_group_size>() / 2);

size_t reductions_per_wi(preferrered_reductions_per_wi);
if (reduction_nelems <= preferrered_reductions_per_wi * max_wg) {
Expand Down Expand Up @@ -1767,8 +1774,9 @@ sycl::event reduction_axis0_over_group_temps_contig_impl(
py::ssize_t reduction_arg_offset,
const std::vector<sycl::event> &depends)
{
const argTy *arg_tp = reinterpret_cast<const argTy *>(arg_cp);
resTy *res_tp = reinterpret_cast<resTy *>(res_cp);
const argTy *arg_tp = reinterpret_cast<const argTy *>(arg_cp) +
iter_arg_offset + reduction_arg_offset;
resTy *res_tp = reinterpret_cast<resTy *>(res_cp) + iter_res_offset;

constexpr resTy identity_val = su_ns::Identity<ReductionOpT, resTy>::value;

Expand All @@ -1780,7 +1788,7 @@ sycl::event reduction_axis0_over_group_temps_contig_impl(
// max_max_wg prevents running out of resources on CPU
constexpr size_t max_max_wg = 2048;
size_t max_wg = std::min(
max_max_wg, d.get_info<sycl::info::device::max_work_group_size>());
max_max_wg, d.get_info<sycl::info::device::max_work_group_size>() / 2);

size_t reductions_per_wi(preferrered_reductions_per_wi);
if (reduction_nelems <= preferrered_reductions_per_wi * max_wg) {
Expand Down Expand Up @@ -3875,8 +3883,9 @@ sycl::event search_over_group_temps_strided_impl(

constexpr size_t preferrered_reductions_per_wi = 4;
// max_max_wg prevents running out of resources on CPU
size_t max_wg = std::min(
size_t(2048), d.get_info<sycl::info::device::max_work_group_size>());
size_t max_wg =
std::min(size_t(2048),
d.get_info<sycl::info::device::max_work_group_size>() / 2);

size_t reductions_per_wi(preferrered_reductions_per_wi);
if (reduction_nelems <= preferrered_reductions_per_wi * max_wg) {
Expand Down Expand Up @@ -4258,8 +4267,9 @@ sycl::event search_axis1_over_group_temps_contig_impl(
py::ssize_t reduction_arg_offset,
const std::vector<sycl::event> &depends)
{
const argTy *arg_tp = reinterpret_cast<const argTy *>(arg_cp);
resTy *res_tp = reinterpret_cast<resTy *>(res_cp);
const argTy *arg_tp = reinterpret_cast<const argTy *>(arg_cp) +
iter_arg_offset + reduction_arg_offset;
resTy *res_tp = reinterpret_cast<resTy *>(res_cp) + iter_res_offset;

constexpr argTy identity_val = su_ns::Identity<ReductionOpT, argTy>::value;
constexpr resTy idx_identity_val = su_ns::Identity<IndexOpT, resTy>::value;
Expand All @@ -4270,8 +4280,9 @@ sycl::event search_axis1_over_group_temps_contig_impl(

constexpr size_t preferrered_reductions_per_wi = 8;
// max_max_wg prevents running out of resources on CPU
size_t max_wg = std::min(
size_t(2048), d.get_info<sycl::info::device::max_work_group_size>());
size_t max_wg =
std::min(size_t(2048),
d.get_info<sycl::info::device::max_work_group_size>() / 2);

size_t reductions_per_wi(preferrered_reductions_per_wi);
if (reduction_nelems <= preferrered_reductions_per_wi * max_wg) {
Expand Down Expand Up @@ -4635,8 +4646,9 @@ sycl::event search_axis0_over_group_temps_contig_impl(
py::ssize_t reduction_arg_offset,
const std::vector<sycl::event> &depends)
{
const argTy *arg_tp = reinterpret_cast<const argTy *>(arg_cp);
resTy *res_tp = reinterpret_cast<resTy *>(res_cp);
const argTy *arg_tp = reinterpret_cast<const argTy *>(arg_cp) +
iter_arg_offset + reduction_arg_offset;
resTy *res_tp = reinterpret_cast<resTy *>(res_cp) + iter_res_offset;

constexpr argTy identity_val = su_ns::Identity<ReductionOpT, argTy>::value;
constexpr resTy idx_identity_val = su_ns::Identity<IndexOpT, resTy>::value;
Expand All @@ -4647,8 +4659,9 @@ sycl::event search_axis0_over_group_temps_contig_impl(

constexpr size_t preferrered_reductions_per_wi = 8;
// max_max_wg prevents running out of resources on CPU
size_t max_wg = std::min(
size_t(2048), d.get_info<sycl::info::device::max_work_group_size>());
size_t max_wg =
std::min(size_t(2048),
d.get_info<sycl::info::device::max_work_group_size>() / 2);

size_t reductions_per_wi(preferrered_reductions_per_wi);
if (reduction_nelems <= preferrered_reductions_per_wi * max_wg) {
Expand Down
9 changes: 9 additions & 0 deletions dpctl/tests/test_tensor_sum.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def test_sum_arg_dtype_default_output_dtype_matrix(arg_dtype):
q = get_queue_or_skip()
skip_if_dtype_not_supported(arg_dtype, q)

# test reduction for C-contiguous input
m = dpt.ones(100, dtype=arg_dtype)
r = dpt.sum(m)

Expand All @@ -55,12 +56,20 @@ def test_sum_arg_dtype_default_output_dtype_matrix(arg_dtype):
assert r.dtype.kind == "f"
elif m.dtype.kind == "c":
assert r.dtype.kind == "c"

assert dpt.all(r == 100)

# test reduction for strided input
m = dpt.ones(200, dtype=arg_dtype)[:1:-2]
r = dpt.sum(m)
assert dpt.all(r == 99)

# test reduction for strided input which can be simplified
# to contiguous computation
m = dpt.ones(100, dtype=arg_dtype)
r = dpt.sum(dpt.flip(m))
assert dpt.all(r == 100)


@pytest.mark.parametrize("arg_dtype", _all_dtypes)
@pytest.mark.parametrize("out_dtype", _all_dtypes[1:])
Expand Down
11 changes: 11 additions & 0 deletions dpctl/tests/test_usm_ndarray_reductions.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,17 @@ def test_search_reduction_kernels(arg_dtype):
m = dpt.argmax(x)
assert m == idx

# test case of strided input mapping to contig
# implementation
m = dpt.argmax(dpt.flip(x))
assert m == x.size - 1 - idx

# test case of strided implementation
y = dpt.ones(2 * x.size, dtype=arg_dtype, sycl_queue=q)
y[::2] = x
m = dpt.argmax(y)
assert m == 2 * idx

x = dpt.reshape(x, (24, 1025))

x[idx_tup[0], :] = 3
Expand Down