@@ -252,6 +252,127 @@ def _find_buf_dtype2(arg1_dtype, arg2_dtype, query_fn, sycl_dev, acceptance_fn):
252
252
return None , None , None
253
253
254
254
255
+ class WeakBooleanType :
256
+ "Python type representing type of Python boolean objects"
257
+
258
+ def __init__ (self , o ):
259
+ self .o_ = o
260
+
261
+ def get (self ):
262
+ return self .o_
263
+
264
+
265
+ class WeakIntegralType :
266
+ "Python type representing type of Python integral objects"
267
+
268
+ def __init__ (self , o ):
269
+ self .o_ = o
270
+
271
+ def get (self ):
272
+ return self .o_
273
+
274
+
275
+ class WeakFloatingType :
276
+ """Python type representing type of Python floating point objects"""
277
+
278
+ def __init__ (self , o ):
279
+ self .o_ = o
280
+
281
+ def get (self ):
282
+ return self .o_
283
+
284
+
285
+ class WeakComplexType :
286
+ """Python type representing type of Python complex floating point objects"""
287
+
288
+ def __init__ (self , o ):
289
+ self .o_ = o
290
+
291
+ def get (self ):
292
+ return self .o_
293
+
294
+
295
+ def _weak_type_num_kind (o ):
296
+ _map = {"?" : 0 , "i" : 1 , "f" : 2 , "c" : 3 }
297
+ if isinstance (o , WeakBooleanType ):
298
+ return _map ["?" ]
299
+ if isinstance (o , WeakIntegralType ):
300
+ return _map ["i" ]
301
+ if isinstance (o , WeakFloatingType ):
302
+ return _map ["f" ]
303
+ if isinstance (o , WeakComplexType ):
304
+ return _map ["c" ]
305
+ raise TypeError (
306
+ f"Unexpected type { o } while expecting "
307
+ "`WeakBooleanType`, `WeakIntegralType`,"
308
+ "`WeakFloatingType`, or `WeakComplexType`."
309
+ )
310
+
311
+
312
+ def _strong_dtype_num_kind (o ):
313
+ _map = {"b" : 0 , "i" : 1 , "u" : 1 , "f" : 2 , "c" : 3 }
314
+ if not isinstance (o , dpt .dtype ):
315
+ raise TypeError
316
+ k = o .kind
317
+ if k in _map :
318
+ return _map [k ]
319
+ raise ValueError (f"Unrecognized kind { k } for dtype { o } " )
320
+
321
+
322
+ def _resolve_weak_types (o1_dtype , o2_dtype , dev ):
323
+ "Resolves weak data type per NEP-0050"
324
+ if isinstance (
325
+ o1_dtype ,
326
+ (WeakBooleanType , WeakIntegralType , WeakFloatingType , WeakComplexType ),
327
+ ):
328
+ if isinstance (
329
+ o2_dtype ,
330
+ (
331
+ WeakBooleanType ,
332
+ WeakIntegralType ,
333
+ WeakFloatingType ,
334
+ WeakComplexType ,
335
+ ),
336
+ ):
337
+ raise ValueError
338
+ o1_kind_num = _weak_type_num_kind (o1_dtype )
339
+ o2_kind_num = _strong_dtype_num_kind (o2_dtype )
340
+ if o1_kind_num > o2_kind_num :
341
+ if isinstance (o1_dtype , WeakIntegralType ):
342
+ return dpt .dtype (ti .default_device_int_type (dev )), o2_dtype
343
+ if isinstance (o1_dtype , WeakComplexType ):
344
+ if o2_dtype is dpt .float16 or o2_dtype is dpt .float32 :
345
+ return dpt .complex64 , o2_dtype
346
+ return (
347
+ _to_device_supported_dtype (dpt .complex128 , dev ),
348
+ o2_dtype ,
349
+ )
350
+ return _to_device_supported_dtype (dpt .float64 , dev ), o2_dtype
351
+ else :
352
+ return o2_dtype , o2_dtype
353
+ elif isinstance (
354
+ o2_dtype ,
355
+ (WeakBooleanType , WeakIntegralType , WeakFloatingType , WeakComplexType ),
356
+ ):
357
+ o1_kind_num = _strong_dtype_num_kind (o1_dtype )
358
+ o2_kind_num = _weak_type_num_kind (o2_dtype )
359
+ if o2_kind_num > o1_kind_num :
360
+ if isinstance (o2_dtype , WeakIntegralType ):
361
+ return o1_dtype , dpt .dtype (ti .default_device_int_type (dev ))
362
+ if isinstance (o2_dtype , WeakComplexType ):
363
+ if o1_dtype is dpt .float16 or o1_dtype is dpt .float32 :
364
+ return o1_dtype , dpt .complex64
365
+ return o1_dtype , _to_device_supported_dtype (dpt .complex128 , dev )
366
+ return (
367
+ o1_dtype ,
368
+ _to_device_supported_dtype (dpt .float64 , dev ),
369
+ )
370
+ else :
371
+ return o1_dtype , o1_dtype
372
+ else :
373
+ return o1_dtype , o2_dtype
374
+
375
+
255
376
class finfo_object :
256
377
"""
257
378
`numpy.finfo` subclass which returns Python floating-point scalars for
@@ -407,17 +528,27 @@ def result_type(*arrays_and_dtypes):
407
528
"""
408
529
dtypes = []
409
530
devices = []
531
+ weak_dtypes = []
410
532
for arg_i in arrays_and_dtypes :
411
533
if isinstance (arg_i , dpt .usm_ndarray ):
412
534
devices .append (arg_i .sycl_device )
413
535
dtypes .append (arg_i .dtype )
536
+ elif isinstance (arg_i , int ):
537
+ weak_dtypes .append (WeakIntegralType (arg_i ))
538
+ elif isinstance (arg_i , float ):
539
+ weak_dtypes .append (WeakFloatingType (arg_i ))
540
+ elif isinstance (arg_i , complex ):
541
+ weak_dtypes .append (WeakComplexType (arg_i ))
542
+ elif isinstance (arg_i , bool ):
543
+ weak_dtypes .append (WeakBooleanType (arg_i ))
414
544
else :
415
545
dt = dpt .dtype (arg_i )
416
546
_supported_dtype ([dt ])
417
547
dtypes .append (dt )
418
548
419
549
has_fp16 = True
420
550
has_fp64 = True
551
+ target_dev = None
421
552
if devices :
422
553
inspected = False
423
554
for d in devices :
@@ -435,17 +566,28 @@ def result_type(*arrays_and_dtypes):
435
566
else :
436
567
has_fp16 = d .has_aspect_fp16
437
568
has_fp64 = d .has_aspect_fp64
569
+ target_dev = d
438
570
inspected = True
439
571
440
572
if not (has_fp16 and has_fp64 ):
441
573
for dt in dtypes :
442
574
if not _dtype_supported_by_device_impl (dt , has_fp16 , has_fp64 ):
443
- raise ValueError (f"Argument { dt } is not supported by " )
575
+ raise ValueError (
576
+ f"Argument { dt } is not supported by the device"
577
+ )
444
578
res_dt = np .result_type (* dtypes )
445
579
res_dt = _to_device_supported_dtype_impl (res_dt , has_fp16 , has_fp64 )
446
- return res_dt
580
+ for wdt in weak_dtypes :
581
+ pair = _resolve_weak_types (wdt , res_dt , target_dev )
582
+ res_dt = np .result_type (* pair )
583
+ res_dt = _to_device_supported_dtype_impl (res_dt , has_fp16 , has_fp64 )
584
+ else :
585
+ res_dt = np .result_type (* dtypes )
586
+ if weak_dtypes :
587
+ weak_dt_obj = [wdt .get () for wdt in weak_dtypes ]
588
+ res_dt = np .result_type (res_dt , * weak_dt_obj )
447
589
448
- return np . result_type ( * dtypes )
590
+ return res_dt
449
591
450
592
451
593
def iinfo (dtype ):
@@ -528,8 +670,15 @@ def _supported_dtype(dtypes):
528
670
"_acceptance_fn_reciprocal" ,
529
671
"_acceptance_fn_default_binary" ,
530
672
"_acceptance_fn_divide" ,
673
+ "_resolve_weak_types" ,
674
+ "_weak_type_num_kind" ,
675
+ "_strong_dtype_num_kind" ,
531
676
"can_cast" ,
532
677
"finfo" ,
533
678
"iinfo" ,
534
679
"result_type" ,
680
+ "WeakBooleanType" ,
681
+ "WeakIntegralType" ,
682
+ "WeakFloatingType" ,
683
+ "WeakComplexType" ,
535
684
]
0 commit comments