Skip to content

Commit 8d355da

Browse files
authored
Add warning if observed in DensityDist is dict (#6292)
* 🚸 add warning if observed is dict * Improve validation * 👌 update error message
1 parent 07388bc commit 8d355da

File tree

2 files changed

+29
-1
lines changed

2 files changed

+29
-1
lines changed

pymc/distributions/distribution.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -577,6 +577,14 @@ def random(mu, rng=None, size=None):
577577

578578
def __new__(cls, name, *args, **kwargs):
579579
kwargs.setdefault("class_name", name)
580+
if isinstance(kwargs.get("observed", None), dict):
581+
raise TypeError(
582+
"Since ``v4.0.0`` the ``observed`` parameter should be of type"
583+
" ``pd.Series``, ``np.array``, or ``pm.Data``."
584+
" Previous versions allowed passing distribution parameters as"
585+
" a dictionary in ``observed``, in the current version these "
586+
"parameters are positional arguments."
587+
)
580588
return super().__new__(cls, name, *args, **kwargs)
581589

582590
@classmethod

pymc/tests/distributions/test_distribution.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,9 +145,29 @@ def test_density_dist_with_random(self, size):
145145
random=lambda mu, rng=None, size=None: rng.normal(loc=mu, scale=1, size=size),
146146
observed=np.random.randn(100, *size),
147147
)
148-
149148
assert obs.eval().shape == (100,) + size
150149

150+
def test_density_dist_with_random_invalid_observed(self):
151+
with pytest.raises(
152+
TypeError,
153+
match=(
154+
"Since ``v4.0.0`` the ``observed`` parameter should be of type"
155+
" ``pd.Series``, ``np.array``, or ``pm.Data``."
156+
" Previous versions allowed passing distribution parameters as"
157+
" a dictionary in ``observed``, in the current version these "
158+
"parameters are positional arguments."
159+
),
160+
):
161+
size = (3,)
162+
with pm.Model() as model:
163+
mu = pm.Normal("mu", 0, 1)
164+
pm.DensityDist(
165+
"density_dist",
166+
mu,
167+
random=lambda mu, rng=None, size=None: rng.normal(loc=mu, scale=1, size=size),
168+
observed={"values": np.random.randn(100, *size)},
169+
)
170+
151171
def test_density_dist_without_random(self):
152172
with pm.Model() as model:
153173
mu = pm.Normal("mu", 0, 1)

0 commit comments

Comments
 (0)