|
11 | 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 | 12 | # See the License for the specific language governing permissions and
|
13 | 13 | # limitations under the License.
|
| 14 | +import aesara.tensor as at |
14 | 15 | import numpy as np
|
15 | 16 | import pytest
|
16 | 17 |
|
@@ -92,6 +93,27 @@ def test_automatically_assigned_test_values(self):
|
92 | 93 | assert hasattr(rv.tag, "test_value")
|
93 | 94 | pass
|
94 | 95 |
|
| 96 | + @pytest.mark.parametrize("rv_cls", [pm.Flat, pm.HalfFlat]) |
| 97 | + def test_initval_shape(self, rv_cls): |
| 98 | + rv = rv_cls.dist(shape=(2,)) |
| 99 | + assert np.shape(rv.tag.test_value) == (2,) |
| 100 | + # Can't set numeric test values when dimensionality is symbolic |
| 101 | + rv = rv_cls.dist(shape=(at.scalar(),)) |
| 102 | + assert rv.tag.test_value is None |
| 103 | + pass |
| 104 | + |
| 105 | + @pytest.mark.parametrize("rv_cls", [pm.Flat, pm.HalfFlat]) |
| 106 | + def test_initval_dims(self, rv_cls): |
| 107 | + with pm.Model( |
| 108 | + coords={ |
| 109 | + "year": [2019, 2020, 2021], |
| 110 | + } |
| 111 | + ) as pmodel: |
| 112 | + rv = rv_cls("rv", dims=("year",)) |
| 113 | + # Can't set numeric test values when dimensionality is symbolic |
| 114 | + assert rv.tag.test_value is None |
| 115 | + pass |
| 116 | + |
95 | 117 |
|
96 | 118 | class TestMoment:
|
97 | 119 | def test_basic(self):
|
|
0 commit comments