diff --git a/dpctl/tensor/__init__.py b/dpctl/tensor/__init__.py index 4508e1e3e3..cdb701e1cb 100644 --- a/dpctl/tensor/__init__.py +++ b/dpctl/tensor/__init__.py @@ -64,16 +64,12 @@ from dpctl.tensor._manipulation_functions import ( broadcast_arrays, broadcast_to, - can_cast, concat, expand_dims, - finfo, flip, - iinfo, moveaxis, permute_dims, repeat, - result_type, roll, squeeze, stack, @@ -180,6 +176,7 @@ sum, ) from ._testing import allclose +from ._type_utils import can_cast, finfo, iinfo, result_type __all__ = [ "Device", diff --git a/dpctl/tensor/_data_types.py b/dpctl/tensor/_data_types.py index bf8a5f59c2..bee557cf18 100644 --- a/dpctl/tensor/_data_types.py +++ b/dpctl/tensor/_data_types.py @@ -120,6 +120,7 @@ def _get_dtype(inp_dt, sycl_obj, ref_type=None): __all__ = [ "dtype", + "_get_dtype", "isdtype", "bool", "int8", diff --git a/dpctl/tensor/_manipulation_functions.py b/dpctl/tensor/_manipulation_functions.py index 7135304b58..0fc288a0f1 100644 --- a/dpctl/tensor/_manipulation_functions.py +++ b/dpctl/tensor/_manipulation_functions.py @@ -27,7 +27,7 @@ import dpctl.utils as dputils from ._copy_utils import _broadcast_strides -from ._type_utils import _to_device_supported_dtype +from ._type_utils import _supported_dtype, _to_device_supported_dtype __doc__ = ( "Implementation module for array manipulation " @@ -35,93 +35,6 @@ ) -class finfo_object: - """ - `numpy.finfo` subclass which returns Python floating-point scalars for - `eps`, `max`, `min`, and `smallest_normal` attributes. - """ - - def __init__(self, dtype): - _supported_dtype([dpt.dtype(dtype)]) - self._finfo = np.finfo(dtype) - - @property - def bits(self): - """ - number of bits occupied by the real-valued floating-point data type. - """ - return int(self._finfo.bits) - - @property - def smallest_normal(self): - """ - smallest positive real-valued floating-point number with full - precision. - """ - return float(self._finfo.smallest_normal) - - @property - def tiny(self): - """an alias for `smallest_normal`""" - return float(self._finfo.tiny) - - @property - def eps(self): - """ - difference between 1.0 and the next smallest representable real-valued - floating-point number larger than 1.0 according to the IEEE-754 - standard. - """ - return float(self._finfo.eps) - - @property - def epsneg(self): - """ - difference between 1.0 and the next smallest representable real-valued - floating-point number smaller than 1.0 according to the IEEE-754 - standard. - """ - return float(self._finfo.epsneg) - - @property - def min(self): - """smallest representable real-valued number.""" - return float(self._finfo.min) - - @property - def max(self): - "largest representable real-valued number." - return float(self._finfo.max) - - @property - def resolution(self): - "the approximate decimal resolution of this type." - return float(self._finfo.resolution) - - @property - def precision(self): - """ - the approximate number of decimal digits to which this kind of - floating point type is precise. - """ - return float(self._finfo.precision) - - @property - def dtype(self): - """ - the dtype for which finfo returns information. For complex input, the - returned dtype is the associated floating point dtype for its real and - complex components. - """ - return self._finfo.dtype - - def __str__(self): - return self._finfo.__str__() - - def __repr__(self): - return self._finfo.__repr__() - - def _broadcast_shape_impl(shapes): if len(set(shapes)) == 1: return shapes[0] @@ -681,127 +594,6 @@ def stack(arrays, axis=0): return res -def can_cast(from_, to, casting="safe"): - """ can_cast(from, to, casting="safe") - - Determines if one data type can be cast to another data type according \ - to Type Promotion Rules. - - Args: - from (usm_ndarray, dtype): source data type - to (dtype): target data type - casting ({'no', 'equiv', 'safe', 'same_kind', 'unsafe'}, optional): - controls what kind of data casting may occur. - - Returns: - bool: - Gives `True` if cast can occur according to the casting rule. - """ - if isinstance(to, dpt.usm_ndarray): - raise TypeError("Expected dtype type.") - - dtype_to = dpt.dtype(to) - - dtype_from = ( - from_.dtype if isinstance(from_, dpt.usm_ndarray) else dpt.dtype(from_) - ) - - _supported_dtype([dtype_from, dtype_to]) - - return np.can_cast(dtype_from, dtype_to, casting) - - -def result_type(*arrays_and_dtypes): - """ - result_type(arrays_and_dtypes) - - Returns the dtype that results from applying the Type Promotion Rules to \ - the arguments. - - Args: - arrays_and_dtypes (object): - An arbitrary length sequence of arrays or dtypes. - - Returns: - dtype: - The dtype resulting from an operation involving the - input arrays and dtypes. - """ - dtypes = [ - X.dtype if isinstance(X, dpt.usm_ndarray) else dpt.dtype(X) - for X in arrays_and_dtypes - ] - - _supported_dtype(dtypes) - - return np.result_type(*dtypes) - - -def iinfo(dtype): - """iinfo(dtype) - - Returns machine limits for integer data types. - - Args: - dtype (dtype, usm_ndarray): - integer dtype or - an array with integer dtype. - - Returns: - iinfo_object: - An object with the following attributes - * bits: int - number of bits occupied by the data type - * max: int - largest representable number. - * min: int - smallest representable number. - * dtype: dtype - integer data type. - """ - if isinstance(dtype, dpt.usm_ndarray): - dtype = dtype.dtype - _supported_dtype([dpt.dtype(dtype)]) - return np.iinfo(dtype) - - -def finfo(dtype): - """finfo(type) - - Returns machine limits for floating-point data types. - - Args: - dtype (dtype, usm_ndarray): floating-point dtype or - an array with floating point data type. - If complex, the information is about its component - data type. - - Returns: - finfo_object: - an object have the following attributes - * bits: int - number of bits occupied by dtype. - * eps: float - difference between 1.0 and the next smallest representable - real-valued floating-point number larger than 1.0 according - to the IEEE-754 standard. - * max: float - largest representable real-valued number. - * min: float - smallest representable real-valued number. - * smallest_normal: float - smallest positive real-valued floating-point number with - full precision. - * dtype: dtype - real-valued floating-point data type. - - """ - if isinstance(dtype, dpt.usm_ndarray): - dtype = dtype.dtype - _supported_dtype([dpt.dtype(dtype)]) - return finfo_object(dtype) - - def unstack(X, axis=0): """unstack(x, axis=0) @@ -1229,10 +1021,3 @@ def tile(x, repetitions): ) hev.wait() return dpt.reshape(res, res_shape) - - -def _supported_dtype(dtypes): - for dtype in dtypes: - if dtype.char not in "?bBhHiIlLqQefdFD": - raise ValueError(f"Dpctl doesn't support dtype {dtype}.") - return True diff --git a/dpctl/tensor/_type_utils.py b/dpctl/tensor/_type_utils.py index bacd488226..c1f6027ccf 100644 --- a/dpctl/tensor/_type_utils.py +++ b/dpctl/tensor/_type_utils.py @@ -14,23 +14,27 @@ # See the License for the specific language governing permissions and # limitations under the License. +import numpy as np + import dpctl.tensor as dpt import dpctl.tensor._tensor_impl as ti def _all_data_types(_fp16, _fp64): + _non_fp_types = [ + dpt.bool, + dpt.int8, + dpt.uint8, + dpt.int16, + dpt.uint16, + dpt.int32, + dpt.uint32, + dpt.int64, + dpt.uint64, + ] if _fp64: if _fp16: - return [ - dpt.bool, - dpt.int8, - dpt.uint8, - dpt.int16, - dpt.uint16, - dpt.int32, - dpt.uint32, - dpt.int64, - dpt.uint64, + return _non_fp_types + [ dpt.float16, dpt.float32, dpt.float64, @@ -38,16 +42,7 @@ def _all_data_types(_fp16, _fp64): dpt.complex128, ] else: - return [ - dpt.bool, - dpt.int8, - dpt.uint8, - dpt.int16, - dpt.uint16, - dpt.int32, - dpt.uint32, - dpt.int64, - dpt.uint64, + return _non_fp_types + [ dpt.float32, dpt.float64, dpt.complex64, @@ -55,31 +50,13 @@ def _all_data_types(_fp16, _fp64): ] else: if _fp16: - return [ - dpt.bool, - dpt.int8, - dpt.uint8, - dpt.int16, - dpt.uint16, - dpt.int32, - dpt.uint32, - dpt.int64, - dpt.uint64, + return _non_fp_types + [ dpt.float16, dpt.float32, dpt.complex64, ] else: - return [ - dpt.bool, - dpt.int8, - dpt.uint8, - dpt.int16, - dpt.uint16, - dpt.int32, - dpt.uint32, - dpt.int64, - dpt.uint64, + return _non_fp_types + [ dpt.float32, dpt.complex64, ] @@ -95,12 +72,33 @@ def _is_maximal_inexact_type(dt: dpt.dtype, _fp16: bool, _fp64: bool): return dt in [dpt.float32, dpt.complex64] -def _can_cast(from_: dpt.dtype, to_: dpt.dtype, _fp16: bool, _fp64: bool): +def _dtype_supported_by_device_impl( + dt: dpt.dtype, has_fp16: bool, has_fp64: bool +) -> bool: + if has_fp64: + if not has_fp16: + if dt is dpt.float16: + return False + else: + if dt is dpt.float64: + return False + elif dt is dpt.complex128: + return False + if not has_fp16 and dt is dpt.float16: + return False + return True + + +def _can_cast( + from_: dpt.dtype, to_: dpt.dtype, _fp16: bool, _fp64: bool, casting="safe" +) -> bool: """ Can `from_` be cast to `to_` safely on a device with fp16 and fp64 aspects as given? """ - can_cast_v = dpt.can_cast(from_, to_) # ask NumPy + if not _dtype_supported_by_device_impl(to_, _fp16, _fp64): + return False + can_cast_v = np.can_cast(from_, to_, casting=casting) # ask NumPy if _fp16 and _fp64: return can_cast_v if not can_cast_v: @@ -114,10 +112,7 @@ def _can_cast(from_: dpt.dtype, to_: dpt.dtype, _fp16: bool, _fp64: bool): return can_cast_v -def _to_device_supported_dtype(dt, dev): - has_fp16 = dev.has_aspect_fp16 - has_fp64 = dev.has_aspect_fp64 - +def _to_device_supported_dtype_impl(dt, has_fp16, has_fp64): if has_fp64: if not has_fp16: if dt is dpt.float16: @@ -132,6 +127,13 @@ def _to_device_supported_dtype(dt, dev): return dt +def _to_device_supported_dtype(dt, dev): + has_fp16 = dev.has_aspect_fp16 + has_fp64 = dev.has_aspect_fp64 + + return _to_device_supported_dtype_impl(dt, has_fp16, has_fp64) + + def _acceptance_fn_default_unary(arg_dtype, ret_buf_dt, res_dt, sycl_dev): return True @@ -250,6 +252,274 @@ def _find_buf_dtype2(arg1_dtype, arg2_dtype, query_fn, sycl_dev, acceptance_fn): return None, None, None +class finfo_object: + """ + `numpy.finfo` subclass which returns Python floating-point scalars for + `eps`, `max`, `min`, and `smallest_normal` attributes. + """ + + def __init__(self, dtype): + _supported_dtype([dpt.dtype(dtype)]) + self._finfo = np.finfo(dtype) + + @property + def bits(self): + """ + number of bits occupied by the real-valued floating-point data type. + """ + return int(self._finfo.bits) + + @property + def smallest_normal(self): + """ + smallest positive real-valued floating-point number with full + precision. + """ + return float(self._finfo.smallest_normal) + + @property + def tiny(self): + """an alias for `smallest_normal`""" + return float(self._finfo.tiny) + + @property + def eps(self): + """ + difference between 1.0 and the next smallest representable real-valued + floating-point number larger than 1.0 according to the IEEE-754 + standard. + """ + return float(self._finfo.eps) + + @property + def epsneg(self): + """ + difference between 1.0 and the next smallest representable real-valued + floating-point number smaller than 1.0 according to the IEEE-754 + standard. + """ + return float(self._finfo.epsneg) + + @property + def min(self): + """smallest representable real-valued number.""" + return float(self._finfo.min) + + @property + def max(self): + "largest representable real-valued number." + return float(self._finfo.max) + + @property + def resolution(self): + "the approximate decimal resolution of this type." + return float(self._finfo.resolution) + + @property + def precision(self): + """ + the approximate number of decimal digits to which this kind of + floating point type is precise. + """ + return float(self._finfo.precision) + + @property + def dtype(self): + """ + the dtype for which finfo returns information. For complex input, the + returned dtype is the associated floating point dtype for its real and + complex components. + """ + return self._finfo.dtype + + def __str__(self): + return self._finfo.__str__() + + def __repr__(self): + return self._finfo.__repr__() + + +def can_cast(from_, to, casting="safe"): + """ can_cast(from, to, casting="safe") + + Determines if one data type can be cast to another data type according \ + to Type Promotion Rules. + + Args: + from_ (Union[usm_ndarray, dtype]): + source data type. If `from_` is an array, a device-specific type + promotion rules apply. + to (dtype): + target data type + casting (Optional[str]): + controls what kind of data casting may occur. + * "no" means data types should not be cast at all. + * "safe" means only casts that preserve values are allowed. + * "same_kind" means only safe casts and casts within a kind, + like `float64` to `float32`, are allowed. + * "unsafe" means any data conversion can be done. + Default: `"safe"`. + + Returns: + bool: + Gives `True` if cast can occur according to the casting rule. + + Device-specific type promotion rules take into account which data type are + and are not supported by a specific device. + """ + if isinstance(to, dpt.usm_ndarray): + raise TypeError(f"Expected `dpt.dtype` type, got {type(to)}.") + + dtype_to = dpt.dtype(to) + _supported_dtype([dtype_to]) + + if isinstance(from_, dpt.usm_ndarray): + dtype_from = from_.dtype + return _can_cast( + dtype_from, + dtype_to, + from_.sycl_device.has_aspect_fp16, + from_.sycl_device.has_aspect_fp64, + casting=casting, + ) + else: + dtype_from = dpt.dtype(from_) + _supported_dtype([dtype_from]) + # query casting as if all dtypes are supported + return _can_cast(dtype_from, dtype_to, True, True, casting=casting) + + +def result_type(*arrays_and_dtypes): + """ + result_type(*arrays_and_dtypes) + + Returns the dtype that results from applying the Type Promotion Rules to \ + the arguments. + + Args: + arrays_and_dtypes (Union[usm_ndarray, dtype]): + An arbitrary length sequence of usm_ndarray objects or dtypes. + + Returns: + dtype: + The dtype resulting from an operation involving the + input arrays and dtypes. + """ + dtypes = [] + devices = [] + for arg_i in arrays_and_dtypes: + if isinstance(arg_i, dpt.usm_ndarray): + devices.append(arg_i.sycl_device) + dtypes.append(arg_i.dtype) + else: + dt = dpt.dtype(arg_i) + _supported_dtype([dt]) + dtypes.append(dt) + + has_fp16 = True + has_fp64 = True + if devices: + inspected = False + for d in devices: + if inspected: + unsame_fp16_support = d.has_aspect_fp16 != has_fp16 + unsame_fp64_support = d.has_aspect_fp64 != has_fp64 + if unsame_fp16_support or unsame_fp64_support: + raise ValueError( + "Input arrays reside on devices " + "with different device supports; " + "unable to determine which " + "device-specific type promotion rules " + "to use." + ) + else: + has_fp16 = d.has_aspect_fp16 + has_fp64 = d.has_aspect_fp64 + inspected = True + + if not (has_fp16 and has_fp64): + for dt in dtypes: + if not _dtype_supported_by_device_impl(dt, has_fp16, has_fp64): + raise ValueError(f"Argument {dt} is not supported by ") + res_dt = np.result_type(*dtypes) + res_dt = _to_device_supported_dtype_impl(res_dt, has_fp16, has_fp64) + return res_dt + + return np.result_type(*dtypes) + + +def iinfo(dtype): + """iinfo(dtype) + + Returns machine limits for integer data types. + + Args: + dtype (dtype, usm_ndarray): + integer dtype or + an array with integer dtype. + + Returns: + iinfo_object: + An object with the following attributes + * bits: int + number of bits occupied by the data type + * max: int + largest representable number. + * min: int + smallest representable number. + * dtype: dtype + integer data type. + """ + if isinstance(dtype, dpt.usm_ndarray): + dtype = dtype.dtype + _supported_dtype([dpt.dtype(dtype)]) + return np.iinfo(dtype) + + +def finfo(dtype): + """finfo(type) + + Returns machine limits for floating-point data types. + + Args: + dtype (dtype, usm_ndarray): floating-point dtype or + an array with floating point data type. + If complex, the information is about its component + data type. + + Returns: + finfo_object: + an object have the following attributes + * bits: int + number of bits occupied by dtype. + * eps: float + difference between 1.0 and the next smallest representable + real-valued floating-point number larger than 1.0 according + to the IEEE-754 standard. + * max: float + largest representable real-valued number. + * min: float + smallest representable real-valued number. + * smallest_normal: float + smallest positive real-valued floating-point number with + full precision. + * dtype: dtype + real-valued floating-point data type. + + """ + if isinstance(dtype, dpt.usm_ndarray): + dtype = dtype.dtype + _supported_dtype([dpt.dtype(dtype)]) + return finfo_object(dtype) + + +def _supported_dtype(dtypes): + for dtype in dtypes: + if dtype.char not in "?bBhHiIlLqQefdFD": + raise ValueError(f"Dpctl doesn't support dtype {dtype}.") + return True + + __all__ = [ "_find_buf_dtype", "_find_buf_dtype2", @@ -258,4 +528,8 @@ def _find_buf_dtype2(arg1_dtype, arg2_dtype, query_fn, sycl_dev, acceptance_fn): "_acceptance_fn_reciprocal", "_acceptance_fn_default_binary", "_acceptance_fn_divide", + "can_cast", + "finfo", + "iinfo", + "result_type", ] diff --git a/dpctl/tests/test_tensor_dtype_routines.py b/dpctl/tests/test_tensor_dtype_routines.py index acb1bb6d8b..98f84076ab 100644 --- a/dpctl/tests/test_tensor_dtype_routines.py +++ b/dpctl/tests/test_tensor_dtype_routines.py @@ -17,6 +17,7 @@ import pytest +import dpctl import dpctl.tensor as dpt list_dtypes = [ @@ -127,3 +128,31 @@ def test_isdtype_kind_tuple_dtypes(dtype_str): def test_isdtype_invalid_kind(kind): with pytest.raises((TypeError, ValueError)): dpt.isdtype(dpt.int32, kind) + + +def test_finfo_array(): + try: + x = dpt.empty(tuple(), dtype="f4") + except dpctl.SyclDeviceCreationError: + pytest.skip("Default-selected SYCL device unavailable") + o = dpt.finfo(x) + assert o.dtype == dpt.float32 + + +def test_iinfo_array(): + try: + x = dpt.empty(tuple(), dtype="i4") + except dpctl.SyclDeviceCreationError: + pytest.skip("Default-selected SYCL device unavailable") + o = dpt.iinfo(x) + assert o.dtype == dpt.int32 + + +def test_iinfo_validation(): + with pytest.raises(ValueError): + dpt.iinfo("O") + + +def test_finfo_validation(): + with pytest.raises(ValueError): + dpt.iinfo("O") diff --git a/dpctl/tests/test_usm_ndarray_manipulation.py b/dpctl/tests/test_usm_ndarray_manipulation.py index f3704274d4..54c2c2380a 100644 --- a/dpctl/tests/test_usm_ndarray_manipulation.py +++ b/dpctl/tests/test_usm_ndarray_manipulation.py @@ -922,10 +922,10 @@ def test_can_cast(): q = get_queue_or_skip() # incorrect input - X = dpt.ones((2, 2), dtype=dpt.int64, sycl_queue=q) + X = dpt.ones((2, 2), dtype=dpt.int16, sycl_queue=q) pytest.raises(TypeError, dpt.can_cast, X, 1) pytest.raises(TypeError, dpt.can_cast, X, X) - X_np = np.ones((2, 2), dtype=np.int64) + X_np = np.ones((2, 2), dtype=np.int16) assert dpt.can_cast(X, "float32") == np.can_cast(X_np, "float32") assert dpt.can_cast(X, dpt.int32) == np.can_cast(X_np, np.int32) @@ -935,8 +935,8 @@ def test_can_cast(): def test_result_type(): q = get_queue_or_skip() - X = [dpt.ones((2), dtype=dpt.int64, sycl_queue=q), dpt.int32, "float16"] - X_np = [np.ones((2), dtype=np.int64), np.int32, "float16"] + X = [dpt.ones((2), dtype=dpt.int16, sycl_queue=q), dpt.int32, "int64"] + X_np = [np.ones((2), dtype=np.int16), np.int32, "int64"] assert dpt.result_type(*X) == np.result_type(*X_np)