39
39
class UnaryElementwiseFunc :
40
40
"""
41
41
Class that implements unary element-wise functions.
42
+
43
+ Args:
44
+ name (str):
45
+ Name of the unary function
46
+ result_type_resovler_fn (callable):
47
+ Function that takes dtype of the input and
48
+ returns the dtype of the result if the
49
+ implementation functions supports it, or
50
+ returns `None` otherwise.
51
+ unary_dp_impl_fn (callable):
52
+ Data-parallel implementation function with signature
53
+ `impl_fn(src: usm_ndarray, dst: usm_ndarray,
54
+ sycl_queue: SyclQueue, depends: Optional[List[SyclEvent]])`
55
+ where the `src` is the argument array, `dst` is the
56
+ array to be populated with function values, effectively
57
+ evaluating `dst = func(src)`.
58
+ The `impl_fn` is expected to return a 2-tuple of `SyclEvent`s.
59
+ The first event corresponds to data-management host tasks,
60
+ including lifetime management of argument Python objects to ensure
61
+ that their associated USM allocation is not freed before offloaded
62
+ computational tasks complete execution, while the second event
63
+ corresponds to computational tasks associated with function
64
+ evaluation.
65
+ docs (str):
66
+ Documentation string for the unary function.
42
67
"""
43
68
44
69
def __init__ (self , name , result_type_resolver_fn , unary_dp_impl_fn , docs ):
@@ -55,8 +80,31 @@ def __str__(self):
55
80
def __repr__ (self ):
56
81
return f"<{ self .__name__ } '{ self .name_ } '>"
57
82
83
+ def get_implementation_function (self ):
84
+ """Returns the implementation function for
85
+ this elementwise unary function.
86
+
87
+ """
88
+ return self .unary_fn_
89
+
90
+ def get_type_result_resolver_function (self ):
91
+ """Returns the type resolver function for this
92
+ elementwise unary function.
93
+ """
94
+ return self .result_type_resolver_fn_
95
+
58
96
@property
59
97
def types (self ):
98
+ """Returns information about types supported by
99
+ implementation function, using NumPy's character
100
+ encoding for data types, e.g.
101
+
102
+ :Example:
103
+ .. code-block:: python
104
+
105
+ dpctl.tensor.sin.types
106
+ # Outputs: ['e->e', 'f->f', 'd->d', 'F->F', 'D->D']
107
+ """
60
108
types = self .types_
61
109
if not types :
62
110
types = []
@@ -363,6 +411,56 @@ def _get_shape(o):
363
411
class BinaryElementwiseFunc :
364
412
"""
365
413
Class that implements binary element-wise functions.
414
+
415
+ Args:
416
+ name (str):
417
+ Name of the unary function
418
+ result_type_resovle_fn (callable):
419
+ Function that takes dtypes of the input and
420
+ returns the dtype of the result if the
421
+ implementation functions supports it, or
422
+ returns `None` otherwise.
423
+ binary_dp_impl_fn (callable):
424
+ Data-parallel implementation function with signature
425
+ `impl_fn(src1: usm_ndarray, src2: usm_ndarray, dst: usm_ndarray,
426
+ sycl_queue: SyclQueue, depends: Optional[List[SyclEvent]])`
427
+ where the `src1` and `src2` are the argument arrays, `dst` is the
428
+ array to be populated with function values,
429
+ i.e. `dst=func(src1, src2)`.
430
+ The `impl_fn` is expected to return a 2-tuple of `SyclEvent`s.
431
+ The first event corresponds to data-management host tasks,
432
+ including lifetime management of argument Python objects to ensure
433
+ that their associated USM allocation is not freed before offloaded
434
+ computational tasks complete execution, while the second event
435
+ corresponds to computational tasks associated with function
436
+ evaluation.
437
+ docs (str):
438
+ Documentation string for the unary function.
439
+ binary_inplace_fn (callable, optional):
440
+ Data-parallel implementation function with signature
441
+ `impl_fn(src: usm_ndarray, dst: usm_ndarray,
442
+ sycl_queue: SyclQueue, depends: Optional[List[SyclEvent]])`
443
+ where the `src` is the argument array, `dst` is the
444
+ array to be populated with function values,
445
+ i.e. `dst=func(dst, src)`.
446
+ The `impl_fn` is expected to return a 2-tuple of `SyclEvent`s.
447
+ The first event corresponds to data-management host tasks,
448
+ including async lifetime management of Python arguments,
449
+ while the second event corresponds to computational tasks
450
+ associated with function evaluation.
451
+ acceptance_fn (callable, optional):
452
+ Function to influence type promotion behavior of this binary
453
+ function. The function takes 6 arguments:
454
+ arg1_dtype - Data type of the first argument
455
+ arg2_dtype - Data type of the second argument
456
+ ret_buf1_dtype - Data type the first argument would be cast to
457
+ ret_buf2_dtype - Data type the second argument would be cast to
458
+ res_dtype - Data type of the output array with function values
459
+ sycl_dev - The :class:`dpctl.SyclDevice` where the function
460
+ evaluation is carried out.
461
+ The function is only called when both arguments of the binary
462
+ function require casting, e.g. both arguments of
463
+ `dpctl.tensor.logaddexp` are arrays with integral data type.
366
464
"""
367
465
368
466
def __init__ (
@@ -392,8 +490,60 @@ def __str__(self):
392
490
def __repr__ (self ):
393
491
return f"<{ self .__name__ } '{ self .name_ } '>"
394
492
493
+ def get_implementation_function (self ):
494
+ """Returns the out-of-place implementation
495
+ function for this elementwise binary function.
496
+
497
+ """
498
+ return self .binary_fn_
499
+
500
+ def get_implementation_inplace_function (self ):
501
+ """Returns the in-place implementation
502
+ function for this elementwise binary function.
503
+
504
+ """
505
+ return self .binary_inplace_fn_
506
+
507
+ def get_type_result_resolver_function (self ):
508
+ """Returns the type resolver function for this
509
+ elementwise binary function.
510
+ """
511
+ return self .result_type_resolver_fn_
512
+
513
+ def get_type_promotion_path_acceptance_function (self ):
514
+ """Returns the acceptance function for this
515
+ elementwise binary function.
516
+
517
+ Acceptance function influences the type promotion
518
+ behavior of this binary function.
519
+ The function takes 6 arguments:
520
+ arg1_dtype - Data type of the first argument
521
+ arg2_dtype - Data type of the second argument
522
+ ret_buf1_dtype - Data type the first argument would be cast to
523
+ ret_buf2_dtype - Data type the second argument would be cast to
524
+ res_dtype - Data type of the output array with function values
525
+ sycl_dev - :class:`dpctl.SyclDevice` on which function evaluation
526
+ is carried out.
527
+
528
+ The acceptance function is only invoked if both input arrays must be
529
+ cast to intermediary data types, as would happen during call of
530
+ `dpctl.tensor.hypot` with both arrays being of integral data type.
531
+ """
532
+ return self .acceptance_fn_
533
+
395
534
@property
396
535
def types (self ):
536
+ """Returns information about types supported by
537
+ implementation function, using NumPy's character
538
+ encoding for data types, e.g.
539
+
540
+ :Example:
541
+ .. code-block:: python
542
+
543
+ dpctl.tensor.divide.types
544
+ # Outputs: ['ee->e', 'ff->f', 'fF->F', 'dd->d', 'dD->D',
545
+ # 'Ff->F', 'FF->F', 'Dd->D', 'DD->D']
546
+ """
397
547
types = self .types_
398
548
if not types :
399
549
types = []
0 commit comments