Skip to content

Commit 02e7714

Browse files
authored
Fix axis0 calls in reduction Python binding (#1459)
* max and min now use MinMaxAtomicSupportFactory These functions were using ArithmeticAtomicSupportFactory, which disables atomics for floating point types * Resolves #1455 This issue was caused by a typo where when the `axis0` kernels for tree and atomic reductions would be called, the `axis1` kernel would be called instead * Adds tests for #1455 resolution
1 parent d82f3a9 commit 02e7714

File tree

4 files changed

+74
-6
lines changed

4 files changed

+74
-6
lines changed

dpctl/tensor/libtensor/source/reductions/reduction_atomic_support.hpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -117,12 +117,12 @@ template <typename fnT, typename T> struct MinMaxAtomicSupportFactory
117117
};
118118

119119
template <typename fnT, typename T>
120-
struct MaxAtomicSupportFactory : public ArithmeticAtomicSupportFactory<fnT, T>
120+
struct MaxAtomicSupportFactory : public MinMaxAtomicSupportFactory<fnT, T>
121121
{
122122
};
123123

124124
template <typename fnT, typename T>
125-
struct MinAtomicSupportFactory : public ArithmeticAtomicSupportFactory<fnT, T>
125+
struct MinAtomicSupportFactory : public MinMaxAtomicSupportFactory<fnT, T>
126126
{
127127
};
128128

dpctl/tensor/libtensor/source/reductions/reduction_over_axis.hpp

+3-4
Original file line numberDiff line numberDiff line change
@@ -417,10 +417,10 @@ std::pair<sycl::event, sycl::event> py_reduction_over_axis(
417417
typename std::remove_all_extents<contig_fnT>::type;
418418
contig_fn_ptr_T fn;
419419
if (supports_atomics) {
420-
fn = axis1_atomic_dispatch_table[src_typeid][dst_typeid];
420+
fn = axis0_atomic_dispatch_table[src_typeid][dst_typeid];
421421
}
422422
else {
423-
fn = axis1_temps_dispatch_table[src_typeid][dst_typeid];
423+
fn = axis0_temps_dispatch_table[src_typeid][dst_typeid];
424424
}
425425
if (fn != nullptr) {
426426
sycl::event reduction_over_axis0_contig_ev =
@@ -727,7 +727,7 @@ std::pair<sycl::event, sycl::event> py_tree_reduction_over_axis(
727727
}
728728
}
729729
else if (mat_reduce_over_axis0) {
730-
auto fn = axis1_temps_dispatch_table[src_typeid][dst_typeid];
730+
auto fn = axis0_temps_dispatch_table[src_typeid][dst_typeid];
731731
if (fn != nullptr) {
732732
sycl::event reduction_over_axis0_contig_ev =
733733
fn(exec_q, iter_nelems, reduction_nelems, src.get_data(),
@@ -929,7 +929,6 @@ std::pair<sycl::event, sycl::event> py_search_over_axis(
929929
}
930930

931931
using dpctl::tensor::py_internal::simplify_iteration_space;
932-
using dpctl::tensor::py_internal::simplify_iteration_space_1;
933932

934933
auto const &src_shape_vecs = src.get_shape_vector();
935934
auto const &src_strides_vecs = src.get_strides_vector();

dpctl/tests/test_tensor_sum.py

+30
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,36 @@ def test_axis0_bug():
212212
assert dpt.all(s == expected)
213213

214214

215+
def test_sum_axis1_axis0():
216+
"""See gh-1455"""
217+
get_queue_or_skip()
218+
219+
# The atomic case is checked in `test_usm_ndarray_reductions`
220+
# This test checks the tree reduction path for correctness
221+
x = dpt.reshape(dpt.arange(3 * 4 * 5, dtype="f4"), (3, 4, 5))
222+
223+
m = dpt.sum(x, axis=0)
224+
expected = dpt.asarray(
225+
[
226+
[60, 63, 66, 69, 72],
227+
[75, 78, 81, 84, 87],
228+
[90, 93, 96, 99, 102],
229+
[105, 108, 111, 114, 117],
230+
],
231+
dtype="f4",
232+
)
233+
tol = dpt.finfo(m.dtype).resolution
234+
assert dpt.allclose(m, expected, atol=tol, rtol=tol)
235+
236+
x = dpt.flip(x, axis=2)
237+
m = dpt.sum(x, axis=2)
238+
expected = dpt.asarray(
239+
[[10, 35, 60, 85], [110, 135, 160, 185], [210, 235, 260, 285]],
240+
dtype="f4",
241+
)
242+
assert dpt.allclose(m, expected, atol=tol, rtol=tol)
243+
244+
215245
def _any_complex(dtypes):
216246
return any(dpt.isdtype(dpt.dtype(dt), "complex floating") for dt in dtypes)
217247

dpctl/tests/test_usm_ndarray_reductions.py

+39
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,20 @@ def test_max_min_axis():
6161
assert dpt.all(m == x[:, 0, 0, :, 0])
6262

6363

64+
def test_max_axis1_axis0():
65+
"""See gh-1455"""
66+
get_queue_or_skip()
67+
68+
x = dpt.reshape(dpt.arange(3 * 4 * 5), (3, 4, 5))
69+
70+
m = dpt.max(x, axis=0)
71+
assert dpt.all(m == x[-1, :, :])
72+
73+
x = dpt.flip(x, axis=2)
74+
m = dpt.max(x, axis=2)
75+
assert dpt.all(m == x[:, :, 0])
76+
77+
6478
def test_reduction_keepdims():
6579
get_queue_or_skip()
6680

@@ -440,3 +454,28 @@ def test_hypot_complex():
440454
x = dpt.zeros(1, dtype="c8")
441455
with pytest.raises(TypeError):
442456
dpt.reduce_hypot(x)
457+
458+
459+
def test_tree_reduction_axis1_axis0():
460+
"""See gh-1455"""
461+
get_queue_or_skip()
462+
463+
x = dpt.reshape(dpt.arange(3 * 4 * 5, dtype="f4"), (3, 4, 5))
464+
465+
m = dpt.logsumexp(x, axis=0)
466+
tol = dpt.finfo(m.dtype).resolution
467+
assert_allclose(
468+
dpt.asnumpy(m),
469+
np.logaddexp.reduce(dpt.asnumpy(x), axis=0, dtype=m.dtype),
470+
rtol=tol,
471+
atol=tol,
472+
)
473+
474+
x = dpt.flip(x, axis=2)
475+
m = dpt.logsumexp(x, axis=2)
476+
assert_allclose(
477+
dpt.asnumpy(m),
478+
np.logaddexp.reduce(dpt.asnumpy(x), axis=2, dtype=m.dtype),
479+
rtol=tol,
480+
atol=tol,
481+
)

0 commit comments

Comments
 (0)