@@ -67,7 +67,45 @@ def __init__(self, seed=None, device=None, sycl_queue=None):
67
67
self ._def_float_type = dpnp .float64
68
68
69
69
self ._random_state = MT19937 (self ._seed , self ._sycl_queue )
70
- self ._fallback_random_state = call_origin (numpy .random .RandomState , seed )
70
+ self ._fallback_random_state = call_origin (numpy .random .RandomState , seed , allow_fallback = True )
71
+
72
+
73
+ def _is_finite_scalar (self , x ):
74
+ """
75
+ Test a scalar for finiteness (not infinity and not Not a Number).
76
+
77
+ Parameters
78
+ -----------
79
+ x : input value for test, must be a scalar.
80
+
81
+ Returns
82
+ -------
83
+ True where ``x`` is not positive infinity, negative infinity, or NaN;
84
+ false otherwise.
85
+ """
86
+
87
+ # TODO: replace with dpnp.isfinite() once function is available in DPNP,
88
+ # but for now use direct numpy calls without call_origin() wrapper, since data is a scalar
89
+ return numpy .isfinite (x )
90
+
91
+
92
+ def _is_signbit_scalar (self , x ):
93
+ """
94
+ Test a scalar if sign bit is set for it (less than zero).
95
+
96
+ Parameters
97
+ -----------
98
+ x : input value for test, must be a scalar.
99
+
100
+ Returns
101
+ -------
102
+ True where sign bit is set for ``x`` (that is ``x`` is less than zero);
103
+ false otherwise.
104
+ """
105
+
106
+ # TODO: replace with dpnp.signbit() once function is available in DPNP,
107
+ # but for now use direct numpy calls without call_origin() wrapper, since data is a scalar
108
+ return numpy .signbit (x )
71
109
72
110
73
111
def get_state (self ):
@@ -125,13 +163,14 @@ def normal(self, loc=0.0, scale=1.0, size=None, dtype=None, usm_type="device"):
125
163
else :
126
164
min_double = numpy .finfo ('double' ).min
127
165
max_double = numpy .finfo ('double' ).max
128
- if (loc >= max_double or loc <= min_double ) and dpnp .isfinite (loc ):
166
+
167
+ if (loc >= max_double or loc <= min_double ) and self ._is_finite_scalar (loc ):
129
168
raise OverflowError (f"Range of loc={ loc } exceeds valid bounds" )
130
169
131
- if (scale >= max_double ) and dpnp . isfinite (scale ):
170
+ if (scale >= max_double ) and self . _is_finite_scalar (scale ):
132
171
raise OverflowError (f"Range of scale={ scale } exceeds valid bounds" )
133
- # # scale = -0.0 is cosidered as negative
134
- elif scale < 0 or scale == 0 and numpy . signbit (scale ):
172
+ # scale = -0.0 is cosidered as negative
173
+ elif scale < 0 or scale == 0 and self . _is_signbit_scalar (scale ):
135
174
raise ValueError (f"scale={ scale } , but must be non-negative." )
136
175
137
176
if dtype is None :
@@ -198,7 +237,8 @@ def randint(self, low, high=None, size=None, dtype=int, usm_type="device"):
198
237
Limitations
199
238
-----------
200
239
Parameters ``low`` and ``high`` are supported only as scalar.
201
- Parameter ``dtype`` is supported only as `int`.
240
+ Parameter ``dtype`` is supported only as :obj:`dpnp.int32` or `int`,
241
+ but `int` value is considered to be exactly equivalent to :obj:`dpnp.int32`.
202
242
Otherwise, :obj:`numpy.random.randint(low, high, size, dtype)` samples are drawn.
203
243
204
244
Examples
@@ -230,9 +270,10 @@ def randint(self, low, high=None, size=None, dtype=int, usm_type="device"):
230
270
231
271
min_int = numpy .iinfo ('int32' ).min
232
272
max_int = numpy .iinfo ('int32' ).max
233
- if not dpnp .isfinite (low ) or low > max_int or low < min_int :
273
+
274
+ if not self ._is_finite_scalar (low ) or low > max_int or low < min_int :
234
275
raise OverflowError (f"Range of low={ low } exceeds valid bounds" )
235
- elif not dpnp . isfinite (high ) or high > max_int or high < min_int :
276
+ elif not self . _is_finite_scalar (high ) or high > max_int or high < min_int :
236
277
raise OverflowError (f"Range of high={ high } exceeds valid bounds" )
237
278
238
279
low = int (low )
@@ -400,9 +441,10 @@ def uniform(self, low=0.0, high=1.0, size=None, dtype=None, usm_type="device"):
400
441
else :
401
442
min_double = numpy .finfo ('double' ).min
402
443
max_double = numpy .finfo ('double' ).max
403
- if not dpnp .isfinite (low ) or low >= max_double or low <= min_double :
444
+
445
+ if not self ._is_finite_scalar (low ) or low >= max_double or low <= min_double :
404
446
raise OverflowError (f"Range of low={ low } exceeds valid bounds" )
405
- elif not dpnp . isfinite (high ) or high >= max_double or high <= min_double :
447
+ elif not self . _is_finite_scalar (high ) or high >= max_double or high <= min_double :
406
448
raise OverflowError (f"Range of high={ high } exceeds valid bounds" )
407
449
408
450
if low > high :
0 commit comments