Skip to content

dpnp.multiply() doesn't work properly with a scalar #1254

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Dec 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 33 additions & 6 deletions dpnp/backend/kernels/dpnp_krnl_elemwise.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
//*****************************************************************************
// Copyright (c) 2016-2020, Intel Corporation
// Copyright (c) 2016-2022, Intel Corporation
// All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
Expand Down Expand Up @@ -183,6 +183,7 @@
where, \
dep_event_vec_ref); \
DPCTLEvent_WaitAndThrow(event_ref); \
DPCTLEvent_Delete(event_ref); \
} \
\
template <typename _DataType_input, typename _DataType_output> \
Expand Down Expand Up @@ -690,6 +691,7 @@ static void func_map_init_elemwise_1arg_2type(func_map_t& fmap)
where, \
dep_event_vec_ref); \
DPCTLEvent_WaitAndThrow(event_ref); \
DPCTLEvent_Delete(event_ref); \
} \
\
template <typename _DataType> \
Expand Down Expand Up @@ -1067,6 +1069,7 @@ static void func_map_init_elemwise_1arg_1type(func_map_t& fmap)
dep_event_vec_ref \
); \
DPCTLEvent_WaitAndThrow(event_ref); \
DPCTLEvent_Delete(event_ref); \
} \
\
template <typename _DataType_output, typename _DataType_input1, typename _DataType_input2> \
Expand Down Expand Up @@ -1732,36 +1735,56 @@ static void func_map_init_elemwise_2arg_3type(func_map_t& fmap)
eft_FLT, (void*)dpnp_multiply_c_ext<float, bool, float>};
fmap[DPNPFuncName::DPNP_FN_MULTIPLY_EXT][eft_BLN][eft_DBL] = {
eft_DBL, (void*)dpnp_multiply_c_ext<double, bool, double>};
fmap[DPNPFuncName::DPNP_FN_MULTIPLY_EXT][eft_BLN][eft_C64] = {
eft_C64, (void*)dpnp_multiply_c_ext<std::complex<float>, bool, std::complex<float>>};
fmap[DPNPFuncName::DPNP_FN_MULTIPLY_EXT][eft_BLN][eft_C128] = {
eft_C128, (void*)dpnp_multiply_c_ext<std::complex<double>, bool, std::complex<double>>};

fmap[DPNPFuncName::DPNP_FN_MULTIPLY_EXT][eft_INT][eft_BLN] = {
eft_INT, (void*)dpnp_multiply_c_ext<int32_t, int32_t, bool>};
fmap[DPNPFuncName::DPNP_FN_MULTIPLY_EXT][eft_INT][eft_INT] = {
eft_INT, (void*)dpnp_multiply_c_ext<int32_t, int32_t, int32_t>};
fmap[DPNPFuncName::DPNP_FN_MULTIPLY_EXT][eft_INT][eft_LNG] = {
eft_LNG, (void*)dpnp_multiply_c_ext<int64_t, int32_t, int64_t>};
fmap[DPNPFuncName::DPNP_FN_MULTIPLY_EXT][eft_INT][eft_FLT] = {
eft_DBL, (void*)dpnp_multiply_c_ext<double, int32_t, float>};
eft_FLT, (void*)dpnp_multiply_c_ext<float, int32_t, float>};
fmap[DPNPFuncName::DPNP_FN_MULTIPLY_EXT][eft_INT][eft_DBL] = {
eft_DBL, (void*)dpnp_multiply_c_ext<double, int32_t, double>};
fmap[DPNPFuncName::DPNP_FN_MULTIPLY_EXT][eft_INT][eft_C64] = {
eft_C64, (void*)dpnp_multiply_c_ext<std::complex<float>, int32_t, std::complex<float>>};
fmap[DPNPFuncName::DPNP_FN_MULTIPLY_EXT][eft_INT][eft_C128] = {
eft_C128, (void*)dpnp_multiply_c_ext<std::complex<double>, int32_t, std::complex<double>>};

fmap[DPNPFuncName::DPNP_FN_MULTIPLY_EXT][eft_LNG][eft_BLN] = {
eft_LNG, (void*)dpnp_multiply_c_ext<int64_t, int64_t, bool>};
fmap[DPNPFuncName::DPNP_FN_MULTIPLY_EXT][eft_LNG][eft_INT] = {
eft_LNG, (void*)dpnp_multiply_c_ext<int64_t, int64_t, int32_t>};
fmap[DPNPFuncName::DPNP_FN_MULTIPLY_EXT][eft_LNG][eft_LNG] = {
eft_LNG, (void*)dpnp_multiply_c_ext<int64_t, int64_t, int64_t>};
fmap[DPNPFuncName::DPNP_FN_MULTIPLY_EXT][eft_LNG][eft_FLT] = {
eft_DBL, (void*)dpnp_multiply_c_ext<double, int64_t, float>};
eft_FLT, (void*)dpnp_multiply_c_ext<float, int64_t, float>};
fmap[DPNPFuncName::DPNP_FN_MULTIPLY_EXT][eft_LNG][eft_DBL] = {
eft_DBL, (void*)dpnp_multiply_c_ext<double, int64_t, double>};
fmap[DPNPFuncName::DPNP_FN_MULTIPLY_EXT][eft_LNG][eft_C64] = {
eft_C64, (void*)dpnp_multiply_c_ext<std::complex<float>, int64_t, std::complex<float>>};
fmap[DPNPFuncName::DPNP_FN_MULTIPLY_EXT][eft_LNG][eft_C128] = {
eft_C128, (void*)dpnp_multiply_c_ext<std::complex<double>, int64_t, std::complex<double>>};

fmap[DPNPFuncName::DPNP_FN_MULTIPLY_EXT][eft_FLT][eft_BLN] = {
eft_FLT, (void*)dpnp_multiply_c_ext<float, float, bool>};
fmap[DPNPFuncName::DPNP_FN_MULTIPLY_EXT][eft_FLT][eft_INT] = {
eft_DBL, (void*)dpnp_multiply_c_ext<double, float, int32_t>};
eft_FLT, (void*)dpnp_multiply_c_ext<float, float, int32_t>};
fmap[DPNPFuncName::DPNP_FN_MULTIPLY_EXT][eft_FLT][eft_LNG] = {
eft_DBL, (void*)dpnp_multiply_c_ext<double, float, int64_t>};
eft_FLT, (void*)dpnp_multiply_c_ext<float, float, int64_t>};
fmap[DPNPFuncName::DPNP_FN_MULTIPLY_EXT][eft_FLT][eft_FLT] = {
eft_FLT, (void*)dpnp_multiply_c_ext<float, float, float>};
fmap[DPNPFuncName::DPNP_FN_MULTIPLY_EXT][eft_FLT][eft_DBL] = {
eft_DBL, (void*)dpnp_multiply_c_ext<double, float, double>};
fmap[DPNPFuncName::DPNP_FN_MULTIPLY_EXT][eft_FLT][eft_C64] = {
eft_C64, (void*)dpnp_multiply_c_ext<std::complex<float>, float, std::complex<float>>};
fmap[DPNPFuncName::DPNP_FN_MULTIPLY_EXT][eft_FLT][eft_C128] = {
eft_C128, (void*)dpnp_multiply_c_ext<std::complex<double>, float, std::complex<double>>};

fmap[DPNPFuncName::DPNP_FN_MULTIPLY_EXT][eft_DBL][eft_BLN] = {
eft_DBL, (void*)dpnp_multiply_c_ext<double, double, bool>};
fmap[DPNPFuncName::DPNP_FN_MULTIPLY_EXT][eft_DBL][eft_INT] = {
Expand All @@ -1772,6 +1795,10 @@ static void func_map_init_elemwise_2arg_3type(func_map_t& fmap)
eft_DBL, (void*)dpnp_multiply_c_ext<double, double, float>};
fmap[DPNPFuncName::DPNP_FN_MULTIPLY_EXT][eft_DBL][eft_DBL] = {
eft_DBL, (void*)dpnp_multiply_c_ext<double, double, double>};
fmap[DPNPFuncName::DPNP_FN_MULTIPLY_EXT][eft_DBL][eft_C64] = {
eft_C64, (void*)dpnp_multiply_c_ext<std::complex<float>, double, std::complex<float>>};
fmap[DPNPFuncName::DPNP_FN_MULTIPLY_EXT][eft_DBL][eft_C128] = {
eft_C128, (void*)dpnp_multiply_c_ext<std::complex<double>, double, std::complex<double>>};

fmap[DPNPFuncName::DPNP_FN_MULTIPLY_EXT][eft_C64][eft_BLN] = {
eft_C64, (void*)dpnp_multiply_c_ext<std::complex<float>, std::complex<float>, bool>};
Expand All @@ -1782,7 +1809,7 @@ static void func_map_init_elemwise_2arg_3type(func_map_t& fmap)
fmap[DPNPFuncName::DPNP_FN_MULTIPLY_EXT][eft_C64][eft_FLT] = {
eft_C64, (void*)dpnp_multiply_c_ext<std::complex<float>, std::complex<float>, float>};
fmap[DPNPFuncName::DPNP_FN_MULTIPLY_EXT][eft_C64][eft_DBL] = {
eft_C128, (void*)dpnp_multiply_c_ext<std::complex<double>, std::complex<float>, double>};
eft_C64, (void*)dpnp_multiply_c_ext<std::complex<float>, std::complex<float>, double>};
fmap[DPNPFuncName::DPNP_FN_MULTIPLY_EXT][eft_C64][eft_C64] = {
eft_C64, (void*)dpnp_multiply_c_ext<std::complex<float>, std::complex<float>, std::complex<float>>};
fmap[DPNPFuncName::DPNP_FN_MULTIPLY_EXT][eft_C64][eft_C128] = {
Expand Down
1 change: 1 addition & 0 deletions dpnp/dpnp_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,7 @@ def asnumpy(self):

Returns
-------
numpy.ndarray
An instance of :class:`numpy.ndarray` populated with the array content.

"""
Expand Down
19 changes: 15 additions & 4 deletions dpnp/dpnp_iface.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@
"get_normalized_queue_device"
]

from dpnp import (
isscalar
)

from dpnp.dpnp_iface_arraycreation import *
from dpnp.dpnp_iface_bitwise import *
from dpnp.dpnp_iface_counting import *
Expand Down Expand Up @@ -187,7 +191,10 @@ def convert_single_elem_array_to_scalar(obj, keepdims=False):
return obj


def get_dpnp_descriptor(ext_obj, copy_when_strides=True, copy_when_nondefault_queue=True):
def get_dpnp_descriptor(ext_obj,
copy_when_strides=True,
copy_when_nondefault_queue=True,
alloc_queue=None):
"""
Return True:
never
Expand All @@ -206,6 +213,11 @@ def get_dpnp_descriptor(ext_obj, copy_when_strides=True, copy_when_nondefault_qu
if use_origin_backend():
return False

# If input object is a scalar, it means it was allocated on host memory.
# We need to copy it to device memory according to compute follows data paradigm.
if isscalar(ext_obj):
ext_obj = array(ext_obj, sycl_queue=alloc_queue)

# while dpnp functions have no implementation with strides support
# we need to create a non-strided copy
# if function get implementation for strides case
Expand All @@ -226,13 +238,12 @@ def get_dpnp_descriptor(ext_obj, copy_when_strides=True, copy_when_nondefault_qu
# we need to create a copy on device associated with DPNP_QUEUE
# if function get implementation for different queue
# then this behavior can be disabled with setting "copy_when_nondefault_queue"
arr_obj = unwrap_array(ext_obj)
queue = getattr(arr_obj, "sycl_queue", None)
queue = getattr(ext_obj, "sycl_queue", None)
if queue is not None and copy_when_nondefault_queue:
default_queue = dpctl.SyclQueue()
queue_is_default = dpctl.utils.get_execution_queue([queue, default_queue]) is not None
if not queue_is_default:
ext_obj = array(arr_obj, sycl_queue=default_queue)
ext_obj = array(ext_obj, sycl_queue=default_queue)

dpnp_desc = dpnp_descriptor(ext_obj)
if dpnp_desc.is_valid:
Expand Down
28 changes: 15 additions & 13 deletions dpnp/dpnp_iface_manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# distutils: language = c++
# -*- coding: utf-8 -*-
# *****************************************************************************
# Copyright (c) 2016-2020, Intel Corporation
# Copyright (c) 2016-2022, Intel Corporation
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
Expand Down Expand Up @@ -131,12 +131,13 @@ def atleast_2d(*arys):
all_is_array = True
arys_desc = []
for ary in arys:
ary_desc = dpnp.get_dpnp_descriptor(ary, copy_when_nondefault_queue=False)
if ary_desc:
arys_desc.append(ary_desc)
else:
all_is_array = False
break
if not dpnp.isscalar(ary):
ary_desc = dpnp.get_dpnp_descriptor(ary, copy_when_nondefault_queue=False)
if ary_desc:
arys_desc.append(ary_desc)
continue
all_is_array = False
break

if not use_origin_backend(arys[0]) and all_is_array:
result = []
Expand Down Expand Up @@ -166,12 +167,13 @@ def atleast_3d(*arys):
all_is_array = True
arys_desc = []
for ary in arys:
ary_desc = dpnp.get_dpnp_descriptor(ary, copy_when_nondefault_queue=False)
if ary_desc:
arys_desc.append(ary_desc)
else:
all_is_array = False
break
if not dpnp.isscalar(ary):
ary_desc = dpnp.get_dpnp_descriptor(ary, copy_when_nondefault_queue=False)
if ary_desc:
arys_desc.append(ary_desc)
continue
all_is_array = False
break

if not use_origin_backend(arys[0]) and all_is_array:
result = []
Expand Down
73 changes: 42 additions & 31 deletions dpnp/dpnp_iface_mathematical.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# distutils: language = c++
# -*- coding: utf-8 -*-
# *****************************************************************************
# Copyright (c) 2016-2020, Intel Corporation
# Copyright (c) 2016-2022, Intel Corporation
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
Expand Down Expand Up @@ -850,9 +850,9 @@ def fmod(x1, x2, dtype=None, out=None, where=True, **kwargs):
pass
elif x1_is_scalar and x2_is_scalar:
pass
elif x1_desc and x1.ndim == 0:
elif x1_desc and x1_desc.ndim == 0:
pass
elif x2_desc and x2.ndim == 0:
elif x2_desc and x2_desc.ndim == 0:
pass
elif dtype is not None:
pass
Expand Down Expand Up @@ -1075,51 +1075,62 @@ def modf(x1, **kwargs):
return call_origin(numpy.modf, x1, **kwargs)


def multiply(x1, x2, dtype=None, out=None, where=True, **kwargs):
def multiply(x1,
x2,
/,
out=None,
*,
where=True,
dtype=None,
subok=True,
**kwargs):
"""
Multiply arguments element-wise.

For full documentation refer to :obj:`numpy.multiply`.

Returns
-------
y : {dpnp.ndarray, scalar}
The product of `x1` and `x2`, element-wise.
The result is a scalar if both x1 and x2 are scalars.

Limitations
-----------
Parameters ``x1`` and ``x2`` are supported as either :obj:`dpnp.ndarray` or scalar.
Parameters ``dtype``, ``out`` and ``where`` are supported with their default values.
Parameters `x1` and `x2` are supported as either :class:`dpnp.ndarray` or scalar.
Parameters `out`, `where`, `dtype` and `subok` are supported with their default values.
Keyword arguments ``kwargs`` are currently unsupported.
Otherwise the functions will be executed sequentially on CPU.
Input array data types are limited by supported DPNP :ref:`Data types`.

Examples
--------
>>> import dpnp as np
>>> a = np.array([1, 2, 3, 4, 5])
>>> result = np.multiply(a, a)
>>> [x for x in result]
>>> import dpnp as dp
>>> a = dp.array([1, 2, 3, 4, 5])
>>> result = dp.multiply(a, a)
>>> print(result)
[1, 4, 9, 16, 25]

"""

x1_is_scalar = dpnp.isscalar(x1)
x2_is_scalar = dpnp.isscalar(x2)
x1_desc = dpnp.get_dpnp_descriptor(x1, copy_when_strides=False, copy_when_nondefault_queue=False)
x2_desc = dpnp.get_dpnp_descriptor(x2, copy_when_strides=False, copy_when_nondefault_queue=False)

if x1_desc and x2_desc and not kwargs:
if not x2_desc and not x2_is_scalar:
pass
elif x1_is_scalar and x2_is_scalar:
pass
elif x1_desc and x1_desc.ndim == 0:
pass
elif x2_desc and x2_desc.ndim == 0:
pass
elif dtype is not None:
pass
elif out is not None:
pass
elif not where:
pass
else:
if out is not None:
pass
elif where is not True:
pass
elif dtype is not None:
pass
elif subok is not True:
pass
elif dpnp.isscalar(x1) and dpnp.isscalar(x2):
# keep the result in host memory, if both inputs are scalars
return x1 * x2
else:
# get a common queue to copy data from the host into a device if any input is scalar
queue = get_common_allocation_queue([x1, x2]) if dpnp.isscalar(x1) or dpnp.isscalar(x2) else None

x1_desc = dpnp.get_dpnp_descriptor(x1, copy_when_strides=False, copy_when_nondefault_queue=False, alloc_queue=queue)
x2_desc = dpnp.get_dpnp_descriptor(x2, copy_when_strides=False, copy_when_nondefault_queue=False, alloc_queue=queue)
if x1_desc and x2_desc:
return dpnp_multiply(x1_desc, x2_desc, dtype=dtype, out=out, where=where).get_pyobj()

return call_origin(numpy.multiply, x1, x2, dtype=dtype, out=out, where=where, **kwargs)
Expand Down
Loading