Skip to content

Commit 05fe251

Browse files
committed
Change base_rv size when needed
1 parent e8d5aee commit 05fe251

File tree

1 file changed

+24
-2
lines changed

1 file changed

+24
-2
lines changed

pymc3/distributions/logp.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
from aesara.tensor.var import TensorVariable
4141

4242
from pymc3.aesaraf import (
43+
change_rv_size,
4344
extract_rv_and_value_vars,
4445
floatX,
4546
rvs_to_value_vars,
@@ -287,7 +288,15 @@ def elemwise_logp(op, *args, **kwargs):
287288

288289
@_logp.register(Add)
289290
@_logp.register(Mul)
290-
def linear_logp(op, var, rvs_to_values, *linear_inputs, transformed=True, **kwargs):
291+
def linear_logp(
292+
op,
293+
var,
294+
rvs_to_values,
295+
*linear_inputs,
296+
transformed=True,
297+
sum=False,
298+
**kwargs,
299+
):
291300

292301
if len(linear_inputs) != 2:
293302
raise ValueError(f"Expected 2 inputs but got: {len(linear_inputs)}")
@@ -324,7 +333,17 @@ def linear_logp(op, var, rvs_to_values, *linear_inputs, transformed=True, **kwar
324333
base_value = var_value - constant
325334
else:
326335
base_value = var_value / constant
327-
var_logp = logpt(base_rv, {base_rv: base_value}, transformed=transformed, **kwargs)
336+
337+
# Change base rv shape if needed
338+
if isinstance(base_rv.owner.op, RandomVariable):
339+
ndim_supp = base_rv.owner.op.ndim_supp
340+
if ndim_supp > 0:
341+
new_size = base_value.shape[:-ndim_supp]
342+
else:
343+
new_size = base_value.shape
344+
base_rv = change_rv_size(base_rv, new_size)
345+
346+
var_logp = logpt(base_rv, {base_rv: base_value}, transformed=transformed, sum=False, **kwargs)
328347

329348
# Apply product jacobian correction for continuous rvs
330349
if isinstance(op, Mul) and "float" in base_rv.dtype:
@@ -337,6 +356,9 @@ def linear_logp(op, var, rvs_to_values, *linear_inputs, transformed=True, **kwar
337356
initial_replacements=None,
338357
)
339358

359+
if sum:
360+
var_logp = at.sum(var_logp)
361+
340362
var_logp.name = f"__logp_{var.name}"
341363
return var_logp
342364

0 commit comments

Comments
 (0)