diff --git a/pymc/sampling_jax.py b/pymc/sampling_jax.py index 4f9f733ec4..e3d2e1034d 100644 --- a/pymc/sampling_jax.py +++ b/pymc/sampling_jax.py @@ -435,6 +435,8 @@ def sample_numpyro_nuts( Keyword arguments for :func:`arviz.from_dict`. It also accepts a boolean as value for the ``log_likelihood`` key to indicate that the pointwise log likelihood should not be included in the returned object. + nuts_kwargs: dict, optional + Keyword arguments for :func:`numpyro.infer.NUTS`. Returns -------