diff --git a/dpnp/dpnp_iface_mathematical.py b/dpnp/dpnp_iface_mathematical.py index 2c45f05e2b9a..36e9804618f8 100644 --- a/dpnp/dpnp_iface_mathematical.py +++ b/dpnp/dpnp_iface_mathematical.py @@ -1828,12 +1828,18 @@ def sum( elif where is not True: pass else: - if axis == (0,) and len(x.shape) == 2 and not keepdims: + if len(x.shape) == 2 and ( + (axis == (0,) and x.flags.c_contiguous) + or (axis == (1,) and x.flags.f_contiguous) + ): from dpctl.tensor._reduction import _default_reduction_dtype from dpnp.backend.extensions.sycl_ext import _sycl_ext_impl - input = dpnp.get_usm_ndarray(x) + input = x + if axis == (1,): + input = input.T + input = dpnp.get_usm_ndarray(input) queue = input.sycl_queue out_dtype = ( @@ -1850,7 +1856,16 @@ def sum( if sum: sum(input, output, []).wait() - return dpnp_array._create_from_usm_ndarray(output) + result = dpnp_array._create_from_usm_ndarray(output) + + if keepdims: + if axis == (0,): + res_sh = (1,) + output.shape + else: + res_sh = output.shape + (1,) + result = result.reshape(res_sh) + + return result y = dpt.sum( dpnp.get_usm_ndarray(x), axis=axis, dtype=dtype, keepdims=keepdims