Skip to content

Commit dc877db

Browse files
Use specilized kernel for f-arrays and sum by axis=1. Add keepdims support
1 parent af1af29 commit dc877db

File tree

1 file changed

+16
-3
lines changed

1 file changed

+16
-3
lines changed

dpnp/dpnp_iface_mathematical.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1828,12 +1828,18 @@ def sum(
18281828
elif where is not True:
18291829
pass
18301830
else:
1831-
if axis == (0,) and len(x.shape) == 2 and not keepdims:
1831+
if len(x.shape) == 2 and (
1832+
(axis == (0,) and x.flags.c_contiguous)
1833+
or (axis == (1,) and x.flags.f_contiguous)
1834+
):
18321835
from dpctl.tensor._reduction import _default_reduction_dtype
18331836

18341837
from dpnp.backend.extensions.sycl_ext import _sycl_ext_impl
18351838

1836-
input = dpnp.get_usm_ndarray(x)
1839+
input = x
1840+
if axis == (1,):
1841+
input = input.T
1842+
input = dpnp.get_usm_ndarray(input)
18371843

18381844
queue = input.sycl_queue
18391845
out_dtype = (
@@ -1850,7 +1856,14 @@ def sum(
18501856

18511857
if sum:
18521858
sum(input, output, []).wait()
1853-
return dpnp_array._create_from_usm_ndarray(output)
1859+
result = dpnp_array._create_from_usm_ndarray(output)
1860+
1861+
if keepdims:
1862+
result = result.reshape((1,) + output.shape)
1863+
if axis == (1,):
1864+
result = result.T
1865+
1866+
return result
18541867

18551868
y = dpt.sum(
18561869
dpnp.get_usm_ndarray(x), axis=axis, dtype=dtype, keepdims=keepdims

0 commit comments

Comments
 (0)