Skip to content

Commit 70e5aa8

Browse files
Use specilized kernel for f-arrays and sum by axis=1. Add keepdims support (#1489)
Use specilized kernel for f-arrays and sum by axis=1. Add keepdims support
1 parent 6a46e0e commit 70e5aa8

File tree

1 file changed

+18
-3
lines changed

1 file changed

+18
-3
lines changed

dpnp/dpnp_iface_mathematical.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1828,12 +1828,18 @@ def sum(
18281828
elif where is not True:
18291829
pass
18301830
else:
1831-
if axis == (0,) and len(x.shape) == 2 and not keepdims:
1831+
if len(x.shape) == 2 and (
1832+
(axis == (0,) and x.flags.c_contiguous)
1833+
or (axis == (1,) and x.flags.f_contiguous)
1834+
):
18321835
from dpctl.tensor._reduction import _default_reduction_dtype
18331836

18341837
from dpnp.backend.extensions.sycl_ext import _sycl_ext_impl
18351838

1836-
input = dpnp.get_usm_ndarray(x)
1839+
input = x
1840+
if axis == (1,):
1841+
input = input.T
1842+
input = dpnp.get_usm_ndarray(input)
18371843

18381844
queue = input.sycl_queue
18391845
out_dtype = (
@@ -1850,7 +1856,16 @@ def sum(
18501856

18511857
if sum:
18521858
sum(input, output, []).wait()
1853-
return dpnp_array._create_from_usm_ndarray(output)
1859+
result = dpnp_array._create_from_usm_ndarray(output)
1860+
1861+
if keepdims:
1862+
if axis == (0,):
1863+
res_sh = (1,) + output.shape
1864+
else:
1865+
res_sh = output.shape + (1,)
1866+
result = result.reshape(res_sh)
1867+
1868+
return result
18541869

18551870
y = dpt.sum(
18561871
dpnp.get_usm_ndarray(x), axis=axis, dtype=dtype, keepdims=keepdims

0 commit comments

Comments
 (0)