Skip to content

Commit 41ec378

Browse files
Added docstrings and getter methods for ElementwiseFunc classes
Added stable API to retrieve implementation functions in each elementwise function class instance to allow `dpnp` to access that information using stable API.
1 parent 421b270 commit 41ec378

File tree

2 files changed

+230
-0
lines changed

2 files changed

+230
-0
lines changed

dpctl/tensor/_elementwise_common.py

Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,31 @@
3939
class UnaryElementwiseFunc:
4040
"""
4141
Class that implements unary element-wise functions.
42+
43+
Args:
44+
name (str):
45+
Name of the unary function
46+
result_type_resovler_fn (callable):
47+
Function that takes dtype of the input and
48+
returns the dtype of the result if the
49+
implementation functions supports it, or
50+
returns `None` otherwise.
51+
unary_dp_impl_fn (callable):
52+
Data-parallel implementation function with signature
53+
`impl_fn(src: usm_ndarray, dst: usm_ndarray,
54+
sycl_queue: SyclQueue, depends: Optional[List[SyclEvent]])`
55+
where the `src` is the argument array, `dst` is the
56+
array to be populated with function values, effectively
57+
evaluating `dst = func(src)`.
58+
The `impl_fn` is expected to return a 2-tuple of `SyclEvent`s.
59+
The first event corresponds to data-management host tasks,
60+
including lifetime management of argument Python objects to ensure
61+
that their associated USM allocation is not freed before offloaded
62+
computational tasks complete execution, while the second event
63+
corresponds to computational tasks associated with function
64+
evaluation.
65+
docs (str):
66+
Documentation string for the unary function.
4267
"""
4368

4469
def __init__(self, name, result_type_resolver_fn, unary_dp_impl_fn, docs):
@@ -55,8 +80,31 @@ def __str__(self):
5580
def __repr__(self):
5681
return f"<{self.__name__} '{self.name_}'>"
5782

83+
def get_implementation_function(self):
84+
"""Returns the implementation function for
85+
this elementwise unary function.
86+
87+
"""
88+
return self.unary_fn_
89+
90+
def get_type_result_resolver_function(self):
91+
"""Returns the type resolver function for this
92+
elementwise unary function.
93+
"""
94+
return self.result_type_resolver_fn_
95+
5896
@property
5997
def types(self):
98+
"""Returns information about types supported by
99+
implementation function, using NumPy's character
100+
encoding for data types, e.g.
101+
102+
:Example:
103+
.. code-block:: python
104+
105+
dpctl.tensor.sin.types
106+
# Outputs: ['e->e', 'f->f', 'd->d', 'F->F', 'D->D']
107+
"""
60108
types = self.types_
61109
if not types:
62110
types = []
@@ -363,6 +411,56 @@ def _get_shape(o):
363411
class BinaryElementwiseFunc:
364412
"""
365413
Class that implements binary element-wise functions.
414+
415+
Args:
416+
name (str):
417+
Name of the unary function
418+
result_type_resovle_fn (callable):
419+
Function that takes dtypes of the input and
420+
returns the dtype of the result if the
421+
implementation functions supports it, or
422+
returns `None` otherwise.
423+
binary_dp_impl_fn (callable):
424+
Data-parallel umplementation function with signature
425+
`impl_fn(src1: usm_ndarray, src2: usm_ndarray, dst: usm_ndarray,
426+
sycl_queue: SyclQueue, depends: Optional[List[SyclEvent]])`
427+
where the `src1` and `src2` are the argument arrays, `dst` is the
428+
array to be populated with function values,
429+
i.e. `dst=func(src1, src2)`.
430+
The `impl_fn` is expected to return a 2-tuple of `SyclEvent`s.
431+
The first event corresponds to data-management host tasks,
432+
including lifetime management of argument Python objects to ensure
433+
that their associated USM allocation is not freed before offloaded
434+
computational tasks complete execution, while the second event
435+
corresponds to computational tasks associated with function
436+
evaluation.
437+
docs (str):
438+
Documentation string for the unary function.
439+
binary_inplace_fn (callable, optional):
440+
Data-parallel omplementation function with signature
441+
`impl_fn(src: usm_ndarray, dst: usm_ndarray,
442+
sycl_queue: SyclQueue, depends: Optional[List[SyclEvent]])`
443+
where the `src` is the argument array, `dst` is the
444+
array to be populated with function values,
445+
i.e. `dst=func(dst, src)`.
446+
The `impl_fn` is expected to return a 2-tuple of `SyclEvent`s.
447+
The first event corresponds to data-management host tasks,
448+
including async lifetime management of Python arguments,
449+
while the second event corresponds to computational tasks
450+
associated with function evaluation.
451+
acceptance_fn (callable, optional):
452+
Function to influence type promotion behavior of this binary
453+
function. The function takes 6 arguments:
454+
arg1_dtype - Data type of the first argument
455+
arg2_dtype - Data type of the second argument
456+
ret_buf1_dtype - Data type the first argument would be cast to
457+
ret_buf2_dtype - Data type the second argument would be cast to
458+
res_dtype - Data type of the output array with function values
459+
sycl_dev - The :class:`dpctl.SyclDevice` where the function
460+
evaluation is carried out.
461+
The function is only called when both arguments of the binary
462+
function require casting, e.g. both arguments of
463+
`dpctl.tensor.logaddexp` are arrays with integral data type.
366464
"""
367465

368466
def __init__(
@@ -392,8 +490,60 @@ def __str__(self):
392490
def __repr__(self):
393491
return f"<{self.__name__} '{self.name_}'>"
394492

493+
def get_implementation_function(self):
494+
"""Returns the out-of-place implementation
495+
function for this elementwise binary function.
496+
497+
"""
498+
return self.binary_fn_
499+
500+
def get_implementation_inplace_function(self):
501+
"""Returns the in-place implementation
502+
function for this elementwise binary function.
503+
504+
"""
505+
return self.binary_inplace_fn_
506+
507+
def get_type_result_resolver_function(self):
508+
"""Returns the type resolver function for this
509+
elementwise binary function.
510+
"""
511+
return self.result_type_resolver_fn_
512+
513+
def get_type_promotion_path_acceptance_function(self):
514+
"""Returns the acceptance function for this
515+
elementwise binary function.
516+
517+
Acceptance function influences the type promotion
518+
behavior of this binary function.
519+
The function takes 6 arguments:
520+
arg1_dtype - Data type of the first argument
521+
arg2_dtype - Data type of the second argument
522+
ret_buf1_dtype - Data type the first argument would be cast to
523+
ret_buf2_dtype - Data type the second argument would be cast to
524+
res_dtype - Data type of the output array with function values
525+
sycl_dev - :class:`dpctl.SyclDevice` on which function evaluation
526+
is carried out.
527+
528+
The acceptance function is only invoked if both input arrays must be
529+
cast to intermediary data types, as would happen during call of
530+
`dpctl.tensor.hypot` with both arrays being of integral data type.
531+
"""
532+
return self.acceptance_fn_
533+
395534
@property
396535
def types(self):
536+
"""Returns information about types supported by
537+
implementation function, using NumPy's character
538+
encoding for data types, e.g.
539+
540+
:Example:
541+
.. code-block:: python
542+
543+
dpctl.tensor.divide.types
544+
# Outputs: ['ee->e', 'ff->f', 'fF->F', 'dd->d', 'dD->D',
545+
# 'Ff->F', 'FF->F', 'Dd->D', 'DD->D']
546+
"""
397547
types = self.types_
398548
if not types:
399549
types = []
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
# Data Parallel Control (dpctl)
2+
#
3+
# Copyright 2020-2023 Intel Corporation
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
import dpctl.tensor as dpt
18+
19+
unary_fn = dpt.negative
20+
binary_fn = dpt.divide
21+
22+
23+
def test_unary_class_getters():
24+
fn = unary_fn.get_implementation_function()
25+
assert callable(fn)
26+
27+
fn = unary_fn.get_type_result_resolver_function()
28+
assert callable(fn)
29+
30+
31+
def test_unary_class_types_property():
32+
loop_types = unary_fn.types
33+
assert isinstance(loop_types, list)
34+
assert len(loop_types) > 0
35+
assert all(isinstance(sig, str) for sig in loop_types)
36+
assert all("->" in sig for sig in loop_types)
37+
38+
39+
def test_unary_class_str_repr():
40+
s = str(unary_fn)
41+
r = repr(unary_fn)
42+
43+
assert isinstance(s, str)
44+
assert isinstance(r, str)
45+
kl_n = unary_fn.__name__
46+
assert kl_n in s
47+
assert kl_n in r
48+
49+
50+
def test_binary_class_getters():
51+
fn = binary_fn.get_implementation_function()
52+
assert callable(fn)
53+
54+
fn = binary_fn.get_implementation_inplace_function()
55+
assert callable(fn)
56+
57+
fn = binary_fn.get_type_result_resolver_function()
58+
assert callable(fn)
59+
60+
fn = binary_fn.get_type_promotion_path_acceptance_function()
61+
assert callable(fn)
62+
63+
64+
def test_binary_class_types_property():
65+
loop_types = binary_fn.types
66+
assert isinstance(loop_types, list)
67+
assert len(loop_types) > 0
68+
assert all(isinstance(sig, str) for sig in loop_types)
69+
assert all("->" in sig for sig in loop_types)
70+
71+
72+
def test_binary_class_str_repr():
73+
s = str(binary_fn)
74+
r = repr(binary_fn)
75+
76+
assert isinstance(s, str)
77+
assert isinstance(r, str)
78+
kl_n = binary_fn.__name__
79+
assert kl_n in s
80+
assert kl_n in r

0 commit comments

Comments
 (0)