diff --git a/tests/variational/test_inference.py b/tests/variational/test_inference.py index f4561f6a60..6f04819ba1 100644 --- a/tests/variational/test_inference.py +++ b/tests/variational/test_inference.py @@ -14,6 +14,7 @@ import io import operator +import warnings from contextlib import nullcontext @@ -196,18 +197,26 @@ def test_fit_start(inference_spec, simple_model): # Minibatch data can't be extracted into the `observed_data` group in the final InferenceData [observed_value] = [simple_model.rvs_to_values[obs] for obs in simple_model.observed_RVs] - if observed_value.name.startswith("minibatch"): - warn_ctxt = pytest.warns( - UserWarning, match="Could not extract data from symbolic observation" - ) - else: - warn_ctxt = nullcontext() - try: - with warn_ctxt: + # We can`t use pytest.warns here because after version 8.0 it`s still check for warning when + # exception raised and test failed instead being skipped + warning_raised = False + expected_warning = observed_value.name.startswith("minibatch") + with warnings.catch_warnings(record=True) as record: + warnings.simplefilter("always") + try: trace = inference.fit(n=0).sample(10000) - except NotImplementedInference as e: - pytest.skip(str(e)) + except NotImplementedInference as e: + pytest.skip(str(e)) + + if expected_warning: + assert len(record) > 0 + for item in record: + assert issubclass(item.category, UserWarning) + assert "Could not extract data from symbolic observation" in str(item.message) + if not expected_warning: + assert not record + np.testing.assert_allclose(np.mean(trace.posterior["mu"]), mu_init, rtol=0.05) if has_start_sigma: np.testing.assert_allclose(np.std(trace.posterior["mu"]), mu_sigma_init, rtol=0.05)