Closed
Description
I tried to sample a truncated normal distribution with Jax that threw an error. This worked fine with the non-jax sampler...but extremely slower. Erros is as follows:
Please provide a minimal, self-contained, and reproducible example.
from scipy.stats import truncnorm
numargs = truncnorm .numargs
a, b = 0,10
import numpy as np
quantile = np.arange (0.01, 1, 0.1)
# Random Variates
R = truncnorm .rvs(a, b, size = 1000)
x = R*np.random.randn(1000)
with pm.Model() as example_model:
b = pm.Normal('b')
mu = b*x
sigma = pm.HalfNormal('sigma')
eaches = pm.TruncatedNormal('predicted_eaches',
mu=mu,
sigma=sigma,
lower=0,
observed=R)
idata = pm.sampling_jax.sample_numpyro_nuts(draws = 500, tune=500, target_accept = .95)
Please provide the full traceback.
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
/tmp/ipykernel_12814/714985784.py in <module>
13 observed=R)
14
---> 15 idata = pm.sampling_jax.sample_numpyro_nuts(draws = 500, tune=500, target_accept = .95)
/opt/conda/lib/python3.7/site-packages/pymc/sampling_jax.py in sample_numpyro_nuts(draws, tune, chains, target_accept, random_seed, initvals, model, var_names, progress_bar, keep_untransformed, chain_method, postprocessing_backend, idata_kwargs, nuts_kwargs)
481 )
482
--> 483 logp_fn = get_jaxified_logp(model, negative_logp=False)
484
485 if nuts_kwargs is None:
/opt/conda/lib/python3.7/site-packages/pymc/sampling_jax.py in get_jaxified_logp(model, negative_logp)
104 if not negative_logp:
105 model_logp = -model_logp
--> 106 logp_fn = get_jaxified_graph(inputs=model.value_vars, outputs=[model_logp])
107
108 def logp_fn_wrap(x):
/opt/conda/lib/python3.7/site-packages/pymc/sampling_jax.py in get_jaxified_graph(inputs, outputs)
97
98 # We now jaxify the optimized fgraph
---> 99 return jax_funcify(fgraph)
100
101
/opt/conda/lib/python3.7/functools.py in wrapper(*args, **kw)
838 '1 positional argument')
839
--> 840 return dispatch(args[0].__class__)(*args, **kw)
841
842 funcname = getattr(func, '__name__', 'singledispatch function')
/opt/conda/lib/python3.7/site-packages/aesara/link/jax/dispatch.py in jax_funcify_FunctionGraph(fgraph, node, fgraph_name, **kwargs)
682 type_conversion_fn=jax_typify,
683 fgraph_name=fgraph_name,
--> 684 **kwargs,
685 )
686
/opt/conda/lib/python3.7/site-packages/aesara/link/utils.py in fgraph_to_python(fgraph, op_conversion_fn, type_conversion_fn, order, input_storage, output_storage, storage_map, fgraph_name, global_env, local_env, get_name_for_object, squeeze_output, **kwargs)
740 for node in order:
741 compiled_func = op_conversion_fn(
--> 742 node.op, node=node, storage_map=storage_map, **kwargs
743 )
744
/opt/conda/lib/python3.7/functools.py in wrapper(*args, **kw)
838 '1 positional argument')
839
--> 840 return dispatch(args[0].__class__)(*args, **kw)
841
842 funcname = getattr(func, '__name__', 'singledispatch function')
/opt/conda/lib/python3.7/site-packages/aesara/link/jax/dispatch.py in jax_funcify_Elemwise(op, **kwargs)
399 def jax_funcify_Elemwise(op, **kwargs):
400 scalar_op = op.scalar_op
--> 401 return jax_funcify(scalar_op, **kwargs)
402
403
/opt/conda/lib/python3.7/functools.py in wrapper(*args, **kw)
838 '1 positional argument')
839
--> 840 return dispatch(args[0].__class__)(*args, **kw)
841
842 funcname = getattr(func, '__name__', 'singledispatch function')
/opt/conda/lib/python3.7/site-packages/aesara/link/jax/dispatch.py in jax_funcify_Composite(op, vectorize, **kwargs)
404 @jax_funcify.register(Composite)
405 def jax_funcify_Composite(op, vectorize=True, **kwargs):
--> 406 jax_impl = jax_funcify(op.fgraph)
407
408 def composite(*args):
/opt/conda/lib/python3.7/functools.py in wrapper(*args, **kw)
838 '1 positional argument')
839
--> 840 return dispatch(args[0].__class__)(*args, **kw)
841
842 funcname = getattr(func, '__name__', 'singledispatch function')
/opt/conda/lib/python3.7/site-packages/aesara/link/jax/dispatch.py in jax_funcify_FunctionGraph(fgraph, node, fgraph_name, **kwargs)
682 type_conversion_fn=jax_typify,
683 fgraph_name=fgraph_name,
--> 684 **kwargs,
685 )
686
/opt/conda/lib/python3.7/site-packages/aesara/link/utils.py in fgraph_to_python(fgraph, op_conversion_fn, type_conversion_fn, order, input_storage, output_storage, storage_map, fgraph_name, global_env, local_env, get_name_for_object, squeeze_output, **kwargs)
740 for node in order:
741 compiled_func = op_conversion_fn(
--> 742 node.op, node=node, storage_map=storage_map, **kwargs
743 )
744
/opt/conda/lib/python3.7/functools.py in wrapper(*args, **kw)
838 '1 positional argument')
839
--> 840 return dispatch(args[0].__class__)(*args, **kw)
841
842 funcname = getattr(func, '__name__', 'singledispatch function')
/opt/conda/lib/python3.7/site-packages/aesara/link/jax/dispatch.py in jax_funcify_ScalarOp(op, **kwargs)
158
159 if "." in func_name:
--> 160 jnp_func = reduce(getattr, [jax] + func_name.split("."))
161 else:
162 jnp_func = getattr(jnp, func_name)
AttributeError: module 'jax.scipy.special' has no attribute 'erfcx'
Versions and main components
- PyMC/PyMC3 Version: 4.0.1
- Aesara/Theano Version: 2.7.2
- Python Version: 3.7.12
- Operating system: Linux
- How did you install PyMC/PyMC3: (conda/pip) conda