@@ -594,7 +594,8 @@ def eye(N, M=None, k=0, dtype=None, order='C', **kwargs):
594
594
elif len (kwargs ) != 0 :
595
595
pass
596
596
else :
597
- return dpnp_eye (N , M = M , k = k , dtype = dtype ).get_pyobj ()
597
+ _dtype = dpnp .default_float_type () if dtype is None else dtype
598
+ return dpnp_eye (N , M = M , k = k , dtype = _dtype ).get_pyobj ()
598
599
599
600
return call_origin (numpy .eye , N , M = M , k = k , dtype = dtype , order = order , ** kwargs )
600
601
@@ -859,10 +860,8 @@ def identity(n, dtype=None, *, like=None):
859
860
elif n < 0 :
860
861
pass
861
862
else :
862
- if dtype is None :
863
- sycl_queue = dpnp .get_normalized_queue_device (sycl_queue = None , device = None )
864
- dtype = map_dtype_to_device (dpnp .float64 , sycl_queue .sycl_device )
865
- return dpnp_identity (n , dtype ).get_pyobj ()
863
+ _dtype = dpnp .default_float_type () if dtype is None else dtype
864
+ return dpnp_identity (n , _dtype ).get_pyobj ()
866
865
867
866
return call_origin (numpy .identity , n , dtype = dtype , like = like )
868
867
@@ -1315,10 +1314,8 @@ def tri(N, M=None, k=0, dtype=dpnp.float, **kwargs):
1315
1314
elif not isinstance (k , int ):
1316
1315
pass
1317
1316
else :
1318
- if dtype is dpnp .float :
1319
- sycl_queue = dpnp .get_normalized_queue_device (sycl_queue = None , device = None )
1320
- dtype = map_dtype_to_device (dpnp .float64 , sycl_queue .sycl_device )
1321
- return dpnp_tri (N , M , k , dtype ).get_pyobj ()
1317
+ _dtype = dpnp .default_float_type () if dtype in (dpnp .float , None ) else dtype
1318
+ return dpnp_tri (N , M , k , _dtype ).get_pyobj ()
1322
1319
1323
1320
return call_origin (numpy .tri , N , M , k , dtype , ** kwargs )
1324
1321
0 commit comments