40
40
from aesara .tensor .var import TensorVariable
41
41
42
42
from pymc3 .aesaraf import (
43
+ change_rv_size ,
43
44
extract_rv_and_value_vars ,
44
45
floatX ,
45
46
rvs_to_value_vars ,
@@ -287,7 +288,15 @@ def elemwise_logp(op, *args, **kwargs):
287
288
288
289
@_logp .register (Add )
289
290
@_logp .register (Mul )
290
- def linear_logp (op , var , rvs_to_values , * linear_inputs , transformed = True , ** kwargs ):
291
+ def linear_logp (
292
+ op ,
293
+ var ,
294
+ rvs_to_values ,
295
+ * linear_inputs ,
296
+ transformed = True ,
297
+ sum = False ,
298
+ ** kwargs ,
299
+ ):
291
300
292
301
if len (linear_inputs ) != 2 :
293
302
raise ValueError (f"Expected 2 inputs but got: { len (linear_inputs )} " )
@@ -324,7 +333,17 @@ def linear_logp(op, var, rvs_to_values, *linear_inputs, transformed=True, **kwar
324
333
base_value = var_value - constant
325
334
else :
326
335
base_value = var_value / constant
327
- var_logp = logpt (base_rv , {base_rv : base_value }, transformed = transformed , ** kwargs )
336
+
337
+ # Change base rv shape if needed
338
+ if isinstance (base_rv .owner .op , RandomVariable ):
339
+ ndim_supp = base_rv .owner .op .ndim_supp
340
+ if ndim_supp > 0 :
341
+ new_size = base_value .shape [:- ndim_supp ]
342
+ else :
343
+ new_size = base_value .shape
344
+ base_rv = change_rv_size (base_rv , new_size )
345
+
346
+ var_logp = logpt (base_rv , {base_rv : base_value }, transformed = transformed , sum = False , ** kwargs )
328
347
329
348
# Apply product jacobian correction for continuous rvs
330
349
if isinstance (op , Mul ) and "float" in base_rv .dtype :
@@ -337,6 +356,9 @@ def linear_logp(op, var, rvs_to_values, *linear_inputs, transformed=True, **kwar
337
356
initial_replacements = None ,
338
357
)
339
358
359
+ if sum :
360
+ var_logp = at .sum (var_logp )
361
+
340
362
var_logp .name = f"__logp_{ var .name } "
341
363
return var_logp
342
364
0 commit comments