@@ -124,8 +124,9 @@ def _get_log_likelihood(model, samples):
124
124
for v in model .observed_RVs :
125
125
logp_v = replace_shared_variables ([logpt (v )])
126
126
fgraph = FunctionGraph (model .value_vars , logp_v , clone = False )
127
+ optimize_graph (fgraph , include = ["fast_run" ], exclude = ["cxx_only" , "BlasOpt" ])
127
128
jax_fn = jax_funcify (fgraph )
128
- result = jax .vmap (jax .vmap (jax_fn ))(* samples )[0 ]
129
+ result = jax .jit ( jax . vmap (jax .vmap (jax_fn ) ))(* samples )[0 ]
129
130
data [v .name ] = result
130
131
return data
131
132
@@ -150,6 +151,20 @@ def sample_numpyro_nuts(
150
151
151
152
vars_to_sample = list (get_default_varnames (var_names , include_transformed = keep_untransformed ))
152
153
154
+ coords = {
155
+ cname : np .array (cvals ) if isinstance (cvals , tuple ) else cvals
156
+ for cname , cvals in model .coords .items ()
157
+ if cvals is not None
158
+ }
159
+
160
+ if hasattr (model , "RV_dims" ):
161
+ dims = {
162
+ var_name : [dim for dim in dims if dim is not None ]
163
+ for var_name , dims in model .RV_dims .items ()
164
+ }
165
+ else :
166
+ dims = {}
167
+
153
168
tic1 = pd .Timestamp .now ()
154
169
print ("Compiling..." , file = sys .stdout )
155
170
@@ -213,6 +228,7 @@ def sample_numpyro_nuts(
213
228
mcmc_samples = {}
214
229
for v in vars_to_sample :
215
230
fgraph = FunctionGraph (model .value_vars , [v ], clone = False )
231
+ optimize_graph (fgraph , include = ["fast_run" ], exclude = ["cxx_only" , "BlasOpt" ])
216
232
jax_fn = jax_funcify (fgraph )
217
233
result = jax .vmap (jax .vmap (jax_fn ))(* raw_mcmc_samples )[0 ]
218
234
mcmc_samples [v .name ] = result
@@ -221,11 +237,13 @@ def sample_numpyro_nuts(
221
237
print ("Transformation time = " , tic4 - tic3 , file = sys .stdout )
222
238
223
239
posterior = mcmc_samples
224
- az_posterior = az .from_dict (posterior = posterior )
225
-
226
- az_obs = az .from_dict (observed_data = find_observations (model ))
227
- az_stats = az .from_dict (sample_stats = _sample_stats_to_xarray (pmap_numpyro ))
228
- az_ll = az .from_dict (log_likelihood = _get_log_likelihood (model , raw_mcmc_samples ))
229
- az_trace = az .concat (az_posterior , az_ll , az_obs , az_stats )
240
+ az_trace = az .from_dict (
241
+ posterior = posterior ,
242
+ log_likelihood = _get_log_likelihood (model , raw_mcmc_samples ),
243
+ observed_data = find_observations (model ),
244
+ sample_stats = _sample_stats_to_xarray (pmap_numpyro ),
245
+ coords = coords ,
246
+ dims = dims ,
247
+ )
230
248
231
249
return az_trace
0 commit comments