@@ -541,43 +541,35 @@ def sample_settings(self):
541
541
return stats
542
542
543
543
544
- def _logp_forw (point , out_vars , vars , shared ):
544
+ def _logp_forw (point , out_vars , in_vars , shared ):
545
545
"""Compile Aesara function of the model and the input and output variables.
546
546
547
547
Parameters
548
548
----------
549
549
out_vars: List
550
550
containing :class:`pymc.Distribution` for the output variables
551
- vars : List
551
+ in_vars : List
552
552
containing :class:`pymc.Distribution` for the input variables
553
553
shared: List
554
554
containing :class:`aesara.tensor.Tensor` for depended shared data
555
555
"""
556
556
557
- # Convert expected input of discrete variables to (rounded) floats
558
- if any (var .dtype in discrete_types for var in vars ):
559
- replace_int_to_float = {}
560
- replace_float_to_round = {}
561
- new_vars = []
562
- for var in vars :
563
- if var .dtype in discrete_types :
564
- float_var = at .TensorType ("floatX" , var .broadcastable )(var .name )
565
- replace_int_to_float [var ] = float_var
566
- new_vars .append (float_var )
567
-
568
- round_float_var = at .round (float_var )
569
- round_float_var .name = var .name
570
- replace_float_to_round [float_var ] = round_float_var
557
+ # Replace integer inputs with rounded float inputs
558
+ if any (var .dtype in discrete_types for var in in_vars ):
559
+ replace_int_input = {}
560
+ new_in_vars = []
561
+ for in_var in in_vars :
562
+ if in_var .dtype in discrete_types :
563
+ float_var = at .TensorType ("floatX" , in_var .broadcastable )(in_var .name )
564
+ new_in_vars .append (float_var )
565
+ replace_int_input [in_var ] = at .round (float_var )
571
566
else :
572
- new_vars .append (var )
567
+ new_in_vars .append (in_var )
573
568
574
- replace_int_to_float .update (shared )
575
- replace_float_to_round .update (shared )
576
- out_vars = clone_replace (out_vars , replace_int_to_float , strict = False )
577
- out_vars = clone_replace (out_vars , replace_float_to_round )
578
- vars = new_vars
569
+ out_vars = clone_replace (out_vars , replace_int_input , strict = False )
570
+ in_vars = new_in_vars
579
571
580
- out_list , inarray0 = join_nonshared_inputs (point , out_vars , vars , shared )
572
+ out_list , inarray0 = join_nonshared_inputs (point , out_vars , in_vars , shared )
581
573
f = compile_rv_inplace ([inarray0 ], out_list [0 ])
582
574
f .trust_input = True
583
575
return f
0 commit comments