|
11 | 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 | 12 | # See the License for the specific language governing permissions and
|
13 | 13 | # limitations under the License.
|
| 14 | +import aesara.tensor as at |
14 | 15 | import numpy as np
|
15 | 16 | import pytest
|
16 | 17 |
|
@@ -108,3 +109,33 @@ def test_basic(self):
|
108 | 109 | assert np.all(get_moment(rv).eval() == np.zeros((2, 4)))
|
109 | 110 | rv = pm.HalfFlat.dist(size=(2, 4))
|
110 | 111 | assert np.all(get_moment(rv).eval() == np.ones((2, 4)))
|
| 112 | + |
| 113 | + @pytest.mark.xfail(reason="Test values are still used for initvals.") |
| 114 | + @pytest.mark.parametrize("rv_cls", [pm.Flat, pm.HalfFlat]) |
| 115 | + def test_numeric_moment_shape(self, rv_cls): |
| 116 | + rv = rv_cls.dist(shape=(2,)) |
| 117 | + assert not hasattr(rv.tag, "test_value") |
| 118 | + assert tuple(get_moment(rv).shape.eval()) == (2,) |
| 119 | + |
| 120 | + @pytest.mark.xfail(reason="Test values are still used for initvals.") |
| 121 | + @pytest.mark.parametrize("rv_cls", [pm.Flat, pm.HalfFlat]) |
| 122 | + def test_symbolic_moment_shape(self, rv_cls): |
| 123 | + s = at.scalar() |
| 124 | + rv = rv_cls.dist(shape=(s,)) |
| 125 | + assert not hasattr(rv.tag, "test_value") |
| 126 | + assert tuple(get_moment(rv).shape.eval({s: 4})) == (4,) |
| 127 | + pass |
| 128 | + |
| 129 | + @pytest.mark.xfail(reason="Test values are still used for initvals.") |
| 130 | + @pytest.mark.parametrize("rv_cls", [pm.Flat, pm.HalfFlat]) |
| 131 | + def test_moment_from_dims(self, rv_cls): |
| 132 | + with pm.Model( |
| 133 | + coords={ |
| 134 | + "year": [2019, 2020, 2021, 2022], |
| 135 | + "city": ["Bonn", "Paris", "Lisbon"], |
| 136 | + } |
| 137 | + ): |
| 138 | + rv = rv_cls("rv", dims=("year", "city")) |
| 139 | + assert not hasattr(rv.tag, "test_value") |
| 140 | + assert tuple(get_moment(rv).shape.eval()) == (4, 3) |
| 141 | + pass |
0 commit comments