Skip to content

Commit f8536c1

Browse files
committed
Change acceptance function names per feedback
`_acceptance_fn_default1` and `_acceptance_fn_default2` are now `_acceptance_fn_default_unary` and `_acceptance_fn_default_binary`
1 parent 0018dfa commit f8536c1

File tree

3 files changed

+14
-10
lines changed

3 files changed

+14
-10
lines changed

dpctl/tensor/_elementwise_common.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@
2828

2929
from ._copy_utils import _empty_like_orderK, _empty_like_pair_orderK
3030
from ._type_utils import (
31-
_acceptance_fn_default1,
32-
_acceptance_fn_default2,
31+
_acceptance_fn_default_binary,
32+
_acceptance_fn_default_unary,
3333
_all_data_types,
3434
_find_buf_dtype,
3535
_find_buf_dtype2,
@@ -95,7 +95,7 @@ def __init__(
9595
if callable(acceptance_fn):
9696
self.acceptance_fn_ = acceptance_fn
9797
else:
98-
self.acceptance_fn_ = _acceptance_fn_default1
98+
self.acceptance_fn_ = _acceptance_fn_default_unary
9999

100100
def __str__(self):
101101
return f"<{self.__name__} '{self.name_}'>"
@@ -526,7 +526,7 @@ def __init__(
526526
if callable(acceptance_fn):
527527
self.acceptance_fn_ = acceptance_fn
528528
else:
529-
self.acceptance_fn_ = _acceptance_fn_default2
529+
self.acceptance_fn_ = _acceptance_fn_default_binary
530530

531531
def __str__(self):
532532
return f"<{self.__name__} '{self.name_}'>"

dpctl/tensor/_type_utils.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ def _to_device_supported_dtype(dt, dev):
132132
return dt
133133

134134

135-
def _acceptance_fn_default1(arg_dtype, ret_buf_dt, res_dt, sycl_dev):
135+
def _acceptance_fn_default_unary(arg_dtype, ret_buf_dt, res_dt, sycl_dev):
136136
return True
137137

138138

@@ -187,7 +187,7 @@ def _get_device_default_dtype(dt_kind, sycl_dev):
187187
raise RuntimeError
188188

189189

190-
def _acceptance_fn_default2(
190+
def _acceptance_fn_default_binary(
191191
arg1_dtype, arg2_dtype, ret_buf1_dt, ret_buf2_dt, res_dt, sycl_dev
192192
):
193193
return True
@@ -254,8 +254,8 @@ def _find_buf_dtype2(arg1_dtype, arg2_dtype, query_fn, sycl_dev, acceptance_fn):
254254
"_find_buf_dtype",
255255
"_find_buf_dtype2",
256256
"_to_device_supported_dtype",
257-
"_acceptance_fn_default1",
257+
"_acceptance_fn_default_unary",
258258
"_acceptance_fn_reciprocal",
259-
"_acceptance_fn_default2",
259+
"_acceptance_fn_default_binary",
260260
"_acceptance_fn_divide",
261261
]

dpctl/tests/elementwise/test_type_utils.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def _denier_fn(dt):
125125
dev = MockDevice(fp16, fp64)
126126
arg_dt = dpt.float64
127127
r = tu._find_buf_dtype(
128-
arg_dt, _denier_fn, dev, tu._acceptance_fn_default1
128+
arg_dt, _denier_fn, dev, tu._acceptance_fn_default_unary
129129
)
130130
assert r == (
131131
None,
@@ -159,7 +159,11 @@ def _denier_fn(dt1, dt2):
159159
arg1_dt = dpt.float64
160160
arg2_dt = dpt.complex64
161161
r = tu._find_buf_dtype2(
162-
arg1_dt, arg2_dt, _denier_fn, dev, tu._acceptance_fn_default2
162+
arg1_dt,
163+
arg2_dt,
164+
_denier_fn,
165+
dev,
166+
tu._acceptance_fn_default_binary,
163167
)
164168
assert r == (
165169
None,

0 commit comments

Comments
 (0)