diff --git a/pymc/backends/arviz.py b/pymc/backends/arviz.py index 1aed0b6c4e..4c710f0787 100644 --- a/pymc/backends/arviz.py +++ b/pymc/backends/arviz.py @@ -131,9 +131,9 @@ def dict_to_dataset( for name, vals in data.items(): vals = np.atleast_1d(vals) val_dims = dims.get(name) - val_dims, coords = generate_dims_coords(vals.shape, name, dims=val_dims, coords=coords) - coords = {key: xr.IndexVariable((key,), data=coords[key]) for key in val_dims} - out_data[name] = xr.DataArray(vals, dims=val_dims, coords=coords) + val_dims, crds = generate_dims_coords(vals.shape, name, dims=val_dims, coords=coords) + crds = {key: xr.IndexVariable((key,), data=crds[key]) for key in val_dims} + out_data[name] = xr.DataArray(vals, dims=val_dims, coords=crds) return xr.Dataset(data_vars=out_data, attrs=make_attrs(attrs=attrs, library=library)) diff --git a/pymc/tests/test_idata_conversion.py b/pymc/tests/test_idata_conversion.py index 93b38c0ccf..06c520ec19 100644 --- a/pymc/tests/test_idata_conversion.py +++ b/pymc/tests/test_idata_conversion.py @@ -568,6 +568,36 @@ def test_multivariate_observations(self): assert "direction" not in idata.log_likelihood.dims assert "direction" in idata.observed_data.dims + def test_constant_data_coords_issue_5046(self): + """This is a regression test against a bug where a local coords variable was overwritten.""" + dims = { + "alpha": ["backwards"], + "bravo": ["letters", "yesno"], + } + coords = { + "backwards": np.arange(17)[::-1], + "letters": list("ABCDEFGHIJK"), + "yesno": ["yes", "no"], + } + data = { + name: np.random.uniform(size=[len(coords[dn]) for dn in dnames]) + for name, dnames in dims.items() + } + + for k in data: + assert len(data[k].shape) == len(dims[k]) + + ds = pm.backends.arviz.dict_to_dataset( + data=data, + library=pm, + coords=coords, + dims=dims, + default_dims=[], + index_origin=0, + ) + for dname, cvals in coords.items(): + np.testing.assert_array_equal(ds[dname].values, cvals) + class TestPyMCWarmupHandling: @pytest.mark.parametrize("save_warmup", [False, True])