Skip to content

Commit d74537c

Browse files
committed
Tweak prior in test_deterministic_of_observed
This test was failing due to the logp assertions introduced by aeppl (see aesara-devs/aeppl#84), which combined with the strongly informative prior away from the likelihood led to underflow of sigma to 0, when using float32, and a subsequent AssertionError. Since this test is only concerned with the accuracy of the deterministic I have simply made the prior less strict.
1 parent e9a06e6 commit d74537c

File tree

1 file changed

+2
-10
lines changed

1 file changed

+2
-10
lines changed

pymc/tests/test_sampling.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import sys
1615
import unittest.mock as mock
1716

1817
from contextlib import ExitStack as does_not_raise
@@ -39,9 +38,6 @@
3938
from pymc.tests.helpers import SeededTest
4039
from pymc.tests.models import simple_init
4140

42-
IS_LINUX = sys.platform == "linux"
43-
IS_FLOAT32 = aesara.config.floatX == "float32"
44-
4541

4642
class TestInitNuts(SeededTest):
4743
def setup_method(self):
@@ -705,20 +701,16 @@ def test_model_shared_variable(self):
705701
assert post_pred["obs"].shape == (samples, 3)
706702
npt.assert_allclose(post_pred["p"], expected_p)
707703

708-
@pytest.mark.xfail(
709-
condition=IS_FLOAT32 and IS_LINUX,
710-
reason="Test fails on linux float32 systems. See https://github.com/pymc-devs/pymc/issues/5088",
711-
)
712704
def test_deterministic_of_observed(self):
713705
rng = np.random.RandomState(8442)
714706

715707
meas_in_1 = pm.aesaraf.floatX(2 + 4 * rng.randn(10))
716708
meas_in_2 = pm.aesaraf.floatX(5 + 4 * rng.randn(10))
717709
nchains = 2
718710
with pm.Model(rng_seeder=rng) as model:
719-
mu_in_1 = pm.Normal("mu_in_1", 0, 1)
711+
mu_in_1 = pm.Normal("mu_in_1", 0, 2)
720712
sigma_in_1 = pm.HalfNormal("sd_in_1", 1)
721-
mu_in_2 = pm.Normal("mu_in_2", 0, 1)
713+
mu_in_2 = pm.Normal("mu_in_2", 0, 2)
722714
sigma_in_2 = pm.HalfNormal("sd__in_2", 1)
723715

724716
in_1 = pm.Normal("in_1", mu_in_1, sigma_in_1, observed=meas_in_1)

0 commit comments

Comments
 (0)