@@ -271,10 +271,8 @@ def _logp(
271271
272272
273273@_logp .register (Elemwise )
274- def logp_elemwise (op , * args , ** kwargs ):
275- if hasattr (op , "scalar_op" ):
276- return _logp (op .scalar_op , * args , ** kwargs )
277- raise NotImplementedError
274+ def elemwise_logp (op , * args , ** kwargs ):
275+ return _logp (op .scalar_op , * args , ** kwargs )
278276
279277
280278# TODO: Implement DimShuffle logp?
@@ -287,14 +285,17 @@ def logp_elemwise(op, *args, **kwargs):
287285# raise NotImplementedError
288286
289287
290- def find_rv_branch ( inputs ):
291- """
292- Helper function to find which input branch(es) contain unregistered random variables
293- """
294- rv_branch = []
295- no_rv_branch = []
288+ @ _logp . register ( Add )
289+ @ _logp . register ( Mul )
290+ def linear_logp ( op , var , rvs_to_values , * linear_inputs , ** kwargs ):
291+
292+ if len ( linear_inputs ) != 2 :
293+ raise ValueError ( f"Expected 2 inputs but got: { len ( linear_inputs ) } " )
296294
297- for inp in inputs :
295+ # Find base_rv and constant inputs
296+ base_rv = []
297+ constant = []
298+ for inp in linear_inputs :
298299 res_ancestors = list (walk_model ((inp ,), walk_past_rvs = True ))
299300 # unregistered variables do not contain a value_var tag
300301 res_unregistered_ancestors = [
@@ -305,94 +306,47 @@ def find_rv_branch(inputs):
305306 and not getattr (v .tag , "value_var" , False )
306307 ]
307308 if res_unregistered_ancestors :
308- rv_branch .append (inp )
309+ base_rv .append (inp )
309310 else :
310- no_rv_branch .append (inp )
311-
312- return rv_branch , no_rv_branch
313-
314-
315- @_logp .register (Add )
316- def add_logp (op , var , rvs_to_values , * add_inputs , ** kwargs ):
317-
318- if len (add_inputs ) != 2 :
319- raise ValueError (f"Expected 2 inputs but got: { len (add_inputs )} " )
320-
321- base_rv , loc = find_rv_branch (add_inputs )
311+ constant .append (inp )
322312
323313 if len (base_rv ) != 1 :
324314 raise NotImplementedError (
325- f"Logp of addition requires one branch with an unregistered RandomVariable but got { len (base_rv )} "
315+ f"Logp of linear transform requires one branch with an unregistered RandomVariable but got { len (base_rv )} "
326316 )
327317
328- var_value = rvs_to_values .get (var , var )
329- loc = loc [0 ]
330318 base_rv = base_rv [0 ]
331- base_value = base_rv .type ()
332-
333- logp_base_rv = logpt (base_rv , {base_rv : base_value }, ** kwargs )
334- fgraph = FunctionGraph (
335- [i for i in graph_inputs ((logp_base_rv ,)) if not isinstance (i , Constant )],
336- [logp_base_rv ],
337- clone = False ,
338- )
339- fgraph .replace (base_value , var_value - loc , import_missing = True )
340- logp_add_rv = fgraph .outputs [0 ]
341-
342- # Replace rvs in graph
343- # TODO: This shouldn't be here
344- (logp_add_rv ,), _ = rvs_to_value_vars (
345- (logp_add_rv ,),
346- apply_transforms = True , # Change this
347- initial_replacements = None ,
348- )
349-
350- logp_add_rv .name = f"__logp_{ var .name } "
351-
352- return logp_add_rv
353-
354-
355- @_logp .register (Mul )
356- def mul_logp (op , var , rvs_to_values , * mul_inputs , ** kwargs ):
357-
358- if len (mul_inputs ) != 2 :
359- raise ValueError (f"Expected 2 inputs but got: { len (mul_inputs )} " )
360-
361- base_rv , scale = find_rv_branch (mul_inputs )
362-
363- if len (base_rv ) != 1 :
364- raise NotImplementedError (
365- f"Logp of product requires one branch with an unregistered RandomVariable but got { len (base_rv )} "
366- )
367-
319+ constant = constant [0 ]
368320 var_value = rvs_to_values .get (var , var )
369- scale = scale [0 ]
370- base_rv = base_rv [0 ]
371- base_value = base_rv .type ()
372321
322+ # Get logp of base_rv
323+ base_value = base_rv .type ()
373324 logp_base_rv = logpt (base_rv , {base_rv : base_value }, ** kwargs )
374325 fgraph = FunctionGraph (
375326 [i for i in graph_inputs ((logp_base_rv ,)) if not isinstance (i , Constant )],
376- [logp_base_rv ],
327+ outputs = [logp_base_rv ],
377328 clone = False ,
378329 )
379330
380- # TODO: This is not correct for discrete variables
381- # TODO: Undefined behavior for scale = 0
382- fgraph .replace (base_value , var_value / scale , import_missing = True )
383- logp_mul_rv = fgraph .outputs [0 ] - at .log (at .abs_ (scale ))
331+ # Transform base_rv and apply jacobian correction (for continuous rvs)
332+ if isinstance (op , Add ):
333+ fgraph .replace (base_value , var_value - constant , import_missing = True )
334+ logp_linear_rv = fgraph .outputs [0 ]
335+ elif isinstance (op , Mul ):
336+ fgraph .replace (base_value , var_value / constant , import_missing = True )
337+ logp_linear_rv = fgraph .outputs [0 ]
338+ if "float" in base_rv .dtype :
339+ logp_linear_rv -= at .log (at .abs_ (constant ))
384340
385341 # Replace rvs in graph
386- # TODO: This shouldn't be here
387- (logp_mul_rv ,), _ = rvs_to_value_vars (
388- (logp_mul_rv ,),
389- apply_transforms = True , # Change this
342+ (logp_linear_rv ,), _ = rvs_to_value_vars (
343+ (logp_linear_rv ,),
344+ apply_transforms = kwargs .get ("transformed" , True ),
390345 initial_replacements = None ,
391346 )
392347
393- logp_mul_rv .name = f"__logp_{ var .name } "
394-
395- return logp_mul_rv
348+ logp_linear_rv .name = f"__logp_{ var .name } "
349+ return logp_linear_rv
396350
397351
398352def convert_indices (indices , entry ):
0 commit comments