Skip to content

Use optional tensorflow implementation for missing JAX Ops #256

Closed
@jbh1128d1

Description

@jbh1128d1

I tried to sample a truncated normal distribution with Jax that threw an error. This worked fine with the non-jax sampler...but extremely slower. Erros is as follows:

Please provide a minimal, self-contained, and reproducible example.

from scipy.stats import truncnorm 
numargs = truncnorm .numargs 
a, b = 0,10
import numpy as np 
quantile = np.arange (0.01, 1, 0.1) 

# Random Variates 
R = truncnorm .rvs(a, b, size = 1000) 
x = R*np.random.randn(1000)
with pm.Model() as example_model:
    
    b = pm.Normal('b')
    
    mu = b*x
    
    sigma = pm.HalfNormal('sigma')
    
    eaches = pm.TruncatedNormal('predicted_eaches',
                                    mu=mu,
                                    sigma=sigma,
                                    lower=0,
                                    observed=R)

    idata = pm.sampling_jax.sample_numpyro_nuts(draws = 500, tune=500, target_accept = .95)

Please provide the full traceback.

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
/tmp/ipykernel_12814/714985784.py in <module>
     13                                     observed=R)
     14 
---> 15     idata = pm.sampling_jax.sample_numpyro_nuts(draws = 500, tune=500, target_accept = .95)

/opt/conda/lib/python3.7/site-packages/pymc/sampling_jax.py in sample_numpyro_nuts(draws, tune, chains, target_accept, random_seed, initvals, model, var_names, progress_bar, keep_untransformed, chain_method, postprocessing_backend, idata_kwargs, nuts_kwargs)
    481     )
    482 
--> 483     logp_fn = get_jaxified_logp(model, negative_logp=False)
    484 
    485     if nuts_kwargs is None:

/opt/conda/lib/python3.7/site-packages/pymc/sampling_jax.py in get_jaxified_logp(model, negative_logp)
    104     if not negative_logp:
    105         model_logp = -model_logp
--> 106     logp_fn = get_jaxified_graph(inputs=model.value_vars, outputs=[model_logp])
    107 
    108     def logp_fn_wrap(x):

/opt/conda/lib/python3.7/site-packages/pymc/sampling_jax.py in get_jaxified_graph(inputs, outputs)
     97 
     98     # We now jaxify the optimized fgraph
---> 99     return jax_funcify(fgraph)
    100 
    101 

/opt/conda/lib/python3.7/functools.py in wrapper(*args, **kw)
    838                             '1 positional argument')
    839 
--> 840         return dispatch(args[0].__class__)(*args, **kw)
    841 
    842     funcname = getattr(func, '__name__', 'singledispatch function')

/opt/conda/lib/python3.7/site-packages/aesara/link/jax/dispatch.py in jax_funcify_FunctionGraph(fgraph, node, fgraph_name, **kwargs)
    682         type_conversion_fn=jax_typify,
    683         fgraph_name=fgraph_name,
--> 684         **kwargs,
    685     )
    686 

/opt/conda/lib/python3.7/site-packages/aesara/link/utils.py in fgraph_to_python(fgraph, op_conversion_fn, type_conversion_fn, order, input_storage, output_storage, storage_map, fgraph_name, global_env, local_env, get_name_for_object, squeeze_output, **kwargs)
    740     for node in order:
    741         compiled_func = op_conversion_fn(
--> 742             node.op, node=node, storage_map=storage_map, **kwargs
    743         )
    744 

/opt/conda/lib/python3.7/functools.py in wrapper(*args, **kw)
    838                             '1 positional argument')
    839 
--> 840         return dispatch(args[0].__class__)(*args, **kw)
    841 
    842     funcname = getattr(func, '__name__', 'singledispatch function')

/opt/conda/lib/python3.7/site-packages/aesara/link/jax/dispatch.py in jax_funcify_Elemwise(op, **kwargs)
    399 def jax_funcify_Elemwise(op, **kwargs):
    400     scalar_op = op.scalar_op
--> 401     return jax_funcify(scalar_op, **kwargs)
    402 
    403 

/opt/conda/lib/python3.7/functools.py in wrapper(*args, **kw)
    838                             '1 positional argument')
    839 
--> 840         return dispatch(args[0].__class__)(*args, **kw)
    841 
    842     funcname = getattr(func, '__name__', 'singledispatch function')

/opt/conda/lib/python3.7/site-packages/aesara/link/jax/dispatch.py in jax_funcify_Composite(op, vectorize, **kwargs)
    404 @jax_funcify.register(Composite)
    405 def jax_funcify_Composite(op, vectorize=True, **kwargs):
--> 406     jax_impl = jax_funcify(op.fgraph)
    407 
    408     def composite(*args):

/opt/conda/lib/python3.7/functools.py in wrapper(*args, **kw)
    838                             '1 positional argument')
    839 
--> 840         return dispatch(args[0].__class__)(*args, **kw)
    841 
    842     funcname = getattr(func, '__name__', 'singledispatch function')

/opt/conda/lib/python3.7/site-packages/aesara/link/jax/dispatch.py in jax_funcify_FunctionGraph(fgraph, node, fgraph_name, **kwargs)
    682         type_conversion_fn=jax_typify,
    683         fgraph_name=fgraph_name,
--> 684         **kwargs,
    685     )
    686 

/opt/conda/lib/python3.7/site-packages/aesara/link/utils.py in fgraph_to_python(fgraph, op_conversion_fn, type_conversion_fn, order, input_storage, output_storage, storage_map, fgraph_name, global_env, local_env, get_name_for_object, squeeze_output, **kwargs)
    740     for node in order:
    741         compiled_func = op_conversion_fn(
--> 742             node.op, node=node, storage_map=storage_map, **kwargs
    743         )
    744 

/opt/conda/lib/python3.7/functools.py in wrapper(*args, **kw)
    838                             '1 positional argument')
    839 
--> 840         return dispatch(args[0].__class__)(*args, **kw)
    841 
    842     funcname = getattr(func, '__name__', 'singledispatch function')

/opt/conda/lib/python3.7/site-packages/aesara/link/jax/dispatch.py in jax_funcify_ScalarOp(op, **kwargs)
    158 
    159     if "." in func_name:
--> 160         jnp_func = reduce(getattr, [jax] + func_name.split("."))
    161     else:
    162         jnp_func = getattr(jnp, func_name)

AttributeError: module 'jax.scipy.special' has no attribute 'erfcx'

Versions and main components

  • PyMC/PyMC3 Version: 4.0.1
  • Aesara/Theano Version: 2.7.2
  • Python Version: 3.7.12
  • Operating system: Linux
  • How did you install PyMC/PyMC3: (conda/pip) conda

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions