|
12 | 12 | # See the License for the specific language governing permissions and
|
13 | 13 | # limitations under the License.
|
14 | 14 |
|
15 |
| -import sys |
16 | 15 | import unittest.mock as mock
|
17 | 16 |
|
18 | 17 | from contextlib import ExitStack as does_not_raise
|
|
39 | 38 | from pymc.tests.helpers import SeededTest
|
40 | 39 | from pymc.tests.models import simple_init
|
41 | 40 |
|
42 |
| -IS_LINUX = sys.platform == "linux" |
43 |
| -IS_FLOAT32 = aesara.config.floatX == "float32" |
44 |
| - |
45 | 41 |
|
46 | 42 | class TestInitNuts(SeededTest):
|
47 | 43 | def setup_method(self):
|
@@ -705,20 +701,16 @@ def test_model_shared_variable(self):
|
705 | 701 | assert post_pred["obs"].shape == (samples, 3)
|
706 | 702 | npt.assert_allclose(post_pred["p"], expected_p)
|
707 | 703 |
|
708 |
| - @pytest.mark.xfail( |
709 |
| - condition=IS_FLOAT32 and IS_LINUX, |
710 |
| - reason="Test fails on linux float32 systems. See https://github.com/pymc-devs/pymc/issues/5088", |
711 |
| - ) |
712 | 704 | def test_deterministic_of_observed(self):
|
713 | 705 | rng = np.random.RandomState(8442)
|
714 | 706 |
|
715 | 707 | meas_in_1 = pm.aesaraf.floatX(2 + 4 * rng.randn(10))
|
716 | 708 | meas_in_2 = pm.aesaraf.floatX(5 + 4 * rng.randn(10))
|
717 | 709 | nchains = 2
|
718 | 710 | with pm.Model(rng_seeder=rng) as model:
|
719 |
| - mu_in_1 = pm.Normal("mu_in_1", 0, 1) |
| 711 | + mu_in_1 = pm.Normal("mu_in_1", 0, 2) |
720 | 712 | sigma_in_1 = pm.HalfNormal("sd_in_1", 1)
|
721 |
| - mu_in_2 = pm.Normal("mu_in_2", 0, 1) |
| 713 | + mu_in_2 = pm.Normal("mu_in_2", 0, 2) |
722 | 714 | sigma_in_2 = pm.HalfNormal("sd__in_2", 1)
|
723 | 715 |
|
724 | 716 | in_1 = pm.Normal("in_1", mu_in_1, sigma_in_1, observed=meas_in_1)
|
|
0 commit comments