Skip to content

Commit 7ec5ca1

Browse files
committed
Avoid unnecessary double clone_replace in smc._logp_forw
1 parent fd7fc60 commit 7ec5ca1

File tree

1 file changed

+15
-23
lines changed

1 file changed

+15
-23
lines changed

pymc/smc/smc.py

Lines changed: 15 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -541,43 +541,35 @@ def sample_settings(self):
541541
return stats
542542

543543

544-
def _logp_forw(point, out_vars, vars, shared):
544+
def _logp_forw(point, out_vars, in_vars, shared):
545545
"""Compile Aesara function of the model and the input and output variables.
546546
547547
Parameters
548548
----------
549549
out_vars: List
550550
containing :class:`pymc.Distribution` for the output variables
551-
vars: List
551+
in_vars: List
552552
containing :class:`pymc.Distribution` for the input variables
553553
shared: List
554554
containing :class:`aesara.tensor.Tensor` for depended shared data
555555
"""
556556

557-
# Convert expected input of discrete variables to (rounded) floats
558-
if any(var.dtype in discrete_types for var in vars):
559-
replace_int_to_float = {}
560-
replace_float_to_round = {}
561-
new_vars = []
562-
for var in vars:
563-
if var.dtype in discrete_types:
564-
float_var = at.TensorType("floatX", var.broadcastable)(var.name)
565-
replace_int_to_float[var] = float_var
566-
new_vars.append(float_var)
567-
568-
round_float_var = at.round(float_var)
569-
round_float_var.name = var.name
570-
replace_float_to_round[float_var] = round_float_var
557+
# Replace integer inputs with rounded float inputs
558+
if any(var.dtype in discrete_types for var in in_vars):
559+
replace_int_input = {}
560+
new_in_vars = []
561+
for in_var in in_vars:
562+
if in_var.dtype in discrete_types:
563+
float_var = at.TensorType("floatX", in_var.broadcastable)(in_var.name)
564+
new_in_vars.append(float_var)
565+
replace_int_input[in_var] = at.round(float_var)
571566
else:
572-
new_vars.append(var)
567+
new_in_vars.append(in_var)
573568

574-
replace_int_to_float.update(shared)
575-
replace_float_to_round.update(shared)
576-
out_vars = clone_replace(out_vars, replace_int_to_float, strict=False)
577-
out_vars = clone_replace(out_vars, replace_float_to_round)
578-
vars = new_vars
569+
out_vars = clone_replace(out_vars, replace_int_input, strict=False)
570+
in_vars = new_in_vars
579571

580-
out_list, inarray0 = join_nonshared_inputs(point, out_vars, vars, shared)
572+
out_list, inarray0 = join_nonshared_inputs(point, out_vars, in_vars, shared)
581573
f = compile_rv_inplace([inarray0], out_list[0])
582574
f.trust_input = True
583575
return f

0 commit comments

Comments
 (0)