20
20
import pytensor .tensor as pt
21
21
import pytest
22
22
23
+ from numpy .testing import assert_allclose , assert_array_equal
23
24
from pytensor .tensor .variable import TensorConstant
24
25
25
26
import pymc as pm
40
41
UnitSortedVector ,
41
42
Vector ,
42
43
)
43
- from tests .checks import close_to , close_to_logical
44
44
45
45
# some transforms (stick breaking) require addition of small slack in order to be numerically
46
46
# stable. The minimal addable slack for float32 is higher thus we need to be less strict
@@ -61,7 +61,7 @@ def check_transform(transform, domain, constructor=pt.scalar, test=0, rv_var=Non
61
61
assert z .type == x .type
62
62
identity_f = pytensor .function ([x ], z , * rv_inputs )
63
63
for val in domain .vals :
64
- close_to (val , identity_f (val ), tol )
64
+ assert_allclose (val , identity_f (val ), atol = tol )
65
65
66
66
67
67
def check_vector_transform (transform , domain , rv_var = None ):
@@ -117,7 +117,7 @@ def check_jacobian_det(
117
117
)
118
118
119
119
for yval in domain .vals :
120
- np . testing . assert_allclose (actual_ljd (yval ), computed_ljd (yval ), rtol = tol )
120
+ assert_allclose (actual_ljd (yval ), computed_ljd (yval ), rtol = tol )
121
121
122
122
123
123
def test_simplex ():
@@ -132,9 +132,9 @@ def test_simplex():
132
132
def test_simplex_bounds ():
133
133
vals = get_values (tr .simplex , Vector (R , 2 ), pt .vector , floatX (np .array ([0 , 0 ])))
134
134
135
- close_to (vals .sum (axis = 1 ), 1 , tol )
136
- close_to_logical (vals > 0 , True , tol )
137
- close_to_logical (vals < 1 , True , tol )
135
+ assert_allclose (vals .sum (axis = 1 ), 1 , tol )
136
+ assert_array_equal (vals > 0 , True )
137
+ assert_array_equal (vals < 1 , True )
138
138
139
139
check_jacobian_det (
140
140
tr .simplex , Vector (R , 2 ), pt .vector , floatX (np .array ([0 , 0 ])), lambda x : x [:- 1 ]
@@ -145,8 +145,8 @@ def test_simplex_accuracy():
145
145
val = floatX (np .array ([- 30 ]))
146
146
x = pt .vector ("x" )
147
147
x .tag .test_value = val
148
- identity_f = pytensor .function ([x ], tr .simplex .forward (x , tr .simplex .backward (x , x )))
149
- close_to (val , identity_f (val ), tol )
148
+ identity_f = pytensor .function ([x ], tr .simplex .forward (tr .simplex .backward (x )))
149
+ assert_allclose (val , identity_f (val ), tol )
150
150
151
151
152
152
def test_sum_to_1 ():
@@ -179,7 +179,7 @@ def test_log():
179
179
check_jacobian_det (tr .log , Vector (Rplusbig , 2 ), pt .vector , [0 , 0 ], elemwise = True )
180
180
181
181
vals = get_values (tr .log )
182
- close_to_logical (vals > 0 , True , tol )
182
+ assert_array_equal (vals > 0 , True )
183
183
184
184
185
185
@pytest .mark .skipif (
@@ -192,7 +192,7 @@ def test_log_exp_m1():
192
192
check_jacobian_det (tr .log_exp_m1 , Vector (Rplusbig , 2 ), pt .vector , [0 , 0 ], elemwise = True )
193
193
194
194
vals = get_values (tr .log_exp_m1 )
195
- close_to_logical (vals > 0 , True , tol )
195
+ assert_array_equal (vals > 0 , True )
196
196
197
197
198
198
def test_logodds ():
@@ -202,8 +202,8 @@ def test_logodds():
202
202
check_jacobian_det (tr .logodds , Vector (Unit , 2 ), pt .vector , [0.5 , 0.5 ], elemwise = True )
203
203
204
204
vals = get_values (tr .logodds )
205
- close_to_logical (vals > 0 , True , tol )
206
- close_to_logical (vals < 1 , True , tol )
205
+ assert_array_equal (vals > 0 , True )
206
+ assert_array_equal (vals < 1 , True )
207
207
208
208
209
209
def test_lowerbound ():
@@ -214,7 +214,7 @@ def test_lowerbound():
214
214
check_jacobian_det (trans , Vector (Rplusbig , 2 ), pt .vector , [0 , 0 ], elemwise = True )
215
215
216
216
vals = get_values (trans )
217
- close_to_logical (vals > 0 , True , tol )
217
+ assert_array_equal (vals > 0 , True )
218
218
219
219
220
220
def test_upperbound ():
@@ -225,7 +225,7 @@ def test_upperbound():
225
225
check_jacobian_det (trans , Vector (Rminusbig , 2 ), pt .vector , [- 1 , - 1 ], elemwise = True )
226
226
227
227
vals = get_values (trans )
228
- close_to_logical (vals < 0 , True , tol )
228
+ assert_array_equal (vals < 0 , True )
229
229
230
230
231
231
def test_interval ():
@@ -238,8 +238,8 @@ def test_interval():
238
238
check_jacobian_det (trans , domain , elemwise = True )
239
239
240
240
vals = get_values (trans )
241
- close_to_logical (vals > a , True , tol )
242
- close_to_logical (vals < b , True , tol )
241
+ assert_array_equal (vals > a , True )
242
+ assert_array_equal (vals < b , True )
243
243
244
244
245
245
@pytest .mark .skipif (
@@ -254,7 +254,7 @@ def test_interval_near_boundary():
254
254
pm .Uniform ("x" , initval = x0 , lower = lb , upper = ub )
255
255
256
256
log_prob = model .point_logps ()
257
- np . testing . assert_allclose (list (log_prob .values ()), floatX (np .array ([- 52.68 ])))
257
+ assert_allclose (list (log_prob .values ()), floatX (np .array ([- 52.68 ])))
258
258
259
259
260
260
def test_circular ():
@@ -264,8 +264,8 @@ def test_circular():
264
264
check_jacobian_det (trans , Circ )
265
265
266
266
vals = get_values (trans )
267
- close_to_logical (vals > - np .pi , True , tol )
268
- close_to_logical (vals < np .pi , True , tol )
267
+ assert_array_equal (vals > - np .pi , True )
268
+ assert_array_equal (vals < np .pi , True )
269
269
270
270
assert isinstance (trans .forward (1 , None ), TensorConstant )
271
271
@@ -281,13 +281,13 @@ def test_ordered():
281
281
)
282
282
283
283
vals = get_values (tr .ordered , Vector (R , 3 ), pt .vector , floatX (np .zeros (3 )))
284
- close_to_logical (np .diff (vals ) >= 0 , True , tol )
284
+ assert_array_equal (np .diff (vals ) >= 0 , True )
285
285
286
286
287
287
def test_chain_values ():
288
288
chain_tranf = tr .Chain ([tr .logodds , tr .ordered ])
289
289
vals = get_values (chain_tranf , Vector (R , 5 ), pt .vector , floatX (np .zeros (5 )))
290
- close_to_logical (np .diff (vals ) >= 0 , True , tol )
290
+ assert_array_equal (np .diff (vals ) >= 0 , True )
291
291
292
292
293
293
def test_chain_vector_transform ():
@@ -339,7 +339,7 @@ def check_transform_elementwise_logp(self, model, vector_transform=False):
339
339
untransform_logp_eval = untransform_logp .eval ({x_val_untransf : test_array_untransf })
340
340
log_jac_det_eval = log_jac_det .eval ({x_val_transf : test_array_transf })
341
341
# Summing the log_jac_det separately from the untransform_logp ensures there is no broadcasting between terms
342
- np . testing . assert_allclose (
342
+ assert_allclose (
343
343
transform_logp_eval .sum (),
344
344
untransform_logp_eval .sum () + log_jac_det_eval .sum (),
345
345
rtol = tol ,
0 commit comments