15
15
)
16
16
17
17
18
+ def assert_cfd (data , exp_sycl_queue , exp_usm_type = None ):
19
+ assert exp_sycl_queue == data .sycl_queue
20
+ if exp_usm_type :
21
+ assert exp_usm_type == data .usm_type
22
+
23
+
18
24
class TestNormal :
19
25
@pytest .mark .parametrize ("dtype" ,
20
26
[dpnp .float32 , dpnp .float64 , None ],
@@ -47,7 +53,7 @@ def test_distr(self, dtype, usm_type):
47
53
assert_array_almost_equal (dpnp .asnumpy (data ), desired , decimal = precision )
48
54
49
55
# check if compute follows data isn't broken
50
- assert sycl_queue == data . sycl_queue
56
+ assert_cfd ( data , sycl_queue , usm_type )
51
57
52
58
53
59
@pytest .mark .parametrize ("dtype" ,
@@ -138,7 +144,7 @@ def test_fallback(self, loc, scale):
138
144
assert_array_almost_equal (actual , desired , decimal = precision )
139
145
140
146
# check if compute follows data isn't broken
141
- assert sycl_queue == data . sycl_queue
147
+ assert_cfd ( data , sycl_queue )
142
148
143
149
144
150
@pytest .mark .parametrize ("dtype" ,
@@ -174,17 +180,17 @@ def test_distr(self, usm_type):
174
180
175
181
precision = numpy .finfo (dtype = numpy .float64 ).precision
176
182
assert_array_almost_equal (dpnp .asnumpy (data ), desired , decimal = precision )
177
- assert sycl_queue == data . sycl_queue
183
+ assert_cfd ( data , sycl_queue , usm_type )
178
184
179
185
# call with the same seed has to draw the same values
180
186
data = RandomState (seed , sycl_queue = sycl_queue ).rand (3 , 2 , usm_type = usm_type )
181
187
assert_array_almost_equal (dpnp .asnumpy (data ), desired , decimal = precision )
182
- assert sycl_queue == data . sycl_queue
188
+ assert_cfd ( data , sycl_queue , usm_type )
183
189
184
190
# call with omitted dimensions has to draw the first element from desired
185
191
data = RandomState (seed , sycl_queue = sycl_queue ).rand (usm_type = usm_type )
186
192
assert_array_almost_equal (dpnp .asnumpy (data ), desired [0 , 0 ], decimal = precision )
187
- assert sycl_queue == data . sycl_queue
193
+ assert_cfd ( data , sycl_queue , usm_type )
188
194
189
195
# rand() is an alias on random_sample(), map arguments
190
196
with mock .patch ('dpnp.random.RandomState.random_sample' ) as m :
@@ -245,7 +251,7 @@ def test_distr(self, dtype, usm_type):
245
251
[5 , 3 ],
246
252
[5 , 7 ]], dtype = numpy .int32 )
247
253
assert_array_equal (dpnp .asnumpy (data ), desired )
248
- assert sycl_queue == data . sycl_queue
254
+ assert_cfd ( data , sycl_queue , usm_type )
249
255
250
256
# call with the same seed has to draw the same values
251
257
data = RandomState (seed , sycl_queue = sycl_queue ).randint (low = low ,
@@ -254,15 +260,15 @@ def test_distr(self, dtype, usm_type):
254
260
dtype = dtype ,
255
261
usm_type = usm_type )
256
262
assert_array_equal (dpnp .asnumpy (data ), desired )
257
- assert sycl_queue == data . sycl_queue
263
+ assert_cfd ( data , sycl_queue , usm_type )
258
264
259
265
# call with omitted dimensions has to draw the first element from desired
260
266
data = RandomState (seed , sycl_queue = sycl_queue ).randint (low = low ,
261
267
high = high ,
262
268
dtype = dtype ,
263
269
usm_type = usm_type )
264
270
assert_array_equal (dpnp .asnumpy (data ), desired [0 , 0 ])
265
- assert sycl_queue == data . sycl_queue
271
+ assert_cfd ( data , sycl_queue , usm_type )
266
272
267
273
# rand() is an alias on random_sample(), map arguments
268
274
with mock .patch ('dpnp.random.RandomState.uniform' ) as m :
@@ -701,7 +707,7 @@ def test_distr(self, bounds, dtype, usm_type):
701
707
assert_array_equal (dpnp .asnumpy (data ), desired )
702
708
703
709
# check if compute follows data isn't broken
704
- assert sycl_queue == data . sycl_queue
710
+ assert_cfd ( data , sycl_queue , usm_type )
705
711
706
712
707
713
@pytest .mark .parametrize ("dtype" ,
@@ -766,7 +772,7 @@ def test_fallback(self, low, high):
766
772
assert_array_almost_equal (actual , desired , decimal = precision )
767
773
768
774
# check if compute follows data isn't broken
769
- assert sycl_queue == data . sycl_queue
775
+ assert_cfd ( data , sycl_queue )
770
776
771
777
772
778
@pytest .mark .parametrize ("dtype" ,
0 commit comments