@@ -252,6 +252,127 @@ def _find_buf_dtype2(arg1_dtype, arg2_dtype, query_fn, sycl_dev, acceptance_fn):
252252 return None , None , None
253253
254254
255+ class WeakBooleanType :
256+ "Python type representing type of Python boolean objects"
257+
258+ def __init__ (self , o ):
259+ self .o_ = o
260+
261+ def get (self ):
262+ return self .o_
263+
264+
265+ class WeakIntegralType :
266+ "Python type representing type of Python integral objects"
267+
268+ def __init__ (self , o ):
269+ self .o_ = o
270+
271+ def get (self ):
272+ return self .o_
273+
274+
275+ class WeakFloatingType :
276+ """Python type representing type of Python floating point objects"""
277+
278+ def __init__ (self , o ):
279+ self .o_ = o
280+
281+ def get (self ):
282+ return self .o_
283+
284+
285+ class WeakComplexType :
286+ """Python type representing type of Python complex floating point objects"""
287+
288+ def __init__ (self , o ):
289+ self .o_ = o
290+
291+ def get (self ):
292+ return self .o_
293+
294+
295+ def _weak_type_num_kind (o ):
296+ _map = {"?" : 0 , "i" : 1 , "f" : 2 , "c" : 3 }
297+ if isinstance (o , WeakBooleanType ):
298+ return _map ["?" ]
299+ if isinstance (o , WeakIntegralType ):
300+ return _map ["i" ]
301+ if isinstance (o , WeakFloatingType ):
302+ return _map ["f" ]
303+ if isinstance (o , WeakComplexType ):
304+ return _map ["c" ]
305+ raise TypeError (
306+ f"Unexpected type { o } while expecting "
307+ "`WeakBooleanType`, `WeakIntegralType`,"
308+ "`WeakFloatingType`, or `WeakComplexType`."
309+ )
310+
311+
312+ def _strong_dtype_num_kind (o ):
313+ _map = {"b" : 0 , "i" : 1 , "u" : 1 , "f" : 2 , "c" : 3 }
314+ if not isinstance (o , dpt .dtype ):
315+ raise TypeError
316+ k = o .kind
317+ if k in _map :
318+ return _map [k ]
319+ raise ValueError (f"Unrecognized kind { k } for dtype { o } " )
320+
321+
322+ def _resolve_weak_types (o1_dtype , o2_dtype , dev ):
323+ "Resolves weak data type per NEP-0050"
324+ if isinstance (
325+ o1_dtype ,
326+ (WeakBooleanType , WeakIntegralType , WeakFloatingType , WeakComplexType ),
327+ ):
328+ if isinstance (
329+ o2_dtype ,
330+ (
331+ WeakBooleanType ,
332+ WeakIntegralType ,
333+ WeakFloatingType ,
334+ WeakComplexType ,
335+ ),
336+ ):
337+ raise ValueError
338+ o1_kind_num = _weak_type_num_kind (o1_dtype )
339+ o2_kind_num = _strong_dtype_num_kind (o2_dtype )
340+ if o1_kind_num > o2_kind_num :
341+ if isinstance (o1_dtype , WeakIntegralType ):
342+ return dpt .dtype (ti .default_device_int_type (dev )), o2_dtype
343+ if isinstance (o1_dtype , WeakComplexType ):
344+ if o2_dtype is dpt .float16 or o2_dtype is dpt .float32 :
345+ return dpt .complex64 , o2_dtype
346+ return (
347+ _to_device_supported_dtype (dpt .complex128 , dev ),
348+ o2_dtype ,
349+ )
350+ return _to_device_supported_dtype (dpt .float64 , dev ), o2_dtype
351+ else :
352+ return o2_dtype , o2_dtype
353+ elif isinstance (
354+ o2_dtype ,
355+ (WeakBooleanType , WeakIntegralType , WeakFloatingType , WeakComplexType ),
356+ ):
357+ o1_kind_num = _strong_dtype_num_kind (o1_dtype )
358+ o2_kind_num = _weak_type_num_kind (o2_dtype )
359+ if o2_kind_num > o1_kind_num :
360+ if isinstance (o2_dtype , WeakIntegralType ):
361+ return o1_dtype , dpt .dtype (ti .default_device_int_type (dev ))
362+ if isinstance (o2_dtype , WeakComplexType ):
363+ if o1_dtype is dpt .float16 or o1_dtype is dpt .float32 :
364+ return o1_dtype , dpt .complex64
365+ return o1_dtype , _to_device_supported_dtype (dpt .complex128 , dev )
366+ return (
367+ o1_dtype ,
368+ _to_device_supported_dtype (dpt .float64 , dev ),
369+ )
370+ else :
371+ return o1_dtype , o1_dtype
372+ else :
373+ return o1_dtype , o2_dtype
374+
375+
255376class finfo_object :
256377 """
257378 `numpy.finfo` subclass which returns Python floating-point scalars for
@@ -407,17 +528,27 @@ def result_type(*arrays_and_dtypes):
407528 """
408529 dtypes = []
409530 devices = []
531+ weak_dtypes = []
410532 for arg_i in arrays_and_dtypes :
411533 if isinstance (arg_i , dpt .usm_ndarray ):
412534 devices .append (arg_i .sycl_device )
413535 dtypes .append (arg_i .dtype )
536+ elif isinstance (arg_i , int ):
537+ weak_dtypes .append (WeakIntegralType (arg_i ))
538+ elif isinstance (arg_i , float ):
539+ weak_dtypes .append (WeakFloatingType (arg_i ))
540+ elif isinstance (arg_i , complex ):
541+ weak_dtypes .append (WeakComplexType (arg_i ))
542+ elif isinstance (arg_i , bool ):
543+ weak_dtypes .append (WeakBooleanType (arg_i ))
414544 else :
415545 dt = dpt .dtype (arg_i )
416546 _supported_dtype ([dt ])
417547 dtypes .append (dt )
418548
419549 has_fp16 = True
420550 has_fp64 = True
551+ target_dev = None
421552 if devices :
422553 inspected = False
423554 for d in devices :
@@ -435,17 +566,28 @@ def result_type(*arrays_and_dtypes):
435566 else :
436567 has_fp16 = d .has_aspect_fp16
437568 has_fp64 = d .has_aspect_fp64
569+ target_dev = d
438570 inspected = True
439571
440572 if not (has_fp16 and has_fp64 ):
441573 for dt in dtypes :
442574 if not _dtype_supported_by_device_impl (dt , has_fp16 , has_fp64 ):
443- raise ValueError (f"Argument { dt } is not supported by " )
575+ raise ValueError (
576+ f"Argument { dt } is not supported by the device"
577+ )
444578 res_dt = np .result_type (* dtypes )
445579 res_dt = _to_device_supported_dtype_impl (res_dt , has_fp16 , has_fp64 )
446- return res_dt
580+ for wdt in weak_dtypes :
581+ pair = _resolve_weak_types (wdt , res_dt , target_dev )
582+ res_dt = np .result_type (* pair )
583+ res_dt = _to_device_supported_dtype_impl (res_dt , has_fp16 , has_fp64 )
584+ else :
585+ res_dt = np .result_type (* dtypes )
586+ if weak_dtypes :
587+ weak_dt_obj = [wdt .get () for wdt in weak_dtypes ]
588+ res_dt = np .result_type (res_dt , * weak_dt_obj )
447589
448- return np . result_type ( * dtypes )
590+ return res_dt
449591
450592
451593def iinfo (dtype ):
@@ -528,8 +670,15 @@ def _supported_dtype(dtypes):
528670 "_acceptance_fn_reciprocal" ,
529671 "_acceptance_fn_default_binary" ,
530672 "_acceptance_fn_divide" ,
673+ "_resolve_weak_types" ,
674+ "_weak_type_num_kind" ,
675+ "_strong_dtype_num_kind" ,
531676 "can_cast" ,
532677 "finfo" ,
533678 "iinfo" ,
534679 "result_type" ,
680+ "WeakBooleanType" ,
681+ "WeakIntegralType" ,
682+ "WeakFloatingType" ,
683+ "WeakComplexType" ,
535684]
0 commit comments