Skip to content

Commit 5b6c804

Browse files
committed
Allow measurable stack and join with interdependent inputs
1 parent 5ec481f commit 5b6c804

File tree

2 files changed

+143
-19
lines changed

2 files changed

+143
-19
lines changed

pymc/logprob/tensor.py

Lines changed: 57 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -129,11 +129,24 @@ class MeasurableMakeVector(MakeVector):
129129

130130

131131
@_logprob.register(MeasurableMakeVector)
132-
def logprob_make_vector(op, values, *base_vars, **kwargs):
132+
def logprob_make_vector(op, values, *base_rvs, **kwargs):
133133
"""Compute the log-likelihood graph for a `MeasurableMakeVector`."""
134+
# TODO: Sort out this circular dependency issue
135+
from pymc.pytensorf import replace_rvs_by_values
136+
134137
(value,) = values
135138

136-
return at.stack([logprob(base_var, value[i]) for i, base_var in enumerate(base_vars)])
139+
base_rvs_to_values = {base_rv: value[i] for i, base_rv in enumerate(base_rvs)}
140+
for i, (base_rv, value) in enumerate(base_rvs_to_values.items()):
141+
base_rv.name = f"base_rv[{i}]"
142+
value.name = f"value[{i}]"
143+
144+
logps = [logprob(base_rv, value) for base_rv, value in base_rvs_to_values.items()]
145+
146+
# If the stacked variables depend on each other, we have to replace them by the respective values
147+
logps = replace_rvs_by_values(logps, rvs_to_values=base_rvs_to_values)
148+
149+
return at.stack(logps)
137150

138151

139152
class MeasurableJoin(Join):
@@ -144,27 +157,28 @@ class MeasurableJoin(Join):
144157

145158

146159
@_logprob.register(MeasurableJoin)
147-
def logprob_join(op, values, axis, *base_vars, **kwargs):
160+
def logprob_join(op, values, axis, *base_rvs, **kwargs):
148161
"""Compute the log-likelihood graph for a `Join`."""
149-
(value,) = values
162+
# TODO: Find better way to avoid circular dependency
163+
from pymc.pytensorf import constant_fold, replace_rvs_by_values
150164

151-
base_var_shapes = [base_var.shape[axis] for base_var in base_vars]
165+
(value,) = values
152166

153-
# TODO: Find better way to avoid circular dependency
154-
from pymc.pytensorf import constant_fold
167+
base_rv_shapes = [base_var.shape[axis] for base_var in base_rvs]
155168

156169
# We don't need the graph to be constant, just to have RandomVariables removed
157-
base_var_shapes = constant_fold(base_var_shapes, raise_not_constant=False)
170+
base_rv_shapes = constant_fold(base_rv_shapes, raise_not_constant=False)
158171

159172
split_values = at.split(
160173
value,
161-
splits_size=base_var_shapes,
162-
n_splits=len(base_vars),
174+
splits_size=base_rv_shapes,
175+
n_splits=len(base_rvs),
163176
axis=axis,
164177
)
165178

179+
base_rvs_to_split_values = {base_rv: value for base_rv, value in zip(base_rvs, split_values)}
166180
logps = [
167-
logprob(base_var, split_value) for base_var, split_value in zip(base_vars, split_values)
181+
logprob(base_var, split_value) for base_var, split_value in base_rvs_to_split_values.items()
168182
]
169183

170184
if len({logp.ndim for logp in logps}) != 1:
@@ -173,12 +187,12 @@ def logprob_join(op, values, axis, *base_vars, **kwargs):
173187
"joining univariate and multivariate distributions",
174188
)
175189

190+
# If the stacked variables depend on each other, we have to replace them by the respective values
191+
logps = replace_rvs_by_values(logps, rvs_to_values=base_rvs_to_split_values)
192+
176193
base_vars_ndim_supp = split_values[0].ndim - logps[0].ndim
177194
join_logprob = at.concatenate(
178-
[
179-
at.atleast_1d(logprob(base_var, split_value))
180-
for base_var, split_value in zip(base_vars, split_values)
181-
],
195+
[at.atleast_1d(logp) for logp in logps],
182196
axis=axis - base_vars_ndim_supp,
183197
)
184198

@@ -199,6 +213,8 @@ def find_measurable_stacks(
199213
if rv_map_feature is None:
200214
return None # pragma: no cover
201215

216+
rvs_to_values = rv_map_feature.rv_values
217+
202218
stack_out = node.outputs[0]
203219

204220
is_join = isinstance(node.op, Join)
@@ -211,18 +227,40 @@ def find_measurable_stacks(
211227
if not all(
212228
base_var.owner
213229
and isinstance(base_var.owner.op, MeasurableVariable)
214-
and base_var not in rv_map_feature.rv_values
230+
and base_var not in rvs_to_values
215231
for base_var in base_vars
216232
):
217233
return None # pragma: no cover
218234

219235
# Make base_vars unmeasurable
220-
base_vars = [assign_custom_measurable_outputs(base_var.owner) for base_var in base_vars]
236+
base_to_unmeasurable_vars = {
237+
base_var: assign_custom_measurable_outputs(base_var.owner).outputs[
238+
base_var.owner.outputs.index(base_var)
239+
]
240+
for base_var in base_vars
241+
}
242+
243+
def replacement_fn(var, replacements):
244+
if var in base_to_unmeasurable_vars:
245+
replacements[var] = base_to_unmeasurable_vars[var]
246+
# We don't want to clone valued nodes. Assigning a var to itself in the
247+
# replacements prevents this
248+
elif var in rvs_to_values:
249+
replacements[var] = var
250+
251+
return []
252+
253+
# TODO: Fix this import circularity!
254+
from pymc.pytensorf import _replace_rvs_in_graphs
255+
256+
unmeasurable_base_vars, _ = _replace_rvs_in_graphs(
257+
graphs=base_vars, replacement_fn=replacement_fn
258+
)
221259

222260
if is_join:
223-
measurable_stack = MeasurableJoin()(axis, *base_vars)
261+
measurable_stack = MeasurableJoin()(axis, *unmeasurable_base_vars)
224262
else:
225-
measurable_stack = MeasurableMakeVector(node.op.dtype)(*base_vars)
263+
measurable_stack = MeasurableMakeVector(node.op.dtype)(*unmeasurable_base_vars)
226264

227265
measurable_stack.name = stack_out.name
228266

pymc/tests/logprob/test_tensor.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
from pymc.logprob import factorized_joint_logprob, joint_logprob
4949
from pymc.logprob.rewriting import logprob_rewrites_db
5050
from pymc.logprob.tensor import naive_bcast_rv_lift
51+
from pymc.tests.helpers import assert_no_rvs
5152

5253

5354
def test_naive_bcast_rv_lift():
@@ -109,6 +110,91 @@ def test_measurable_make_vector():
109110
assert np.isclose(make_vector_logp_eval.sum(), ref_logp_eval_eval)
110111

111112

113+
@pytest.mark.parametrize("reverse", (False, True))
114+
def test_measurable_make_vector_interdependent(reverse):
115+
"""Test that we can obtain a proper graph when stacked RVs depend on each other"""
116+
x = at.random.normal(name="x")
117+
y_rvs = []
118+
prev_rv = x
119+
for i in range(3):
120+
next_rv = at.random.normal(prev_rv + 1, name=f"y{i}")
121+
y_rvs.append(next_rv)
122+
prev_rv = next_rv
123+
124+
if reverse:
125+
y_rvs = y_rvs[::-1]
126+
127+
ys = at.stack(y_rvs)
128+
ys.name = "ys"
129+
130+
x_vv = x.clone()
131+
ys_vv = ys.clone()
132+
133+
logp = joint_logprob({x: x_vv, ys: ys_vv})
134+
assert_no_rvs(logp)
135+
136+
y0_vv = y_rvs[0].clone()
137+
y1_vv = y_rvs[1].clone()
138+
y2_vv = y_rvs[2].clone()
139+
140+
ref_logp = joint_logprob({x: x_vv, y_rvs[0]: y0_vv, y_rvs[1]: y1_vv, y_rvs[2]: y2_vv})
141+
142+
rng = np.random.default_rng()
143+
x_vv_test = rng.normal()
144+
ys_vv_test = rng.normal(size=3)
145+
np.testing.assert_allclose(
146+
logp.eval({x_vv: x_vv_test, ys_vv: ys_vv_test}),
147+
ref_logp.eval(
148+
{x_vv: x_vv_test, y0_vv: ys_vv_test[0], y1_vv: ys_vv_test[1], y2_vv: ys_vv_test[2]}
149+
),
150+
)
151+
152+
153+
@pytest.mark.parametrize("reverse", (False, True))
154+
def test_measurable_join_interdependent(reverse):
155+
"""Test that we can obtain a proper graph when stacked RVs depend on each other"""
156+
x = at.random.normal(name="x")
157+
y_rvs = []
158+
prev_rv = x
159+
for i in range(3):
160+
next_rv = at.random.normal(prev_rv + 1, name=f"y{i}", size=(1, 2))
161+
y_rvs.append(next_rv)
162+
prev_rv = next_rv
163+
164+
if reverse:
165+
y_rvs = y_rvs[::-1]
166+
167+
ys = at.concatenate(y_rvs, axis=0)
168+
ys.name = "ys"
169+
170+
x_vv = x.clone()
171+
ys_vv = ys.clone()
172+
173+
logp = joint_logprob({x: x_vv, ys: ys_vv})
174+
assert_no_rvs(logp)
175+
176+
y0_vv = y_rvs[0].clone()
177+
y1_vv = y_rvs[1].clone()
178+
y2_vv = y_rvs[2].clone()
179+
180+
ref_logp = joint_logprob({x: x_vv, y_rvs[0]: y0_vv, y_rvs[1]: y1_vv, y_rvs[2]: y2_vv})
181+
182+
rng = np.random.default_rng()
183+
x_vv_test = rng.normal()
184+
ys_vv_test = rng.normal(size=(3, 2))
185+
np.testing.assert_allclose(
186+
logp.eval({x_vv: x_vv_test, ys_vv: ys_vv_test}),
187+
ref_logp.eval(
188+
{
189+
x_vv: x_vv_test,
190+
y0_vv: ys_vv_test[0:1],
191+
y1_vv: ys_vv_test[1:2],
192+
y2_vv: ys_vv_test[2:3],
193+
}
194+
),
195+
)
196+
197+
112198
@pytest.mark.parametrize(
113199
"size1, size2, axis, concatenate",
114200
[

0 commit comments

Comments
 (0)