Skip to content

Commit 4cb0d1b

Browse files
committed
Make mul_logp and add_logp work with transforms
1 parent e8f9602 commit 4cb0d1b

File tree

2 files changed

+63
-59
lines changed

2 files changed

+63
-59
lines changed

pymc3/distributions/logp.py

+42-40
Original file line numberDiff line numberDiff line change
@@ -318,37 +318,38 @@ def add_logp(op, var, rvs_to_values, *add_inputs, **kwargs):
318318
if len(add_inputs) != 2:
319319
raise ValueError(f"Expected 2 inputs but got: {len(add_inputs)}")
320320

321-
rv, loc = find_rv_branch(add_inputs)
321+
base_rv, loc = find_rv_branch(add_inputs)
322322

323-
if len(rv) != 1:
323+
if len(base_rv) != 1:
324324
raise NotImplementedError(
325-
f"Logp of addition requires one branch with an unregistered RandomVariable but got {len(rv)}"
325+
f"Logp of addition requires one branch with an unregistered RandomVariable but got {len(base_rv)}"
326326
)
327327

328-
rv = rv[0]
329-
rv_value = rvs_to_values.get(rv, getattr(rv.tag, "value_var", rv))
328+
var_value = rvs_to_values.get(var, var)
330329
loc = loc[0]
331-
loc_value = rvs_to_values.get(loc, getattr(loc.tag, "value_var", loc))
332-
333-
new_rvs_to_values = rvs_to_values.copy()
334-
new_rvs_to_values[rv] = rv_value
330+
base_rv = base_rv[0]
331+
base_value = base_rv.type()
335332

336-
logp_rv = logpt(rv, new_rvs_to_values, **kwargs)
333+
logp_base_rv = logpt(base_rv, {base_rv: base_value}, **kwargs)
337334
fgraph = FunctionGraph(
338-
[i for i in graph_inputs((logp_rv,)) if not isinstance(i, Constant)],
339-
[logp_rv],
335+
[i for i in graph_inputs((logp_base_rv,)) if not isinstance(i, Constant)],
336+
[logp_base_rv],
340337
clone=False,
341338
)
339+
fgraph.replace(base_value, var_value - loc, import_missing=True)
340+
logp_add_rv = fgraph.outputs[0]
341+
342+
# Replace rvs in graph
343+
# TODO: This shouldn't be here
344+
(logp_add_rv,), _ = rvs_to_value_vars(
345+
(logp_add_rv,),
346+
apply_transforms=True, # Change this
347+
initial_replacements=None,
348+
)
342349

343-
var_value = rvs_to_values.get(var, var)
344-
345-
fgraph.add_input(loc_value)
346-
fgraph.add_input(var_value)
347-
fgraph.replace(rv_value, var_value - loc_value)
348-
349-
logp_rv.name = f"__logp_{var.name}"
350+
logp_add_rv.name = f"__logp_{var.name}"
350351

351-
return logp_rv
352+
return logp_add_rv
352353

353354

354355
@_logp.register(Mul)
@@ -357,40 +358,41 @@ def mul_logp(op, var, rvs_to_values, *mul_inputs, **kwargs):
357358
if len(mul_inputs) != 2:
358359
raise ValueError(f"Expected 2 inputs but got: {len(mul_inputs)}")
359360

360-
rv, scale = find_rv_branch(mul_inputs)
361+
base_rv, scale = find_rv_branch(mul_inputs)
361362

362-
if len(rv) != 1:
363+
if len(base_rv) != 1:
363364
raise NotImplementedError(
364-
f"Logp of product requires one branch with an unregistered RandomVariable but got {len(rv)}"
365+
f"Logp of product requires one branch with an unregistered RandomVariable but got {len(base_rv)}"
365366
)
366367

367-
rv = rv[0]
368-
rv_value = rvs_to_values.get(rv, getattr(rv.tag, "value_var", rv))
368+
var_value = rvs_to_values.get(var, var)
369369
scale = scale[0]
370-
scale_value = rvs_to_values.get(scale, getattr(scale.tag, "value_var", scale))
371-
372-
new_rvs_to_values = rvs_to_values.copy()
373-
new_rvs_to_values[rv] = rv_value
370+
base_rv = base_rv[0]
371+
base_value = base_rv.type()
374372

375-
logp_rv = logpt(rv, new_rvs_to_values, **kwargs)
373+
logp_base_rv = logpt(base_rv, {base_rv: base_value}, **kwargs)
376374
fgraph = FunctionGraph(
377-
[i for i in graph_inputs((logp_rv,)) if not isinstance(i, Constant)],
378-
[logp_rv],
375+
[i for i in graph_inputs((logp_base_rv,)) if not isinstance(i, Constant)],
376+
[logp_base_rv],
379377
clone=False,
380378
)
381379

382-
var_value = rvs_to_values.get(var, var)
383-
384-
fgraph.add_input(scale_value)
385-
fgraph.add_input(var_value)
386380
# TODO: This is not correct for discrete variables
387381
# TODO: Undefined behavior for scale = 0
388-
fgraph.replace(rv_value, var_value / scale_value)
382+
fgraph.replace(base_value, var_value / scale, import_missing=True)
383+
logp_mul_rv = fgraph.outputs[0] - at.log(at.abs_(scale))
384+
385+
# Replace rvs in graph
386+
# TODO: This shouldn't be here
387+
(logp_mul_rv,), _ = rvs_to_value_vars(
388+
(logp_mul_rv,),
389+
apply_transforms=True, # Change this
390+
initial_replacements=None,
391+
)
389392

390-
logp_rv = fgraph.outputs[0] - at.log(at.abs_(scale_value))
391-
logp_rv.name = f"__logp_{var.name}"
393+
logp_mul_rv.name = f"__logp_{var.name}"
392394

393-
return logp_rv
395+
return logp_mul_rv
394396

395397

396398
def convert_indices(indices, entry):

pymc3/tests/test_logp.py

+21-19
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def test_logpt_add():
7777
random variable and ``loc`` is an tensor variable or a registered random variable
7878
"""
7979
with Model() as m:
80-
loc = Uniform("loc", 0, 1)
80+
loc = Exponential("loc", 10)
8181
x = Normal.dist(0, 1) + loc
8282
m.register_rv(x, "x")
8383

@@ -98,8 +98,8 @@ def test_logpt_add():
9898

9999
# Test logp is correct
100100
f_logp = aesara.function([x_value_var, loc_value_var], x_logp)
101-
np.testing.assert_almost_equal(f_logp(50, 50), sp.norm(50, 1).logpdf(50))
102-
np.testing.assert_almost_equal(f_logp(50, 0), sp.norm(0, 1).logpdf(50), decimal=5)
101+
np.testing.assert_almost_equal(f_logp(50, np.log(50)), sp.norm(50, 1).logpdf(50))
102+
np.testing.assert_almost_equal(f_logp(50, np.log(10)), sp.norm(10, 1).logpdf(50), decimal=5)
103103

104104

105105
def test_logpt_mul():
@@ -108,7 +108,7 @@ def test_logpt_mul():
108108
random variable and ``scale`` is an tensor variable or a registered random variable
109109
"""
110110
with Model() as m:
111-
scale = Uniform("scale", 0, 1)
111+
scale = Exponential("scale", 10)
112112
x = Exponential.dist(1) * scale
113113
m.register_rv(x, "x")
114114

@@ -129,9 +129,8 @@ def test_logpt_mul():
129129

130130
# Test logp is correct
131131
f_logp = aesara.function([x_value_var, scale_value_var], x_logp)
132-
np.testing.assert_almost_equal(f_logp(0, 5), sp.expon(scale=5).logpdf(0))
133-
np.testing.assert_almost_equal(f_logp(-2, -2), sp.expon(scale=2).logpdf(2))
134-
assert f_logp(2, -2) == -np.inf
132+
np.testing.assert_almost_equal(f_logp(0, np.log(5)), sp.expon(scale=5).logpdf(0))
133+
np.testing.assert_almost_equal(f_logp(2, np.log(2)), sp.expon(scale=2).logpdf(2))
135134

136135

137136
def test_logpt_mul_add():
@@ -140,8 +139,8 @@ def test_logpt_mul_add():
140139
random variable and ``loc`` and ``scale`` are tensor variables or registered random variables
141140
"""
142141
with Model() as m:
143-
loc = Uniform("loc", 0, 1)
144-
scale = Uniform("scale", 0, 1)
142+
loc = Exponential("loc", 10)
143+
scale = Exponential("scale", 10)
145144
x = loc + scale * Normal.dist(0, 1)
146145
m.register_rv(x, "x")
147146

@@ -164,11 +163,14 @@ def test_logpt_mul_add():
164163

165164
# Test logp is correct
166165
f_logp = aesara.function([x_value_var, loc_value_var, scale_value_var], x_logp)
167-
np.testing.assert_almost_equal(f_logp(-1, 0, 2), sp.norm(0, 2).logpdf(-1))
168-
np.testing.assert_almost_equal(f_logp(95, 100, 15), sp.norm(100, 15).logpdf(95), decimal=6)
166+
np.testing.assert_almost_equal(f_logp(-1, np.log(0), np.log(2)), sp.norm(0, 2).logpdf(-1))
167+
np.testing.assert_almost_equal(
168+
f_logp(95, np.log(100), np.log(15)), sp.norm(100, 15).logpdf(95), decimal=6
169+
)
169170

170171

171-
def test_logpt_not_implemented():
172+
@pytest.mark.parametrize("op", [at.add, at.mul])
173+
def test_logpt_not_implemented(op):
172174
"""Test that logpt for add and mul fail if inputs are 0 or 2 unregistered rvs"""
173175

174176
with Model() as m:
@@ -179,14 +181,14 @@ def test_logpt_not_implemented():
179181
registered1 = Normal("registered1", 0, 1)
180182
registered2 = Normal("registered2", 0, 1)
181183

182-
x_fail1 = variable1 + variable2
183-
x_fail2 = unregistered1 + unregistered2
184-
x_fail3 = registered1 + variable1
185-
x_fail4 = registered1 + registered2
184+
x_fail1 = op(variable1, variable2)
185+
x_fail2 = op(unregistered1, unregistered2)
186+
x_fail3 = op(registered1, variable1)
187+
x_fail4 = op(registered1, registered2)
186188

187-
x_pass1 = variable1 + unregistered2
188-
x_pass2 = unregistered1 + variable2
189-
x_pass3 = registered1 + unregistered1
189+
x_pass1 = op(variable1, unregistered2)
190+
x_pass2 = op(unregistered1, variable2)
191+
x_pass3 = op(registered1, unregistered1)
190192

191193
m.register_rv(x_fail1, "x_fail1")
192194
m.register_rv(x_fail2, "x_fail2")

0 commit comments

Comments
 (0)