Skip to content

Support more use cases for 'out' parameter #1341

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions dpnp/backend/kernels/dpnp_krnl_bitwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -281,8 +281,8 @@ static void func_map_init_bitwise_1arg_1type(func_map_t& fmap)
const size_t output_id = global_id[0]; /* for (size_t i = 0; i < result_size; ++i) */ \
{ \
const shape_elem_type* result_strides_data = &dev_strides_data[0]; \
const shape_elem_type* input1_strides_data = &dev_strides_data[1]; \
const shape_elem_type* input2_strides_data = &dev_strides_data[2]; \
const shape_elem_type* input1_strides_data = &dev_strides_data[result_ndim]; \
const shape_elem_type* input2_strides_data = &dev_strides_data[2 * result_ndim]; \
\
size_t input1_id = 0; \
size_t input2_id = 0; \
Expand Down
8 changes: 4 additions & 4 deletions dpnp/backend/kernels/dpnp_krnl_elemwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@
size_t output_id = global_id[0]; /* for (size_t i = 0; i < result_size; ++i) */ \
{ \
const shape_elem_type* result_strides_data = &dev_strides_data[0]; \
const shape_elem_type* input1_strides_data = &dev_strides_data[1]; \
const shape_elem_type* input1_strides_data = &dev_strides_data[result_ndim]; \
\
size_t input_id = 0; \
for (size_t i = 0; i < input1_ndim; ++i) \
Expand Down Expand Up @@ -635,7 +635,7 @@ static void func_map_init_elemwise_1arg_2type(func_map_t& fmap)
size_t output_id = global_id[0]; /* for (size_t i = 0; i < result_size; ++i) */ \
{ \
const shape_elem_type* result_strides_data = &dev_strides_data[0]; \
const shape_elem_type* input1_strides_data = &dev_strides_data[1]; \
const shape_elem_type* input1_strides_data = &dev_strides_data[result_ndim]; \
\
size_t input_id = 0; \
for (size_t i = 0; i < input1_ndim; ++i) \
Expand Down Expand Up @@ -995,8 +995,8 @@ static void func_map_init_elemwise_1arg_1type(func_map_t& fmap)
const size_t output_id = global_id[0]; /* for (size_t i = 0; i < result_size; ++i) */ \
{ \
const shape_elem_type* result_strides_data = &dev_strides_data[0]; \
const shape_elem_type* input1_strides_data = &dev_strides_data[1]; \
const shape_elem_type* input2_strides_data = &dev_strides_data[2]; \
const shape_elem_type* input1_strides_data = &dev_strides_data[result_ndim]; \
const shape_elem_type* input2_strides_data = &dev_strides_data[2 * result_ndim]; \
\
size_t input1_id = 0; \
size_t input2_id = 0; \
Expand Down
6 changes: 3 additions & 3 deletions dpnp/backend/kernels/dpnp_krnl_logic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,7 @@ DPCTLSyclEventRef (*dpnp_any_ext_c)(DPCTLSyclQueueRef,
const size_t output_id = global_id[0]; /* for (size_t i = 0; i < result_size; ++i) */ \
{ \
const shape_elem_type *result_strides_data = &dev_strides_data[0]; \
const shape_elem_type *input1_strides_data = &dev_strides_data[1]; \
const shape_elem_type *input1_strides_data = &dev_strides_data[result_ndim]; \
\
size_t input1_id = 0; \
\
Expand Down Expand Up @@ -635,8 +635,8 @@ static void func_map_logic_1arg_1type_helper(func_map_t& fmap)
const size_t output_id = global_id[0]; /* for (size_t i = 0; i < result_size; ++i) */ \
{ \
const shape_elem_type *result_strides_data = &dev_strides_data[0]; \
const shape_elem_type *input1_strides_data = &dev_strides_data[1]; \
const shape_elem_type *input2_strides_data = &dev_strides_data[2]; \
const shape_elem_type *input1_strides_data = &dev_strides_data[result_ndim]; \
const shape_elem_type *input2_strides_data = &dev_strides_data[2 * result_ndim]; \
\
size_t input1_id = 0; \
size_t input2_id = 0; \
Expand Down
6 changes: 3 additions & 3 deletions dpnp/backend/kernels/dpnp_krnl_searching.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -294,9 +294,9 @@ DPCTLSyclEventRef dpnp_where_c(DPCTLSyclQueueRef q_ref,
const size_t output_id = global_id[0]; /* for (size_t i = 0; i < result_size; ++i) */
{
const shape_elem_type* result_strides_data = &dev_strides_data[0];
const shape_elem_type* condition_strides_data = &dev_strides_data[1];
const shape_elem_type* input1_strides_data = &dev_strides_data[2];
const shape_elem_type* input2_strides_data = &dev_strides_data[3];
const shape_elem_type* condition_strides_data = &dev_strides_data[result_ndim];
const shape_elem_type* input1_strides_data = &dev_strides_data[2 * result_ndim];
const shape_elem_type* input2_strides_data = &dev_strides_data[3 * result_ndim];

size_t condition_id = 0;
size_t input1_id = 0;
Expand Down
33 changes: 22 additions & 11 deletions dpnp/dpnp_algo/dpnp_algo.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -505,25 +505,33 @@ cdef utils.dpnp_descriptor call_fptr_2in_1out_strides(DPNPFuncName fptr_name,
return_type = kernel_data.return_type_no_fp64
func = < fptr_2in_1out_strides_t > kernel_data.ptr_no_fp64

if out is None:
""" Create result array with type given by FPTR data """
# check 'out' parameter data
if out is not None:
if out.shape != result_shape:
utils.checker_throw_value_error(func_name, 'out.shape', out.shape, result_shape)

utils.get_common_usm_allocation(x1_obj, out) # check USM allocation is common

if out is None or out.is_array_overlapped(x1_obj) or out.is_array_overlapped(x2_obj) or not out.match_ctype(return_type):
"""
Create result array with type given by FPTR data.
If 'out' array has another dtype than expected or overlaps a memory from any input array,
we have to create a temporary array and to copy data from the temporary into 'out' array,
once the computation is completed.
Otherwise simultaneously access to the same memory may cause a race condition issue
which will result into undefined behaviour.
"""
is_result_memory_allocated = True
result = utils.create_output_descriptor(result_shape,
return_type,
None,
device=result_sycl_device,
usm_type=result_usm_type,
sycl_queue=result_sycl_queue)
else:
result_type = dpnp_DPNPFuncType_to_dtype(< size_t > return_type)
if out.dtype != result_type:
utils.checker_throw_value_error(func_name, 'out.dtype', out.dtype, result_type)
if out.shape != result_shape:
utils.checker_throw_value_error(func_name, 'out.shape', out.shape, result_shape)

is_result_memory_allocated = False
result = out

utils.get_common_usm_allocation(x1_obj, result) # check USM allocation is common

cdef shape_type_c result_strides = utils.strides_to_vector(result.strides, result_shape)

result_obj = result.get_array()
Expand Down Expand Up @@ -554,4 +562,7 @@ cdef utils.dpnp_descriptor call_fptr_2in_1out_strides(DPNPFuncName fptr_name,
with nogil: c_dpctl.DPCTLEvent_WaitAndThrow(event_ref)
c_dpctl.DPCTLEvent_Delete(event_ref)

return result
if out is not None and is_result_memory_allocated:
return out.get_result_desc(result)

return result.get_result_desc()
7 changes: 6 additions & 1 deletion dpnp/dpnp_iface.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,10 @@ def get_dpnp_descriptor(ext_obj,
if use_origin_backend():
return False

# It's required to keep track of input object if a non-strided copy is going to be created.
# Thus there will be an extra descriptor allocated to refer on original input.
orig_desc = None

# If input object is a scalar, it means it was allocated on host memory.
# We need to copy it to USM memory according to compute follows data paradigm.
if isscalar(ext_obj):
Expand All @@ -291,6 +295,7 @@ def get_dpnp_descriptor(ext_obj,
ext_obj_offset = 0

if ext_obj.strides != shape_offsets or ext_obj_offset != 0:
orig_desc = dpnp_descriptor(ext_obj)
ext_obj = array(ext_obj)

# while dpnp functions are based on DPNP_QUEUE
Expand All @@ -304,7 +309,7 @@ def get_dpnp_descriptor(ext_obj,
if not queue_is_default:
ext_obj = array(ext_obj, sycl_queue=default_queue)

dpnp_desc = dpnp_descriptor(ext_obj)
dpnp_desc = dpnp_descriptor(ext_obj, orig_desc)
if dpnp_desc.is_valid:
return dpnp_desc

Expand Down
8 changes: 5 additions & 3 deletions dpnp/dpnp_iface_bitwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,9 @@
def _check_nd_call(origin_func, dpnp_func, x1, x2, dtype=None, out=None, where=True, **kwargs):
"""Choose function to call based on input and call chosen fucntion."""

if where is not True:
if kwargs:
pass
elif where is not True:
pass
elif dtype is not None:
pass
Expand All @@ -85,7 +87,7 @@ def _check_nd_call(origin_func, dpnp_func, x1, x2, dtype=None, out=None, where=T
if out is not None:
if not isinstance(out, (dpnp.ndarray, dpt.usm_ndarray)):
raise TypeError("return array must be of supported array type")
out_desc = dpnp.get_dpnp_descriptor(out, copy_when_nondefault_queue=False)
out_desc = dpnp.get_dpnp_descriptor(out, copy_when_nondefault_queue=False) or None
else:
out_desc = None

Expand Down Expand Up @@ -273,7 +275,7 @@ def invert(x,
if out is not None:
if not isinstance(out, (dpnp.ndarray, dpt.usm_ndarray)):
raise TypeError("return array must be of supported array type")
out_desc = dpnp.get_dpnp_descriptor(out, copy_when_nondefault_queue=False)
out_desc = dpnp.get_dpnp_descriptor(out, copy_when_nondefault_queue=False) or None
else:
out_desc = None
return dpnp_invert(x1_desc, out_desc).get_pyobj()
Expand Down
2 changes: 1 addition & 1 deletion dpnp/dpnp_iface_linearalgebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def dot(x1, x2, out=None, **kwargs):
if out is not None:
if not isinstance(out, (dpnp.ndarray, dpt.usm_ndarray)):
raise TypeError("return array must be of supported array type")
out_desc = dpnp.get_dpnp_descriptor(out, copy_when_nondefault_queue=False)
out_desc = dpnp.get_dpnp_descriptor(out, copy_when_nondefault_queue=False) or None
else:
out_desc = None
return dpnp_dot(x1_desc, x2_desc, out=out_desc).get_pyobj()
Expand Down
64 changes: 36 additions & 28 deletions dpnp/dpnp_iface_mathematical.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,41 @@
]


def _check_nd_call(origin_func, dpnp_func, x1, x2, out=None, where=True, dtype=None, subok=True, **kwargs):
"""Choose function to call based on input and call chosen fucntion."""

if kwargs:
pass
elif where is not True:
pass
elif dtype is not None:
pass
elif subok is not True:
pass
elif dpnp.isscalar(x1) and dpnp.isscalar(x2):
# at least either x1 or x2 has to be an array
pass
else:
# get USM type and queue to copy scalar from the host memory into a USM allocation
usm_type, queue = get_usm_allocations([x1, x2]) if dpnp.isscalar(x1) or dpnp.isscalar(x2) else (None, None)

x1_desc = dpnp.get_dpnp_descriptor(x1, copy_when_strides=False, copy_when_nondefault_queue=False,
alloc_usm_type=usm_type, alloc_queue=queue)
x2_desc = dpnp.get_dpnp_descriptor(x2, copy_when_strides=False, copy_when_nondefault_queue=False,
alloc_usm_type=usm_type, alloc_queue=queue)
if x1_desc and x2_desc:
if out is not None:
if not isinstance(out, (dpnp.ndarray, dpt.usm_ndarray)):
raise TypeError("return array must be of supported array type")
out_desc = dpnp.get_dpnp_descriptor(out, copy_when_nondefault_queue=False) or None
else:
out_desc = None

return dpnp_func(x1_desc, x2_desc, dtype=dtype, out=out_desc, where=where).get_pyobj()

return call_origin(origin_func, x1, x2, dtype=dtype, out=out, where=where, **kwargs)


def abs(*args, **kwargs):
"""
Calculate the absolute value element-wise.
Expand Down Expand Up @@ -1397,34 +1432,7 @@ def power(x1,

"""

if where is not True:
pass
elif dtype is not None:
pass
elif subok is not True:
pass
elif dpnp.isscalar(x1) and dpnp.isscalar(x2):
# at least either x1 or x2 has to be an array
pass
else:
# get USM type and queue to copy scalar from the host memory into a USM allocation
usm_type, queue = get_usm_allocations([x1, x2]) if dpnp.isscalar(x1) or dpnp.isscalar(x2) else (None, None)

x1_desc = dpnp.get_dpnp_descriptor(x1, copy_when_strides=False, copy_when_nondefault_queue=False,
alloc_usm_type=usm_type, alloc_queue=queue)
x2_desc = dpnp.get_dpnp_descriptor(x2, copy_when_strides=False, copy_when_nondefault_queue=False,
alloc_usm_type=usm_type, alloc_queue=queue)
if x1_desc and x2_desc:
if out is not None:
if not isinstance(out, (dpnp.ndarray, dpt.usm_ndarray)):
raise TypeError("return array must be of supported array type")
out_desc = dpnp.get_dpnp_descriptor(out, copy_when_nondefault_queue=False)
else:
out_desc = None

return dpnp_power(x1_desc, x2_desc, dtype=dtype, out=out_desc, where=where).get_pyobj()

return call_origin(numpy.power, x1, x2, out=out, where=where, dtype=dtype, subok=subok, **kwargs)
return _check_nd_call(numpy.power, dpnp_power, x1, x2, out=out, where=where, dtype=dtype, subok=subok, **kwargs)


def prod(x1, axis=None, dtype=None, out=None, keepdims=False, initial=None, where=True):
Expand Down
2 changes: 2 additions & 0 deletions dpnp/dpnp_utils/dpnp_algo_utils.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -116,11 +116,13 @@ cdef class dpnp_descriptor:

cdef public: # TODO remove "public" as python accessible attribute
object origin_pyobj
dpnp_descriptor origin_desc
dict descriptor
Py_ssize_t dpnp_descriptor_data_size
cpp_bool dpnp_descriptor_is_scalar

cdef void * get_data(self)
cdef cpp_bool match_ctype(self, DPNPFuncType ctype)


cdef shape_type_c get_common_shape(shape_type_c input1_shape, shape_type_c input2_shape) except *
Expand Down
Loading