Skip to content

Commit 30040cd

Browse files
Add regression tests for #4993
1 parent 59ce9af commit 30040cd

File tree

1 file changed

+22
-0
lines changed

1 file changed

+22
-0
lines changed

pymc3/tests/test_initvals.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import aesara.tensor as at
1415
import numpy as np
1516
import pytest
1617

@@ -92,6 +93,27 @@ def test_automatically_assigned_test_values(self):
9293
assert hasattr(rv.tag, "test_value")
9394
pass
9495

96+
@pytest.mark.parametrize("rv_cls", [pm.Flat, pm.HalfFlat])
97+
def test_initval_shape(self, rv_cls):
98+
rv = rv_cls.dist(shape=(2,))
99+
assert np.shape(rv.tag.test_value) == (2,)
100+
# Can't set numeric test values when dimensionality is symbolic
101+
rv = rv_cls.dist(shape=(at.scalar(),))
102+
assert rv.tag.test_value is None
103+
pass
104+
105+
@pytest.mark.parametrize("rv_cls", [pm.Flat, pm.HalfFlat])
106+
def test_initval_dims(self, rv_cls):
107+
with pm.Model(
108+
coords={
109+
"year": [2019, 2020, 2021],
110+
}
111+
) as pmodel:
112+
rv = rv_cls("rv", dims=("year",))
113+
# Can't set numeric test values when dimensionality is symbolic
114+
assert rv.tag.test_value is None
115+
pass
116+
95117

96118
class TestMoment:
97119
def test_basic(self):

0 commit comments

Comments
 (0)