Skip to content

Commit fbe0081

Browse files
Merge pull request #1473 from IntelPython/fix_result_type
Fixed dpctl.tensor.result_type function for scalars
2 parents 700079f + 01e9d9c commit fbe0081

File tree

5 files changed

+238
-132
lines changed

5 files changed

+238
-132
lines changed

dpctl/tensor/_clip.py

+9-6
Original file line numberDiff line numberDiff line change
@@ -24,21 +24,24 @@
2424
_empty_like_triple_orderK,
2525
)
2626
from dpctl.tensor._elementwise_common import (
27-
WeakBooleanType,
28-
WeakComplexType,
29-
WeakFloatingType,
30-
WeakIntegralType,
3127
_get_dtype,
3228
_get_queue_usm_type,
3329
_get_shape,
34-
_strong_dtype_num_kind,
3530
_validate_dtype,
36-
_weak_type_num_kind,
3731
)
3832
from dpctl.tensor._manipulation_functions import _broadcast_shape_impl
3933
from dpctl.tensor._type_utils import _can_cast, _to_device_supported_dtype
4034
from dpctl.utils import ExecutionPlacementError
4135

36+
from ._type_utils import (
37+
WeakBooleanType,
38+
WeakComplexType,
39+
WeakFloatingType,
40+
WeakIntegralType,
41+
_strong_dtype_num_kind,
42+
_weak_type_num_kind,
43+
)
44+
4245

4346
def _resolve_one_strong_two_weak_types(st_dtype, dtype1, dtype2, dev):
4447
"Resolves weak data types per NEP-0050,"

dpctl/tensor/_elementwise_common.py

+5-121
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,16 @@
2828

2929
from ._copy_utils import _empty_like_orderK, _empty_like_pair_orderK
3030
from ._type_utils import (
31+
WeakBooleanType,
32+
WeakComplexType,
33+
WeakFloatingType,
34+
WeakIntegralType,
3135
_acceptance_fn_default_binary,
3236
_acceptance_fn_default_unary,
3337
_all_data_types,
3438
_find_buf_dtype,
3539
_find_buf_dtype2,
40+
_resolve_weak_types,
3641
_to_device_supported_dtype,
3742
)
3843

@@ -286,46 +291,6 @@ def _get_queue_usm_type(o):
286291
return None, None
287292

288293

289-
class WeakBooleanType:
290-
"Python type representing type of Python boolean objects"
291-
292-
def __init__(self, o):
293-
self.o_ = o
294-
295-
def get(self):
296-
return self.o_
297-
298-
299-
class WeakIntegralType:
300-
"Python type representing type of Python integral objects"
301-
302-
def __init__(self, o):
303-
self.o_ = o
304-
305-
def get(self):
306-
return self.o_
307-
308-
309-
class WeakFloatingType:
310-
"""Python type representing type of Python floating point objects"""
311-
312-
def __init__(self, o):
313-
self.o_ = o
314-
315-
def get(self):
316-
return self.o_
317-
318-
319-
class WeakComplexType:
320-
"""Python type representing type of Python complex floating point objects"""
321-
322-
def __init__(self, o):
323-
self.o_ = o
324-
325-
def get(self):
326-
return self.o_
327-
328-
329294
def _get_dtype(o, dev):
330295
if isinstance(o, dpt.usm_ndarray):
331296
return o.dtype
@@ -375,87 +340,6 @@ def _validate_dtype(dt) -> bool:
375340
)
376341

377342

378-
def _weak_type_num_kind(o):
379-
_map = {"?": 0, "i": 1, "f": 2, "c": 3}
380-
if isinstance(o, WeakBooleanType):
381-
return _map["?"]
382-
if isinstance(o, WeakIntegralType):
383-
return _map["i"]
384-
if isinstance(o, WeakFloatingType):
385-
return _map["f"]
386-
if isinstance(o, WeakComplexType):
387-
return _map["c"]
388-
raise TypeError(
389-
f"Unexpected type {o} while expecting "
390-
"`WeakBooleanType`, `WeakIntegralType`,"
391-
"`WeakFloatingType`, or `WeakComplexType`."
392-
)
393-
394-
395-
def _strong_dtype_num_kind(o):
396-
_map = {"b": 0, "i": 1, "u": 1, "f": 2, "c": 3}
397-
if not isinstance(o, dpt.dtype):
398-
raise TypeError
399-
k = o.kind
400-
if k in _map:
401-
return _map[k]
402-
raise ValueError(f"Unrecognized kind {k} for dtype {o}")
403-
404-
405-
def _resolve_weak_types(o1_dtype, o2_dtype, dev):
406-
"Resolves weak data type per NEP-0050"
407-
if isinstance(
408-
o1_dtype,
409-
(WeakBooleanType, WeakIntegralType, WeakFloatingType, WeakComplexType),
410-
):
411-
if isinstance(
412-
o2_dtype,
413-
(
414-
WeakBooleanType,
415-
WeakIntegralType,
416-
WeakFloatingType,
417-
WeakComplexType,
418-
),
419-
):
420-
raise ValueError
421-
o1_kind_num = _weak_type_num_kind(o1_dtype)
422-
o2_kind_num = _strong_dtype_num_kind(o2_dtype)
423-
if o1_kind_num > o2_kind_num:
424-
if isinstance(o1_dtype, WeakIntegralType):
425-
return dpt.dtype(ti.default_device_int_type(dev)), o2_dtype
426-
if isinstance(o1_dtype, WeakComplexType):
427-
if o2_dtype is dpt.float16 or o2_dtype is dpt.float32:
428-
return dpt.complex64, o2_dtype
429-
return (
430-
_to_device_supported_dtype(dpt.complex128, dev),
431-
o2_dtype,
432-
)
433-
return _to_device_supported_dtype(dpt.float64, dev), o2_dtype
434-
else:
435-
return o2_dtype, o2_dtype
436-
elif isinstance(
437-
o2_dtype,
438-
(WeakBooleanType, WeakIntegralType, WeakFloatingType, WeakComplexType),
439-
):
440-
o1_kind_num = _strong_dtype_num_kind(o1_dtype)
441-
o2_kind_num = _weak_type_num_kind(o2_dtype)
442-
if o2_kind_num > o1_kind_num:
443-
if isinstance(o2_dtype, WeakIntegralType):
444-
return o1_dtype, dpt.dtype(ti.default_device_int_type(dev))
445-
if isinstance(o2_dtype, WeakComplexType):
446-
if o1_dtype is dpt.float16 or o1_dtype is dpt.float32:
447-
return o1_dtype, dpt.complex64
448-
return o1_dtype, _to_device_supported_dtype(dpt.complex128, dev)
449-
return (
450-
o1_dtype,
451-
_to_device_supported_dtype(dpt.float64, dev),
452-
)
453-
else:
454-
return o1_dtype, o1_dtype
455-
else:
456-
return o1_dtype, o2_dtype
457-
458-
459343
def _get_shape(o):
460344
if isinstance(o, dpt.usm_ndarray):
461345
return o.shape

dpctl/tensor/_type_utils.py

+152-3
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,127 @@ def _find_buf_dtype2(arg1_dtype, arg2_dtype, query_fn, sycl_dev, acceptance_fn):
252252
return None, None, None
253253

254254

255+
class WeakBooleanType:
256+
"Python type representing type of Python boolean objects"
257+
258+
def __init__(self, o):
259+
self.o_ = o
260+
261+
def get(self):
262+
return self.o_
263+
264+
265+
class WeakIntegralType:
266+
"Python type representing type of Python integral objects"
267+
268+
def __init__(self, o):
269+
self.o_ = o
270+
271+
def get(self):
272+
return self.o_
273+
274+
275+
class WeakFloatingType:
276+
"""Python type representing type of Python floating point objects"""
277+
278+
def __init__(self, o):
279+
self.o_ = o
280+
281+
def get(self):
282+
return self.o_
283+
284+
285+
class WeakComplexType:
286+
"""Python type representing type of Python complex floating point objects"""
287+
288+
def __init__(self, o):
289+
self.o_ = o
290+
291+
def get(self):
292+
return self.o_
293+
294+
295+
def _weak_type_num_kind(o):
296+
_map = {"?": 0, "i": 1, "f": 2, "c": 3}
297+
if isinstance(o, WeakBooleanType):
298+
return _map["?"]
299+
if isinstance(o, WeakIntegralType):
300+
return _map["i"]
301+
if isinstance(o, WeakFloatingType):
302+
return _map["f"]
303+
if isinstance(o, WeakComplexType):
304+
return _map["c"]
305+
raise TypeError(
306+
f"Unexpected type {o} while expecting "
307+
"`WeakBooleanType`, `WeakIntegralType`,"
308+
"`WeakFloatingType`, or `WeakComplexType`."
309+
)
310+
311+
312+
def _strong_dtype_num_kind(o):
313+
_map = {"b": 0, "i": 1, "u": 1, "f": 2, "c": 3}
314+
if not isinstance(o, dpt.dtype):
315+
raise TypeError
316+
k = o.kind
317+
if k in _map:
318+
return _map[k]
319+
raise ValueError(f"Unrecognized kind {k} for dtype {o}")
320+
321+
322+
def _resolve_weak_types(o1_dtype, o2_dtype, dev):
323+
"Resolves weak data type per NEP-0050"
324+
if isinstance(
325+
o1_dtype,
326+
(WeakBooleanType, WeakIntegralType, WeakFloatingType, WeakComplexType),
327+
):
328+
if isinstance(
329+
o2_dtype,
330+
(
331+
WeakBooleanType,
332+
WeakIntegralType,
333+
WeakFloatingType,
334+
WeakComplexType,
335+
),
336+
):
337+
raise ValueError
338+
o1_kind_num = _weak_type_num_kind(o1_dtype)
339+
o2_kind_num = _strong_dtype_num_kind(o2_dtype)
340+
if o1_kind_num > o2_kind_num:
341+
if isinstance(o1_dtype, WeakIntegralType):
342+
return dpt.dtype(ti.default_device_int_type(dev)), o2_dtype
343+
if isinstance(o1_dtype, WeakComplexType):
344+
if o2_dtype is dpt.float16 or o2_dtype is dpt.float32:
345+
return dpt.complex64, o2_dtype
346+
return (
347+
_to_device_supported_dtype(dpt.complex128, dev),
348+
o2_dtype,
349+
)
350+
return _to_device_supported_dtype(dpt.float64, dev), o2_dtype
351+
else:
352+
return o2_dtype, o2_dtype
353+
elif isinstance(
354+
o2_dtype,
355+
(WeakBooleanType, WeakIntegralType, WeakFloatingType, WeakComplexType),
356+
):
357+
o1_kind_num = _strong_dtype_num_kind(o1_dtype)
358+
o2_kind_num = _weak_type_num_kind(o2_dtype)
359+
if o2_kind_num > o1_kind_num:
360+
if isinstance(o2_dtype, WeakIntegralType):
361+
return o1_dtype, dpt.dtype(ti.default_device_int_type(dev))
362+
if isinstance(o2_dtype, WeakComplexType):
363+
if o1_dtype is dpt.float16 or o1_dtype is dpt.float32:
364+
return o1_dtype, dpt.complex64
365+
return o1_dtype, _to_device_supported_dtype(dpt.complex128, dev)
366+
return (
367+
o1_dtype,
368+
_to_device_supported_dtype(dpt.float64, dev),
369+
)
370+
else:
371+
return o1_dtype, o1_dtype
372+
else:
373+
return o1_dtype, o2_dtype
374+
375+
255376
class finfo_object:
256377
"""
257378
`numpy.finfo` subclass which returns Python floating-point scalars for
@@ -407,17 +528,27 @@ def result_type(*arrays_and_dtypes):
407528
"""
408529
dtypes = []
409530
devices = []
531+
weak_dtypes = []
410532
for arg_i in arrays_and_dtypes:
411533
if isinstance(arg_i, dpt.usm_ndarray):
412534
devices.append(arg_i.sycl_device)
413535
dtypes.append(arg_i.dtype)
536+
elif isinstance(arg_i, int):
537+
weak_dtypes.append(WeakIntegralType(arg_i))
538+
elif isinstance(arg_i, float):
539+
weak_dtypes.append(WeakFloatingType(arg_i))
540+
elif isinstance(arg_i, complex):
541+
weak_dtypes.append(WeakComplexType(arg_i))
542+
elif isinstance(arg_i, bool):
543+
weak_dtypes.append(WeakBooleanType(arg_i))
414544
else:
415545
dt = dpt.dtype(arg_i)
416546
_supported_dtype([dt])
417547
dtypes.append(dt)
418548

419549
has_fp16 = True
420550
has_fp64 = True
551+
target_dev = None
421552
if devices:
422553
inspected = False
423554
for d in devices:
@@ -435,17 +566,28 @@ def result_type(*arrays_and_dtypes):
435566
else:
436567
has_fp16 = d.has_aspect_fp16
437568
has_fp64 = d.has_aspect_fp64
569+
target_dev = d
438570
inspected = True
439571

440572
if not (has_fp16 and has_fp64):
441573
for dt in dtypes:
442574
if not _dtype_supported_by_device_impl(dt, has_fp16, has_fp64):
443-
raise ValueError(f"Argument {dt} is not supported by ")
575+
raise ValueError(
576+
f"Argument {dt} is not supported by the device"
577+
)
444578
res_dt = np.result_type(*dtypes)
445579
res_dt = _to_device_supported_dtype_impl(res_dt, has_fp16, has_fp64)
446-
return res_dt
580+
for wdt in weak_dtypes:
581+
pair = _resolve_weak_types(wdt, res_dt, target_dev)
582+
res_dt = np.result_type(*pair)
583+
res_dt = _to_device_supported_dtype_impl(res_dt, has_fp16, has_fp64)
584+
else:
585+
res_dt = np.result_type(*dtypes)
586+
if weak_dtypes:
587+
weak_dt_obj = [wdt.get() for wdt in weak_dtypes]
588+
res_dt = np.result_type(res_dt, *weak_dt_obj)
447589

448-
return np.result_type(*dtypes)
590+
return res_dt
449591

450592

451593
def iinfo(dtype):
@@ -528,8 +670,15 @@ def _supported_dtype(dtypes):
528670
"_acceptance_fn_reciprocal",
529671
"_acceptance_fn_default_binary",
530672
"_acceptance_fn_divide",
673+
"_resolve_weak_types",
674+
"_weak_type_num_kind",
675+
"_strong_dtype_num_kind",
531676
"can_cast",
532677
"finfo",
533678
"iinfo",
534679
"result_type",
680+
"WeakBooleanType",
681+
"WeakIntegralType",
682+
"WeakFloatingType",
683+
"WeakComplexType",
535684
]

0 commit comments

Comments
 (0)