@@ -74,7 +74,7 @@ def get_jaxified_logp(model: Model) -> Callable:
74
74
75
75
logpt = replace_shared_variables ([model .logpt ()])[0 ]
76
76
77
- logpt_fgraph = FunctionGraph (outputs = [logpt ], clone = False )
77
+ logpt_fgraph = FunctionGraph (outputs = [logpt ], clone = True )
78
78
optimize_graph (logpt_fgraph , include = ["fast_run" ], exclude = ["cxx_only" , "BlasOpt" ])
79
79
80
80
# We now jaxify the optimized fgraph
@@ -123,7 +123,7 @@ def _get_log_likelihood(model, samples):
123
123
data = {}
124
124
for v in model .observed_RVs :
125
125
logp_v = replace_shared_variables ([model .logpt (v , sum = False )[0 ]])
126
- fgraph = FunctionGraph (model .value_vars , logp_v , clone = False )
126
+ fgraph = FunctionGraph (model .value_vars , logp_v , clone = True )
127
127
optimize_graph (fgraph , include = ["fast_run" ], exclude = ["cxx_only" , "BlasOpt" ])
128
128
jax_fn = jax_funcify (fgraph )
129
129
result = jax .jit (jax .vmap (jax .vmap (jax_fn )))(* samples )[0 ]
@@ -229,7 +229,7 @@ def sample_numpyro_nuts(
229
229
print ("Transforming variables..." , file = sys .stdout )
230
230
mcmc_samples = {}
231
231
for v in vars_to_sample :
232
- fgraph = FunctionGraph (model .value_vars , [v ], clone = False )
232
+ fgraph = FunctionGraph (model .value_vars , [v ], clone = True )
233
233
optimize_graph (fgraph , include = ["fast_run" ], exclude = ["cxx_only" , "BlasOpt" ])
234
234
jax_fn = jax_funcify (fgraph )
235
235
result = jax .vmap (jax .vmap (jax_fn ))(* raw_mcmc_samples )[0 ]
0 commit comments