Skip to content

Commit c06c2f4

Browse files
Fix coords in constant_data issue #5046 (#5062)
* Add test case for constant_data coords issue #5046 * Don't replace local coords variable inside iterator Closes #5046
1 parent a3cc81c commit c06c2f4

File tree

2 files changed

+33
-3
lines changed

2 files changed

+33
-3
lines changed

pymc/backends/arviz.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -131,9 +131,9 @@ def dict_to_dataset(
131131
for name, vals in data.items():
132132
vals = np.atleast_1d(vals)
133133
val_dims = dims.get(name)
134-
val_dims, coords = generate_dims_coords(vals.shape, name, dims=val_dims, coords=coords)
135-
coords = {key: xr.IndexVariable((key,), data=coords[key]) for key in val_dims}
136-
out_data[name] = xr.DataArray(vals, dims=val_dims, coords=coords)
134+
val_dims, crds = generate_dims_coords(vals.shape, name, dims=val_dims, coords=coords)
135+
crds = {key: xr.IndexVariable((key,), data=crds[key]) for key in val_dims}
136+
out_data[name] = xr.DataArray(vals, dims=val_dims, coords=crds)
137137
return xr.Dataset(data_vars=out_data, attrs=make_attrs(attrs=attrs, library=library))
138138

139139

pymc/tests/test_idata_conversion.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -568,6 +568,36 @@ def test_multivariate_observations(self):
568568
assert "direction" not in idata.log_likelihood.dims
569569
assert "direction" in idata.observed_data.dims
570570

571+
def test_constant_data_coords_issue_5046(self):
572+
"""This is a regression test against a bug where a local coords variable was overwritten."""
573+
dims = {
574+
"alpha": ["backwards"],
575+
"bravo": ["letters", "yesno"],
576+
}
577+
coords = {
578+
"backwards": np.arange(17)[::-1],
579+
"letters": list("ABCDEFGHIJK"),
580+
"yesno": ["yes", "no"],
581+
}
582+
data = {
583+
name: np.random.uniform(size=[len(coords[dn]) for dn in dnames])
584+
for name, dnames in dims.items()
585+
}
586+
587+
for k in data:
588+
assert len(data[k].shape) == len(dims[k])
589+
590+
ds = pm.backends.arviz.dict_to_dataset(
591+
data=data,
592+
library=pm,
593+
coords=coords,
594+
dims=dims,
595+
default_dims=[],
596+
index_origin=0,
597+
)
598+
for dname, cvals in coords.items():
599+
np.testing.assert_array_equal(ds[dname].values, cvals)
600+
571601

572602
class TestPyMCWarmupHandling:
573603
@pytest.mark.parametrize("save_warmup", [False, True])

0 commit comments

Comments
 (0)