Skip to content

Remove workaround in InferenceData conversion; require ArviZ >=0.11.4 #5060

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Oct 11, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion conda-envs/environment-dev-py37.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ channels:
- defaults
dependencies:
- aesara>=2.1.0
- arviz>=0.11.2
- arviz>=0.11.4
- cachetools>=4.2.1
- cloudpickle
- fastprogress>=0.2.0
Expand Down
2 changes: 1 addition & 1 deletion conda-envs/environment-dev-py38.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ channels:
- defaults
dependencies:
- aesara>=2.1.0
- arviz>=0.11.2
- arviz>=0.11.4
- cachetools>=4.2.1
- cloudpickle
- fastprogress>=0.2.0
Expand Down
2 changes: 1 addition & 1 deletion conda-envs/environment-dev-py39.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ channels:
- defaults
dependencies:
- aesara>=2.1.0
- arviz>=0.11.2
- arviz>=0.11.4
- cachetools>=4.2.1
- cloudpickle
- fastprogress>=0.2.0
Expand Down
2 changes: 1 addition & 1 deletion conda-envs/windows-environment-dev-py38.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ channels:
dependencies:
# base dependencies (see install guide for Windows)
- aesara>=2.1.0
- arviz>=0.11.2
- arviz>=0.11.4
- cachetools>=4.2.1
- cloudpickle
- fastprogress>=0.2.0
Expand Down
57 changes: 2 additions & 55 deletions pymc/backends/arviz.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@
Any,
Dict,
Iterable,
List,
Mapping,
Optional,
Tuple,
Union,
Expand All @@ -21,9 +19,7 @@
from aesara.tensor.sharedvar import SharedVariable
from aesara.tensor.subtensor import AdvancedIncSubtensor, AdvancedIncSubtensor1
from arviz import InferenceData, concat, rcParams
from arviz.data.base import CoordSpec, DimSpec
from arviz.data.base import dict_to_dataset as _dict_to_dataset
from arviz.data.base import generate_dims_coords, make_attrs, requires
from arviz.data.base import CoordSpec, DimSpec, dict_to_dataset, requires

import pymc

Expand Down Expand Up @@ -101,42 +97,6 @@ def insert(self, k: str, v, idx: int):
self.trace_dict[k][idx, :] = v


def dict_to_dataset(
data,
library=None,
coords=None,
dims=None,
attrs=None,
default_dims=None,
skip_event_dims=None,
index_origin=None,
):
"""Temporal workaround for dict_to_dataset.

Once ArviZ>0.11.2 release is available, only two changes are needed for everything to work.
1) this should be deleted, 2) dict_to_dataset should be imported as is from arviz, no underscore,
also remove unnecessary imports
"""
if default_dims is None:
return _dict_to_dataset(
data,
attrs=attrs,
library=library,
coords=coords,
dims=dims,
skip_event_dims=skip_event_dims,
)
else:
out_data = {}
for name, vals in data.items():
vals = np.atleast_1d(vals)
val_dims = dims.get(name)
val_dims, crds = generate_dims_coords(vals.shape, name, dims=val_dims, coords=coords)
crds = {key: xr.IndexVariable((key,), data=crds[key]) for key in val_dims}
out_data[name] = xr.DataArray(vals, dims=val_dims, coords=crds)
return xr.Dataset(data_vars=out_data, attrs=make_attrs(attrs=attrs, library=library))


class InferenceDataConverter: # pylint: disable=too-many-instance-attributes
"""Encapsulate InferenceData specific logic."""

Expand All @@ -160,7 +120,6 @@ def __init__(
model=None,
save_warmup: Optional[bool] = None,
density_dist_obs: bool = True,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

doesn't need to be in this PR, but we can remove this argument. densitydist api will be fixed in v4 so this is no longer necessary.

index_origin: Optional[int] = None,
):

self.save_warmup = rcParams["data.save_warmup"] if save_warmup is None else save_warmup
Expand Down Expand Up @@ -196,7 +155,6 @@ def __init__(
self.posterior_predictive = posterior_predictive
self.log_likelihood = log_likelihood
self.predictions = predictions
self.index_origin = rcParams["data.index_origin"] if index_origin is None else index_origin

def arbitrary_element(dct: Dict[Any, np.ndarray]) -> np.ndarray:
return next(iter(dct.values()))
Expand Down Expand Up @@ -344,15 +302,13 @@ def posterior_to_xarray(self):
coords=self.coords,
dims=self.dims,
attrs=self.attrs,
index_origin=self.index_origin,
),
dict_to_dataset(
data_warmup,
library=pymc,
coords=self.coords,
dims=self.dims,
attrs=self.attrs,
index_origin=self.index_origin,
),
)

Expand Down Expand Up @@ -386,15 +342,13 @@ def sample_stats_to_xarray(self):
dims=None,
coords=self.coords,
attrs=self.attrs,
index_origin=self.index_origin,
),
dict_to_dataset(
data_warmup,
library=pymc,
dims=None,
coords=self.coords,
attrs=self.attrs,
index_origin=self.index_origin,
),
)

Expand Down Expand Up @@ -427,15 +381,13 @@ def log_likelihood_to_xarray(self):
dims=self.dims,
coords=self.coords,
skip_event_dims=True,
index_origin=self.index_origin,
),
dict_to_dataset(
data_warmup,
library=pymc,
dims=self.dims,
coords=self.coords,
skip_event_dims=True,
index_origin=self.index_origin,
),
)

Expand All @@ -456,9 +408,7 @@ def translate_posterior_predictive_dict_to_xarray(self, dct) -> xr.Dataset:
"This can mean that some draws or even whole chains are not represented.",
k,
)
return dict_to_dataset(
data, library=pymc, coords=self.coords, dims=self.dims, index_origin=self.index_origin
)
return dict_to_dataset(data, library=pymc, coords=self.coords, dims=self.dims)

@requires(["posterior_predictive"])
def posterior_predictive_to_xarray(self):
Expand Down Expand Up @@ -493,7 +443,6 @@ def priors_to_xarray(self):
library=pymc,
coords=self.coords,
dims=self.dims,
index_origin=self.index_origin,
)
)
return priors_dict
Expand All @@ -510,7 +459,6 @@ def observed_data_to_xarray(self):
coords=self.coords,
dims=self.dims,
default_dims=[],
index_origin=self.index_origin,
)

@requires(["trace", "predictions"])
Expand Down Expand Up @@ -557,7 +505,6 @@ def is_data(name, var) -> bool:
coords=self.coords,
dims=self.dims,
default_dims=[],
index_origin=self.index_origin,
)

def to_inference_data(self):
Expand Down
8 changes: 4 additions & 4 deletions pymc/tests/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1114,10 +1114,6 @@ def test_wald_logp(self):
decimal=select_by_precision(float64=6, float32=1),
)

@pytest.mark.xfail(
condition=(aesara.config.floatX == "float32"),
reason="Poor CDF in SciPy. See scipy/scipy#869 for details.",
)
def test_wald_logcdf(self):
self.check_logcdf(
Wald,
Expand Down Expand Up @@ -1273,6 +1269,10 @@ def modified_scipy_hypergeom_logcdf(value, N, k, n):
{"N": NatSmall, "k": NatSmall, "n": NatSmall},
)

@pytest.mark.xfail(
condition=(aesara.config.floatX == "float32"),
reason="SciPy log CDF stopped working after un-pinning NumPy version.",
)
def test_negative_binomial(self):
def scipy_mu_alpha_logpmf(value, mu, alpha):
return sp.nbinom.logpmf(value, alpha, 1 - mu / (mu + alpha))
Expand Down
2 changes: 1 addition & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# See that file for comments about the need/usage of each dependency.

aesara>=2.1.0
arviz>=0.11.2
arviz>=0.11.4
cachetools>=4.2.1
cloudpickle
fastprogress>=0.2.0
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
aesara>=2.1.0
arviz>=0.11.2
arviz>=0.11.4
cachetools>=4.2.1
cloudpickle
fastprogress>=0.2.0
Expand Down