Skip to content

Commit eb1d63a

Browse files
authored
Use numpy testing utilities instead of custom close_to* (#6961)
* Use np.testing asserts in test_transform.py * Fix test_simplex_accuracy - The previous identity_f function returned an empty array. This was missed by our check but captured by numpy. * Replace custom close_to function in tests/helpers.py * Replace custom close_to function in test_dist_math.py * Replace custom close_to function in tuning/test_starting.py * Remove reduntant file tests/checks.py * Use absolute tolerance in check_transform
1 parent 03f9f72 commit eb1d63a

File tree

6 files changed

+49
-75
lines changed

6 files changed

+49
-75
lines changed

tests/checks.py

-27
This file was deleted.

tests/distributions/test_dist_math.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import numpy as np
15-
import numpy.testing as npt
1615
import pytensor
1716
import pytensor.tensor as pt
1817
import pytest
@@ -36,7 +35,6 @@
3635
)
3736
from pymc.logprob.utils import ParameterValueError
3837
from pymc.pytensorf import floatX
39-
from tests.checks import close_to
4038
from tests.helpers import verify_grad
4139

4240

@@ -160,7 +158,7 @@ def test_clipped_beta_rvs(dtype):
160158

161159
def check_vals(fn1, fn2, *args):
162160
v = fn1(*args)
163-
close_to(v, fn2(*args), 1e-6 if v.dtype == np.float64 else 1e-4)
161+
np.testing.assert_allclose(v, fn2(*args), atol=1e-6 if v.dtype == np.float64 else 1e-4)
164162

165163

166164
def test_multigamma():

tests/distributions/test_transform.py

+22-22
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import pytensor.tensor as pt
2121
import pytest
2222

23+
from numpy.testing import assert_allclose, assert_array_equal
2324
from pytensor.tensor.variable import TensorConstant
2425

2526
import pymc as pm
@@ -40,7 +41,6 @@
4041
UnitSortedVector,
4142
Vector,
4243
)
43-
from tests.checks import close_to, close_to_logical
4444

4545
# some transforms (stick breaking) require addition of small slack in order to be numerically
4646
# stable. The minimal addable slack for float32 is higher thus we need to be less strict
@@ -61,7 +61,7 @@ def check_transform(transform, domain, constructor=pt.scalar, test=0, rv_var=Non
6161
assert z.type == x.type
6262
identity_f = pytensor.function([x], z, *rv_inputs)
6363
for val in domain.vals:
64-
close_to(val, identity_f(val), tol)
64+
assert_allclose(val, identity_f(val), atol=tol)
6565

6666

6767
def check_vector_transform(transform, domain, rv_var=None):
@@ -117,7 +117,7 @@ def check_jacobian_det(
117117
)
118118

119119
for yval in domain.vals:
120-
np.testing.assert_allclose(actual_ljd(yval), computed_ljd(yval), rtol=tol)
120+
assert_allclose(actual_ljd(yval), computed_ljd(yval), rtol=tol)
121121

122122

123123
def test_simplex():
@@ -132,9 +132,9 @@ def test_simplex():
132132
def test_simplex_bounds():
133133
vals = get_values(tr.simplex, Vector(R, 2), pt.vector, floatX(np.array([0, 0])))
134134

135-
close_to(vals.sum(axis=1), 1, tol)
136-
close_to_logical(vals > 0, True, tol)
137-
close_to_logical(vals < 1, True, tol)
135+
assert_allclose(vals.sum(axis=1), 1, tol)
136+
assert_array_equal(vals > 0, True)
137+
assert_array_equal(vals < 1, True)
138138

139139
check_jacobian_det(
140140
tr.simplex, Vector(R, 2), pt.vector, floatX(np.array([0, 0])), lambda x: x[:-1]
@@ -145,8 +145,8 @@ def test_simplex_accuracy():
145145
val = floatX(np.array([-30]))
146146
x = pt.vector("x")
147147
x.tag.test_value = val
148-
identity_f = pytensor.function([x], tr.simplex.forward(x, tr.simplex.backward(x, x)))
149-
close_to(val, identity_f(val), tol)
148+
identity_f = pytensor.function([x], tr.simplex.forward(tr.simplex.backward(x)))
149+
assert_allclose(val, identity_f(val), tol)
150150

151151

152152
def test_sum_to_1():
@@ -179,7 +179,7 @@ def test_log():
179179
check_jacobian_det(tr.log, Vector(Rplusbig, 2), pt.vector, [0, 0], elemwise=True)
180180

181181
vals = get_values(tr.log)
182-
close_to_logical(vals > 0, True, tol)
182+
assert_array_equal(vals > 0, True)
183183

184184

185185
@pytest.mark.skipif(
@@ -192,7 +192,7 @@ def test_log_exp_m1():
192192
check_jacobian_det(tr.log_exp_m1, Vector(Rplusbig, 2), pt.vector, [0, 0], elemwise=True)
193193

194194
vals = get_values(tr.log_exp_m1)
195-
close_to_logical(vals > 0, True, tol)
195+
assert_array_equal(vals > 0, True)
196196

197197

198198
def test_logodds():
@@ -202,8 +202,8 @@ def test_logodds():
202202
check_jacobian_det(tr.logodds, Vector(Unit, 2), pt.vector, [0.5, 0.5], elemwise=True)
203203

204204
vals = get_values(tr.logodds)
205-
close_to_logical(vals > 0, True, tol)
206-
close_to_logical(vals < 1, True, tol)
205+
assert_array_equal(vals > 0, True)
206+
assert_array_equal(vals < 1, True)
207207

208208

209209
def test_lowerbound():
@@ -214,7 +214,7 @@ def test_lowerbound():
214214
check_jacobian_det(trans, Vector(Rplusbig, 2), pt.vector, [0, 0], elemwise=True)
215215

216216
vals = get_values(trans)
217-
close_to_logical(vals > 0, True, tol)
217+
assert_array_equal(vals > 0, True)
218218

219219

220220
def test_upperbound():
@@ -225,7 +225,7 @@ def test_upperbound():
225225
check_jacobian_det(trans, Vector(Rminusbig, 2), pt.vector, [-1, -1], elemwise=True)
226226

227227
vals = get_values(trans)
228-
close_to_logical(vals < 0, True, tol)
228+
assert_array_equal(vals < 0, True)
229229

230230

231231
def test_interval():
@@ -238,8 +238,8 @@ def test_interval():
238238
check_jacobian_det(trans, domain, elemwise=True)
239239

240240
vals = get_values(trans)
241-
close_to_logical(vals > a, True, tol)
242-
close_to_logical(vals < b, True, tol)
241+
assert_array_equal(vals > a, True)
242+
assert_array_equal(vals < b, True)
243243

244244

245245
@pytest.mark.skipif(
@@ -254,7 +254,7 @@ def test_interval_near_boundary():
254254
pm.Uniform("x", initval=x0, lower=lb, upper=ub)
255255

256256
log_prob = model.point_logps()
257-
np.testing.assert_allclose(list(log_prob.values()), floatX(np.array([-52.68])))
257+
assert_allclose(list(log_prob.values()), floatX(np.array([-52.68])))
258258

259259

260260
def test_circular():
@@ -264,8 +264,8 @@ def test_circular():
264264
check_jacobian_det(trans, Circ)
265265

266266
vals = get_values(trans)
267-
close_to_logical(vals > -np.pi, True, tol)
268-
close_to_logical(vals < np.pi, True, tol)
267+
assert_array_equal(vals > -np.pi, True)
268+
assert_array_equal(vals < np.pi, True)
269269

270270
assert isinstance(trans.forward(1, None), TensorConstant)
271271

@@ -281,13 +281,13 @@ def test_ordered():
281281
)
282282

283283
vals = get_values(tr.ordered, Vector(R, 3), pt.vector, floatX(np.zeros(3)))
284-
close_to_logical(np.diff(vals) >= 0, True, tol)
284+
assert_array_equal(np.diff(vals) >= 0, True)
285285

286286

287287
def test_chain_values():
288288
chain_tranf = tr.Chain([tr.logodds, tr.ordered])
289289
vals = get_values(chain_tranf, Vector(R, 5), pt.vector, floatX(np.zeros(5)))
290-
close_to_logical(np.diff(vals) >= 0, True, tol)
290+
assert_array_equal(np.diff(vals) >= 0, True)
291291

292292

293293
def test_chain_vector_transform():
@@ -339,7 +339,7 @@ def check_transform_elementwise_logp(self, model, vector_transform=False):
339339
untransform_logp_eval = untransform_logp.eval({x_val_untransf: test_array_untransf})
340340
log_jac_det_eval = log_jac_det.eval({x_val_transf: test_array_transf})
341341
# Summing the log_jac_det separately from the untransform_logp ensures there is no broadcasting between terms
342-
np.testing.assert_allclose(
342+
assert_allclose(
343343
transform_logp_eval.sum(),
344344
untransform_logp_eval.sum() + log_jac_det_eval.sum(),
345345
rtol=tol,

tests/helpers.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,12 @@
2323
import numpy.random as nr
2424
import pytensor
2525

26+
from numpy.testing import assert_array_less
2627
from pytensor.gradient import verify_grad as at_verify_grad
2728

2829
import pymc as pm
2930

3031
from pymc.testing import fast_unstable_sampling_mode
31-
from tests.checks import close_to
3232
from tests.models import mv_simple, mv_simple_coarse
3333

3434

@@ -118,11 +118,11 @@ def setup_class(self):
118118
def teardown_class(self):
119119
shutil.rmtree(self.temp_dir)
120120

121-
def check_stat(self, check, idata, name):
121+
def check_stat(self, check, idata):
122122
group = idata.posterior
123123
for var, stat, value, bound in check:
124124
s = stat(group[var].sel(chain=0), axis=0)
125-
close_to(s, value, bound, name)
125+
assert_array_less(np.abs(s.values - value), bound)
126126

127127
def check_stat_dtype(self, step, idata):
128128
# TODO: This check does not confirm the announced dtypes are correct as the
@@ -156,7 +156,7 @@ def step_continuous(self, step_fn, draws, chains=1, tune=1000):
156156
assert idata.warmup_posterior.sizes["draw"] == tune
157157
assert idata.posterior.sizes["chain"] == chains
158158
assert idata.posterior.sizes["draw"] == draws
159-
self.check_stat(check, idata, step.__class__.__name__)
159+
self.check_stat(check, idata)
160160
self.check_stat_dtype(idata, step)
161161

162162

tests/step_methods/test_metropolis.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -302,7 +302,7 @@ def test_step_discrete(self):
302302
model=model,
303303
random_seed=1,
304304
)
305-
self.check_stat(check, idata, step.__class__.__name__)
305+
self.check_stat(check, idata)
306306
self.check_stat_dtype(idata, step)
307307

308308
@pytest.mark.parametrize("proposal", ["uniform", "proportional"])
@@ -321,7 +321,7 @@ def test_step_categorical(self, proposal):
321321
model=model,
322322
random_seed=1,
323323
)
324-
self.check_stat(check, idata, step.__class__.__name__)
324+
self.check_stat(check, idata)
325325
self.check_stat_dtype(idata, step)
326326

327327
@pytest.mark.parametrize(

tests/tuning/test_starting.py

+20-17
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,15 @@
1616
import numpy as np
1717
import pytest
1818

19+
from numpy.testing import assert_allclose
20+
1921
import pymc as pm
2022

2123
from pymc.exceptions import ImputationWarning
2224
from pymc.step_methods.metropolis import tune
2325
from pymc.testing import select_by_precision
2426
from pymc.tuning import find_MAP
2527
from tests import models
26-
from tests.checks import close_to
2728
from tests.models import non_normal, simple_arbitrary_det, simple_model
2829

2930

@@ -36,7 +37,7 @@ def test_mle_jacobian(bounded):
3637
start, model, _ = models.simple_normal(bounded_prior=bounded)
3738
with model:
3839
map_estimate = find_MAP(method="BFGS", model=model)
39-
np.testing.assert_allclose(map_estimate["mu_i"], truth, rtol=rtol)
40+
assert_allclose(map_estimate["mu_i"], truth, rtol=rtol)
4041

4142

4243
def test_tune_not_inplace():
@@ -50,14 +51,16 @@ def test_accuracy_normal():
5051
_, model, (mu, _) = simple_model()
5152
with model:
5253
newstart = find_MAP(pm.Point(x=[-10.5, 100.5]))
53-
close_to(newstart["x"], [mu, mu], select_by_precision(float64=1e-5, float32=1e-4))
54+
assert_allclose(
55+
newstart["x"], [mu, mu], atol=select_by_precision(float64=1e-5, float32=1e-4)
56+
)
5457

5558

5659
def test_accuracy_non_normal():
5760
_, model, (mu, _) = non_normal(4)
5861
with model:
5962
newstart = find_MAP(pm.Point(x=[0.5, 0.01, 0.95, 0.99]))
60-
close_to(newstart["x"], mu, select_by_precision(float64=1e-5, float32=1e-4))
63+
assert_allclose(newstart["x"], mu, atol=select_by_precision(float64=1e-5, float32=1e-4))
6164

6265

6366
def test_find_MAP_discrete():
@@ -76,9 +79,9 @@ def test_find_MAP_discrete():
7679
map_est1 = find_MAP()
7780
map_est2 = find_MAP(vars=model.value_vars)
7881

79-
close_to(map_est1["p"], 0.6086956533498806, tol1)
82+
assert_allclose(map_est1["p"], 0.6086956533498806, atol=tol1, rtol=0)
8083

81-
close_to(map_est2["p"], 0.695642178810167, tol2)
84+
assert_allclose(map_est2["p"], 0.695642178810167, atol=tol2, rtol=0)
8285
assert map_est2["ss"] == 14
8386

8487

@@ -105,11 +108,11 @@ def test_find_MAP():
105108
# Test non-gradient minimization
106109
map_est2 = find_MAP(progressbar=False, method="Powell")
107110

108-
close_to(map_est1["mu"], 0, tol)
109-
close_to(map_est1["sigma"], 1, tol)
111+
assert_allclose(map_est1["mu"], 0, atol=tol)
112+
assert_allclose(map_est1["sigma"], 1, atol=tol)
110113

111-
close_to(map_est2["mu"], 0, tol)
112-
close_to(map_est2["sigma"], 1, tol)
114+
assert_allclose(map_est2["mu"], 0, atol=tol)
115+
assert_allclose(map_est2["sigma"], 1, atol=tol)
113116

114117

115118
def test_find_MAP_issue_5923():
@@ -131,11 +134,11 @@ def test_find_MAP_issue_5923():
131134
map_est1 = find_MAP(progressbar=False, vars=[mu, sigma], start=start)
132135
map_est2 = find_MAP(progressbar=False, vars=[sigma, mu], start=start)
133136

134-
close_to(map_est1["mu"], 0, tol)
135-
close_to(map_est1["sigma"], 1, tol)
137+
assert_allclose(map_est1["mu"], 0, atol=tol)
138+
assert_allclose(map_est1["sigma"], 1, atol=tol)
136139

137-
close_to(map_est2["mu"], 0, tol)
138-
close_to(map_est2["sigma"], 1, tol)
140+
assert_allclose(map_est2["mu"], 0, atol=tol)
141+
assert_allclose(map_est2["sigma"], 1, atol=tol)
139142

140143

141144
def test_find_MAP_issue_4488():
@@ -147,8 +150,8 @@ def test_find_MAP_issue_4488():
147150
map_estimate = find_MAP()
148151

149152
assert not set.difference({"x_unobserved", "x_unobserved_log__", "y"}, set(map_estimate.keys()))
150-
np.testing.assert_allclose(map_estimate["x_unobserved"], 0.2, rtol=1e-4, atol=1e-4)
151-
np.testing.assert_allclose(map_estimate["y"], [2.0, map_estimate["x_unobserved"][0] + 1])
153+
assert_allclose(map_estimate["x_unobserved"], 0.2, rtol=1e-4, atol=1e-4)
154+
assert_allclose(map_estimate["y"], [2.0, map_estimate["x_unobserved"][0] + 1])
152155

153156

154157
def test_find_MAP_warning_non_free_RVs():
@@ -161,4 +164,4 @@ def test_find_MAP_warning_non_free_RVs():
161164
msg = "Intermediate variables (such as Deterministic or Potential) were passed"
162165
with pytest.warns(UserWarning, match=re.escape(msg)):
163166
r = pm.find_MAP(vars=[det])
164-
np.testing.assert_allclose([r["x"], r["y"], r["det"]], [50, 50, 100])
167+
assert_allclose([r["x"], r["y"], r["det"]], [50, 50, 100])

0 commit comments

Comments
 (0)