Skip to content

Implements logaddexp and hypot #1272

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Jul 24, 2023
4 changes: 4 additions & 0 deletions dpctl/tensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@
floor_divide,
greater,
greater_equal,
hypot,
imag,
isfinite,
isinf,
Expand All @@ -115,6 +116,7 @@
log1p,
log2,
log10,
logaddexp,
logical_and,
logical_not,
logical_or,
Expand Down Expand Up @@ -222,6 +224,7 @@
"floor_divide",
"greater",
"greater_equal",
"hypot",
"imag",
"isfinite",
"isinf",
Expand All @@ -241,6 +244,7 @@
"not_equal",
"positive",
"pow",
"logaddexp",
"proj",
"real",
"sin",
Expand Down
31 changes: 24 additions & 7 deletions dpctl/tensor/_elementwise_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from dpctl.utils import ExecutionPlacementError

from ._type_utils import (
_acceptance_fn_default,
_empty_like_orderK,
_empty_like_pair_orderK,
_find_buf_dtype,
Expand All @@ -48,6 +49,12 @@ def __init__(self, name, result_type_resolver_fn, unary_dp_impl_fn, docs):
self.unary_fn_ = unary_dp_impl_fn
self.__doc__ = docs

def __str__(self):
return f"<{self.__name__} '{self.name_}'>"

def __repr__(self):
return f"<{self.__name__} '{self.name_}'>"

def __call__(self, x, out=None, order="K"):
if not isinstance(x, dpt.usm_ndarray):
raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x)}")
Expand Down Expand Up @@ -357,27 +364,33 @@ def __init__(
binary_dp_impl_fn,
docs,
binary_inplace_fn=None,
acceptance_fn=None,
):
self.__name__ = "BinaryElementwiseFunc"
self.name_ = name
self.result_type_resolver_fn_ = result_type_resolver_fn
self.binary_fn_ = binary_dp_impl_fn
self.binary_inplace_fn_ = binary_inplace_fn
self.__doc__ = docs
if callable(acceptance_fn):
self.acceptance_fn_ = acceptance_fn
else:
self.acceptance_fn_ = _acceptance_fn_default

def __str__(self):
return f"<BinaryElementwiseFunc '{self.name_}'>"
return f"<{self.__name__} '{self.name_}'>"

def __repr__(self):
return f"<BinaryElementwiseFunc '{self.name_}'>"
return f"<{self.__name__} '{self.name_}'>"

def __call__(self, o1, o2, out=None, order="K"):
# FIXME: replace with check against base array
# when views can be identified
if o1 is out:
return self._inplace(o1, o2)
elif o2 is out:
return self._inplace(o2, o1)
if self.binary_inplace_fn_:
if o1 is out:
return self._inplace(o1, o2)
elif o2 is out:
return self._inplace(o2, o1)

if order not in ["K", "C", "F", "A"]:
order = "K"
Expand Down Expand Up @@ -445,7 +458,11 @@ def __call__(self, o1, o2, out=None, order="K"):
o1_dtype, o2_dtype = _resolve_weak_types(o1_dtype, o2_dtype, sycl_dev)

buf1_dt, buf2_dt, res_dt = _find_buf_dtype2(
o1_dtype, o2_dtype, self.result_type_resolver_fn_, sycl_dev
o1_dtype,
o2_dtype,
self.result_type_resolver_fn_,
sycl_dev,
acceptance_fn=self.acceptance_fn_,
)

if res_dt is None:
Expand Down
66 changes: 62 additions & 4 deletions dpctl/tensor/_elementwise_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import dpctl.tensor._tensor_impl as ti

from ._elementwise_common import BinaryElementwiseFunc, UnaryElementwiseFunc
from ._type_utils import _acceptance_fn_divide

# U01: ==== ABS (x)
_abs_docstring_ = """
Expand Down Expand Up @@ -215,7 +216,11 @@
"""

divide = BinaryElementwiseFunc(
"divide", ti._divide_result_type, ti._divide, _divide_docstring_
"divide",
ti._divide_result_type,
ti._divide,
_divide_docstring_,
acceptance_fn=_acceptance_fn_divide,
)

# B09: ==== EQUAL (x1, x2)
Expand Down Expand Up @@ -661,7 +666,32 @@
)

# B15: ==== LOGADDEXP (x1, x2)
# FIXME: implement B15
_logaddexp_docstring_ = """
logaddexp(x1, x2, out=None, order='K')

Calculates the ratio for each element `x1_i` of the input array `x1` with
the respective element `x2_i` of the input array `x2`.

Args:
x1 (usm_ndarray):
First input array, expected to have numeric data type.
x2 (usm_ndarray):
Second input array, also expected to have numeric data type.
out ({None, usm_ndarray}, optional):
Output array to populate.
Array have the correct shape and the expected data type.
order ("C","F","A","K", optional):
Memory layout of the newly output array, if parameter `out` is `None`.
Default: "K".
Returns:
usm_narray:
An array containing the result of element-wise division. The data type
of the returned array is determined by the Type Promotion Rules.
"""

logaddexp = BinaryElementwiseFunc(
"logaddexp", ti._logaddexp_result_type, ti._logaddexp, _logaddexp_docstring_
)

# B16: ==== LOGICAL_AND (x1, x2)
_logical_and_docstring_ = """
Expand Down Expand Up @@ -1094,12 +1124,40 @@
order ("C","F","A","K", optional):
Memory layout of the newly output array, if parameter `out` is `None`.
Default: "K".
Returns:
usm_narray:
An array containing the result of element-wise division. The data type
of the returned array is determined by the Type Promotion Rules.
"""
trunc = UnaryElementwiseFunc(
"trunc", ti._trunc_result_type, ti._trunc, _trunc_docstring
)


# B24: ==== HYPOT (x1, x2)
_hypot_docstring_ = """
hypot(x1, x2, out=None, order='K')

Calculates the ratio for each element `x1_i` of the input array `x1` with
the respective element `x2_i` of the input array `x2`.

Args:
x1 (usm_ndarray):
First input array, expected to have numeric data type.
x2 (usm_ndarray):
Second input array, also expected to have numeric data type.
out ({None, usm_ndarray}, optional):
Output array to populate.
Array have the correct shape and the expected data type.
order ("C","F","A","K", optional):
Memory layout of the newly output array, if parameter `out` is `None`.
Default: "K".
Returns:
usm_narray:
An array containing the element-wise truncated value of input array.
The returned array has the same data type as `x`.
"""

trunc = UnaryElementwiseFunc(
"trunc", ti._trunc_result_type, ti._trunc, _trunc_docstring
hypot = BinaryElementwiseFunc(
"hypot", ti._hypot_result_type, ti._hypot, _hypot_docstring_
)
56 changes: 41 additions & 15 deletions dpctl/tensor/_type_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,34 @@ def _get_device_default_dtype(dt_kind, sycl_dev):
raise RuntimeError


def _find_buf_dtype2(arg1_dtype, arg2_dtype, query_fn, sycl_dev):
def _acceptance_fn_default(
arg1_dtype, arg2_dtype, ret_buf1_dt, ret_buf2_dt, res_dt, sycl_dev
):
return True


def _acceptance_fn_divide(
arg1_dtype, arg2_dtype, ret_buf1_dt, ret_buf2_dt, res_dt, sycl_dev
):
# both are being promoted, if the kind of result is
# different than the kind of original input dtypes,
# we use default dtype for the resulting kind.
# This covers, e.g. (array_dtype_i1 / array_dtype_u1)
# result of which in divide is double (in NumPy), but
# regular type promotion rules peg at float16
if (ret_buf1_dt.kind != arg1_dtype.kind) and (
ret_buf2_dt.kind != arg2_dtype.kind
):
default_dt = _get_device_default_dtype(res_dt.kind, sycl_dev)
if res_dt == default_dt:
return True
else:
return False
else:
return True


def _find_buf_dtype2(arg1_dtype, arg2_dtype, query_fn, sycl_dev, acceptance_fn):
res_dt = query_fn(arg1_dtype, arg2_dtype)
if res_dt:
return None, None, res_dt
Expand All @@ -275,21 +302,18 @@ def _find_buf_dtype2(arg1_dtype, arg2_dtype, query_fn, sycl_dev):
if ret_buf1_dt is None or ret_buf2_dt is None:
return ret_buf1_dt, ret_buf2_dt, res_dt
else:
# both are being promoted, if the kind of result is
# different than the kind of original input dtypes,
# we must use default dtype for the resulting kind.
if (res_dt.kind != arg1_dtype.kind) and (
res_dt.kind != arg2_dtype.kind
):
default_dt = _get_device_default_dtype(
res_dt.kind, sycl_dev
)
if res_dt == default_dt:
return ret_buf1_dt, ret_buf2_dt, res_dt
else:
continue
else:
acceptable = acceptance_fn(
arg1_dtype,
arg2_dtype,
ret_buf1_dt,
ret_buf2_dt,
res_dt,
sycl_dev,
)
if acceptable:
return ret_buf1_dt, ret_buf2_dt, res_dt
else:
continue

return None, None, None

Expand Down Expand Up @@ -318,4 +342,6 @@ def _find_inplace_dtype(lhs_dtype, rhs_dtype, query_fn, sycl_dev):
"_empty_like_orderK",
"_empty_like_pair_orderK",
"_to_device_supported_dtype",
"_acceptance_fn_default",
"_acceptance_fn_divide",
]
Loading