|
42 | 42 | from pymc.blocking import DictToArrayBijection, RaveledVars
|
43 | 43 | from pymc.distributions import Normal, transforms
|
44 | 44 | from pymc.distributions.distribution import PartialObservedRV
|
45 |
| -from pymc.distributions.transforms import log, simplex |
| 45 | +from pymc.distributions.transforms import Transform, log, simplex |
46 | 46 | from pymc.exceptions import ImputationWarning, ShapeError, ShapeWarning
|
47 | 47 | from pymc.logprob.basic import transformed_conditional_logp
|
48 | 48 | from pymc.logprob.transforms import IntervalTransform
|
@@ -527,6 +527,42 @@ def test_model_var_maps():
|
527 | 527 | assert model.rvs_to_transforms[x] is None
|
528 | 528 |
|
529 | 529 |
|
| 530 | +class TestTransformArgs: |
| 531 | + def test_transform_warning(self): |
| 532 | + with pm.Model(): |
| 533 | + with pytest.warns( |
| 534 | + UserWarning, |
| 535 | + match="To disable default transform," |
| 536 | + " please use default_transform=None" |
| 537 | + " instead of transform=None. Setting transform to" |
| 538 | + " None will not have any effect in future.", |
| 539 | + ): |
| 540 | + a = pm.Normal("a", transform=None) |
| 541 | + |
| 542 | + def test_transform_order(self): |
| 543 | + transform_order = [] |
| 544 | + |
| 545 | + class DummyTransform(Transform): |
| 546 | + name = "dummy1" |
| 547 | + ndim_supp = 0 |
| 548 | + |
| 549 | + def __init__(self, marker) -> None: |
| 550 | + super().__init__() |
| 551 | + self.marker = marker |
| 552 | + |
| 553 | + def forward(self, value, *inputs): |
| 554 | + nonlocal transform_order |
| 555 | + transform_order.append(self.marker) |
| 556 | + return value |
| 557 | + |
| 558 | + def backward(self, value, *inputs): |
| 559 | + return value |
| 560 | + |
| 561 | + with pm.Model() as model: |
| 562 | + x = pm.Normal("x", transform=DummyTransform(2), default_transform=DummyTransform(1)) |
| 563 | + assert transform_order == [1, 2] |
| 564 | + |
| 565 | + |
530 | 566 | def test_make_obs_var():
|
531 | 567 | """
|
532 | 568 | Check returned values for `data` given known inputs to `as_tensor()`.
|
|
0 commit comments