Skip to content

Commit b8e9e7a

Browse files
Use specilized kernel for f-arrays and sum by axis=1. Add keepdims support
1 parent af1af29 commit b8e9e7a

File tree

1 file changed

+13
-3
lines changed

1 file changed

+13
-3
lines changed

dpnp/dpnp_iface_mathematical.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1828,12 +1828,15 @@ def sum(
18281828
elif where is not True:
18291829
pass
18301830
else:
1831-
if axis == (0,) and len(x.shape) == 2 and not keepdims:
1831+
if len(x.shape) == 2 and ((axis==(0,) and x.flags.c_contiguous) or (axis==(1,) and x.flags.f_contiguous)):
18321832
from dpctl.tensor._reduction import _default_reduction_dtype
18331833

18341834
from dpnp.backend.extensions.sycl_ext import _sycl_ext_impl
18351835

1836-
input = dpnp.get_usm_ndarray(x)
1836+
input = x
1837+
if axis == (1,):
1838+
input = input.T
1839+
input = dpnp.get_usm_ndarray(input)
18371840

18381841
queue = input.sycl_queue
18391842
out_dtype = (
@@ -1850,7 +1853,14 @@ def sum(
18501853

18511854
if sum:
18521855
sum(input, output, []).wait()
1853-
return dpnp_array._create_from_usm_ndarray(output)
1856+
result = dpnp_array._create_from_usm_ndarray(output)
1857+
1858+
if keepdims:
1859+
result = result.reshape((1,) + output.shape)
1860+
if (axis == (1,)):
1861+
result = result.T
1862+
1863+
return result
18541864

18551865
y = dpt.sum(
18561866
dpnp.get_usm_ndarray(x), axis=axis, dtype=dtype, keepdims=keepdims

0 commit comments

Comments
 (0)