Skip to content

Commit 7d6ecf9

Browse files
committed
add tests for transfor args
1 parent def0f6e commit 7d6ecf9

File tree

1 file changed

+37
-1
lines changed

1 file changed

+37
-1
lines changed

tests/model/test_core.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
from pymc.blocking import DictToArrayBijection, RaveledVars
4343
from pymc.distributions import Normal, transforms
4444
from pymc.distributions.distribution import PartialObservedRV
45-
from pymc.distributions.transforms import log, simplex
45+
from pymc.distributions.transforms import Transform, log, simplex
4646
from pymc.exceptions import ImputationWarning, ShapeError, ShapeWarning
4747
from pymc.logprob.basic import transformed_conditional_logp
4848
from pymc.logprob.transforms import IntervalTransform
@@ -527,6 +527,42 @@ def test_model_var_maps():
527527
assert model.rvs_to_transforms[x] is None
528528

529529

530+
class TestTransformArgs:
531+
def test_transform_warning(self):
532+
with pm.Model():
533+
with pytest.warns(
534+
UserWarning,
535+
match="To disable default transform,"
536+
" please use default_transform=None"
537+
" instead of transform=None. Setting transform to"
538+
" None will not have any effect in future.",
539+
):
540+
a = pm.Normal("a", transform=None)
541+
542+
def test_transform_order(self):
543+
transform_order = []
544+
545+
class DummyTransform(Transform):
546+
name = "dummy1"
547+
ndim_supp = 0
548+
549+
def __init__(self, marker) -> None:
550+
super().__init__()
551+
self.marker = marker
552+
553+
def forward(self, value, *inputs):
554+
nonlocal transform_order
555+
transform_order.append(self.marker)
556+
return value
557+
558+
def backward(self, value, *inputs):
559+
return value
560+
561+
with pm.Model() as model:
562+
x = pm.Normal("x", transform=DummyTransform(2), default_transform=DummyTransform(1))
563+
assert transform_order == [1, 2]
564+
565+
530566
def test_make_obs_var():
531567
"""
532568
Check returned values for `data` given known inputs to `as_tensor()`.

0 commit comments

Comments
 (0)