Skip to content

Commit b60eb41

Browse files
committed
address reviewer's comments
1 parent 15e2611 commit b60eb41

File tree

3 files changed

+95
-90
lines changed

3 files changed

+95
-90
lines changed

dpnp/backend/include/dpnp_iface_fptr.hpp

+41-41
Original file line numberDiff line numberDiff line change
@@ -69,47 +69,47 @@ enum class DPNPFuncName : size_t
6969
DPNP_FN_ALLCLOSE_EXT, /**< Used in numpy.allclose() impl, requires extra
7070
parameters */
7171
DPNP_FN_ANY, /**< Used in numpy.any() impl */
72-
DPNP_FN_ARANGE, /**< Used in numpy.arange() impl */
73-
DPNP_FN_ARCCOS, /**< Used in numpy.arccos() impl */
74-
DPNP_FN_ARCCOS_EXT, /**< Used in numpy.arccos() impl, requires extra
75-
parameters */
76-
DPNP_FN_ARCCOSH, /**< Used in numpy.arccosh() impl */
77-
DPNP_FN_ARCCOSH_EXT, /**< Used in numpy.arccosh() impl, requires extra
78-
parameters */
79-
DPNP_FN_ARCSIN, /**< Used in numpy.arcsin() impl */
80-
DPNP_FN_ARCSIN_EXT, /**< Used in numpy.arcsin() impl, requires extra
81-
parameters */
82-
DPNP_FN_ARCSINH, /**< Used in numpy.arcsinh() impl */
83-
DPNP_FN_ARCSINH_EXT, /**< Used in numpy.arcsinh() impl, requires extra
84-
parameters */
85-
DPNP_FN_ARCTAN, /**< Used in numpy.arctan() impl */
86-
DPNP_FN_ARCTAN_EXT, /**< Used in numpy.arctan() impl, requires extra
87-
parameters */
88-
DPNP_FN_ARCTAN2, /**< Used in numpy.arctan2() impl */
89-
DPNP_FN_ARCTAN2_EXT, /**< Used in numpy.arctan2() impl, requires extra
90-
parameters */
91-
DPNP_FN_ARCTANH, /**< Used in numpy.arctanh() impl */
92-
DPNP_FN_ARCTANH_EXT, /**< Used in numpy.arctanh() impl, requires extra
93-
parameters */
94-
DPNP_FN_ARGMAX, /**< Used in numpy.argmax() impl */
95-
DPNP_FN_ARGMAX_EXT, /**< Used in numpy.argmax() impl, requires extra
96-
parameters */
97-
DPNP_FN_ARGMIN, /**< Used in numpy.argmin() impl */
98-
DPNP_FN_ARGMIN_EXT, /**< Used in numpy.argmin() impl, requires extra
99-
parameters */
100-
DPNP_FN_ARGSORT, /**< Used in numpy.argsort() impl */
101-
DPNP_FN_ARGSORT_EXT, /**< Used in numpy.argsort() impl, requires extra
102-
parameters */
103-
DPNP_FN_AROUND, /**< Used in numpy.around() impl */
104-
DPNP_FN_AROUND_EXT, /**< Used in numpy.around() impl, requires extra
105-
parameters */
106-
DPNP_FN_ASTYPE, /**< Used in numpy.astype() impl */
107-
DPNP_FN_ASTYPE_EXT, /**< Used in numpy.astype() impl, requires extra
108-
parameters */
109-
DPNP_FN_BITWISE_AND, /**< Used in numpy.bitwise_and() impl */
110-
DPNP_FN_BITWISE_OR, /**< Used in numpy.bitwise_or() impl */
111-
DPNP_FN_BITWISE_XOR, /**< Used in numpy.bitwise_xor() impl */
112-
DPNP_FN_CBRT, /**< Used in numpy.cbrt() impl */
72+
DPNP_FN_ARANGE, /**< Used in numpy.arange() impl */
73+
DPNP_FN_ARCCOS, /**< Used in numpy.arccos() impl */
74+
DPNP_FN_ARCCOS_EXT, /**< Used in numpy.arccos() impl, requires extra
75+
parameters */
76+
DPNP_FN_ARCCOSH, /**< Used in numpy.arccosh() impl */
77+
DPNP_FN_ARCCOSH_EXT, /**< Used in numpy.arccosh() impl, requires extra
78+
parameters */
79+
DPNP_FN_ARCSIN, /**< Used in numpy.arcsin() impl */
80+
DPNP_FN_ARCSIN_EXT, /**< Used in numpy.arcsin() impl, requires extra
81+
parameters */
82+
DPNP_FN_ARCSINH, /**< Used in numpy.arcsinh() impl */
83+
DPNP_FN_ARCSINH_EXT, /**< Used in numpy.arcsinh() impl, requires extra
84+
parameters */
85+
DPNP_FN_ARCTAN, /**< Used in numpy.arctan() impl */
86+
DPNP_FN_ARCTAN_EXT, /**< Used in numpy.arctan() impl, requires extra
87+
parameters */
88+
DPNP_FN_ARCTAN2, /**< Used in numpy.arctan2() impl */
89+
DPNP_FN_ARCTAN2_EXT, /**< Used in numpy.arctan2() impl, requires extra
90+
parameters */
91+
DPNP_FN_ARCTANH, /**< Used in numpy.arctanh() impl */
92+
DPNP_FN_ARCTANH_EXT, /**< Used in numpy.arctanh() impl, requires extra
93+
parameters */
94+
DPNP_FN_ARGMAX, /**< Used in numpy.argmax() impl */
95+
DPNP_FN_ARGMAX_EXT, /**< Used in numpy.argmax() impl, requires extra
96+
parameters */
97+
DPNP_FN_ARGMIN, /**< Used in numpy.argmin() impl */
98+
DPNP_FN_ARGMIN_EXT, /**< Used in numpy.argmin() impl, requires extra
99+
parameters */
100+
DPNP_FN_ARGSORT, /**< Used in numpy.argsort() impl */
101+
DPNP_FN_ARGSORT_EXT, /**< Used in numpy.argsort() impl, requires extra
102+
parameters */
103+
DPNP_FN_AROUND, /**< Used in numpy.around() impl */
104+
DPNP_FN_AROUND_EXT, /**< Used in numpy.around() impl, requires extra
105+
parameters */
106+
DPNP_FN_ASTYPE, /**< Used in numpy.astype() impl */
107+
DPNP_FN_ASTYPE_EXT, /**< Used in numpy.astype() impl, requires extra
108+
parameters */
109+
DPNP_FN_BITWISE_AND, /**< Used in numpy.bitwise_and() impl */
110+
DPNP_FN_BITWISE_OR, /**< Used in numpy.bitwise_or() impl */
111+
DPNP_FN_BITWISE_XOR, /**< Used in numpy.bitwise_xor() impl */
112+
DPNP_FN_CBRT, /**< Used in numpy.cbrt() impl */
113113
DPNP_FN_CBRT_EXT, /**< Used in numpy.cbrt() impl, requires extra parameters
114114
*/
115115
DPNP_FN_CEIL, /**< Used in numpy.ceil() impl */

dpnp/dpnp_algo/dpnp_elementwise_common.py

+49-49
Original file line numberDiff line numberDiff line change
@@ -186,55 +186,6 @@ def dpnp_add(x1, x2, out=None, order="K"):
186186
return dpnp_array._create_from_usm_ndarray(res_usm)
187187

188188

189-
_cos_docstring = """
190-
cos(x, out=None, order='K')
191-
Computes cosine for each element `x_i` for input array `x`.
192-
Args:
193-
x (dpnp.ndarray):
194-
Input array, expected to have numeric data type.
195-
out ({None, dpnp.ndarray}, optional):
196-
Output array to populate. Array must have the correct
197-
shape and the expected data type.
198-
order ("C","F","A","K", optional): memory layout of the new
199-
output array, if parameter `out` is `None`.
200-
Default: "K".
201-
Return:
202-
dpnp.ndarray:
203-
An array containing the element-wise cosine. The data type
204-
of the returned array is determined by the Type Promotion Rules.
205-
"""
206-
207-
208-
def dpnp_cos(x, out=None, order="K"):
209-
"""
210-
Invokes cos() function from pybind11 extension of OneMKL VM if possible.
211-
212-
Otherwise fully relies on dpctl.tensor implementation for cos() function.
213-
214-
"""
215-
216-
def _call_cos(src, dst, sycl_queue, depends=None):
217-
"""A callback to register in UnaryElementwiseFunc class of dpctl.tensor"""
218-
219-
if depends is None:
220-
depends = []
221-
222-
if vmi._mkl_cos_to_call(sycl_queue, src, dst):
223-
# call pybind11 extension for cos() function from OneMKL VM
224-
return vmi._cos(sycl_queue, src, dst, depends)
225-
return ti._cos(src, dst, sycl_queue, depends)
226-
227-
# dpctl.tensor only works with usm_ndarray
228-
x1_usm = dpnp.get_usm_ndarray(x)
229-
out_usm = None if out is None else dpnp.get_usm_ndarray(out)
230-
231-
func = UnaryElementwiseFunc(
232-
"cos", ti._cos_result_type, _call_cos, _cos_docstring
233-
)
234-
res_usm = func(x1_usm, out=out_usm, order=order)
235-
return dpnp_array._create_from_usm_ndarray(res_usm)
236-
237-
238189
_bitwise_and_docstring_ = """
239190
bitwise_and(x1, x2, out=None, order='K')
240191
@@ -367,6 +318,55 @@ def dpnp_bitwise_xor(x1, x2, out=None, order="K"):
367318
return dpnp_array._create_from_usm_ndarray(res_usm)
368319

369320

321+
_cos_docstring = """
322+
cos(x, out=None, order='K')
323+
Computes cosine for each element `x_i` for input array `x`.
324+
Args:
325+
x (dpnp.ndarray):
326+
Input array, expected to have numeric data type.
327+
out ({None, dpnp.ndarray}, optional):
328+
Output array to populate. Array must have the correct
329+
shape and the expected data type.
330+
order ("C","F","A","K", optional): memory layout of the new
331+
output array, if parameter `out` is `None`.
332+
Default: "K".
333+
Return:
334+
dpnp.ndarray:
335+
An array containing the element-wise cosine. The data type
336+
of the returned array is determined by the Type Promotion Rules.
337+
"""
338+
339+
340+
def dpnp_cos(x, out=None, order="K"):
341+
"""
342+
Invokes cos() function from pybind11 extension of OneMKL VM if possible.
343+
344+
Otherwise fully relies on dpctl.tensor implementation for cos() function.
345+
346+
"""
347+
348+
def _call_cos(src, dst, sycl_queue, depends=None):
349+
"""A callback to register in UnaryElementwiseFunc class of dpctl.tensor"""
350+
351+
if depends is None:
352+
depends = []
353+
354+
if vmi._mkl_cos_to_call(sycl_queue, src, dst):
355+
# call pybind11 extension for cos() function from OneMKL VM
356+
return vmi._cos(sycl_queue, src, dst, depends)
357+
return ti._cos(src, dst, sycl_queue, depends)
358+
359+
# dpctl.tensor only works with usm_ndarray
360+
x1_usm = dpnp.get_usm_ndarray(x)
361+
out_usm = None if out is None else dpnp.get_usm_ndarray(out)
362+
363+
func = UnaryElementwiseFunc(
364+
"cos", ti._cos_result_type, _call_cos, _cos_docstring
365+
)
366+
res_usm = func(x1_usm, out=out_usm, order=order)
367+
return dpnp_array._create_from_usm_ndarray(res_usm)
368+
369+
370370
_divide_docstring_ = """
371371
divide(x1, x2, out=None, order="K")
372372

tests/test_bitwise.py

+5
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ def test_bitwise_and(self, lhs, rhs, dtype):
6868
assert_array_equal(dp_a & dp_b, np_a & np_b)
6969

7070
"""
71+
TODO: unmute once dpctl support that
7172
if (
7273
not (inp.isscalar(dp_a) or inp.isscalar(dp_b))
7374
and dp_a.shape == dp_b.shape
@@ -84,6 +85,7 @@ def test_bitwise_or(self, lhs, rhs, dtype):
8485
assert_array_equal(dp_a | dp_b, np_a | np_b)
8586

8687
"""
88+
TODO: unmute once dpctl support that
8789
if (
8890
not (inp.isscalar(dp_a) or inp.isscalar(dp_b))
8991
and dp_a.shape == dp_b.shape
@@ -100,6 +102,7 @@ def test_bitwise_xor(self, lhs, rhs, dtype):
100102
assert_array_equal(dp_a ^ dp_b, np_a ^ np_b)
101103

102104
"""
105+
TODO: unmute once dpctl support that
103106
if (
104107
not (inp.isscalar(dp_a) or inp.isscalar(dp_b))
105108
and dp_a.shape == dp_b.shape
@@ -120,6 +123,7 @@ def test_left_shift(self, lhs, rhs, dtype):
120123
assert_array_equal(dp_a << dp_b, np_a << np_b)
121124

122125
"""
126+
TODO: unmute once dpctl support that
123127
if (
124128
not (inp.isscalar(dp_a) or inp.isscalar(dp_b))
125129
and dp_a.shape == dp_b.shape
@@ -136,6 +140,7 @@ def test_right_shift(self, lhs, rhs, dtype):
136140
assert_array_equal(dp_a >> dp_b, np_a >> np_b)
137141

138142
"""
143+
TODO: unmute once dpctl support that
139144
if (
140145
not (inp.isscalar(dp_a) or inp.isscalar(dp_b))
141146
and dp_a.shape == dp_b.shape

0 commit comments

Comments
 (0)