Skip to content

Use zarr to validate attrs when writing to zarr #6636

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jun 3, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions xarray/backends/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1556,9 +1556,8 @@ def to_zarr(
f"'w-', 'a' and 'r+', but mode={mode!r}"
)

# validate Dataset keys, DataArray names, and attr keys/values
# validate Dataset keys, DataArray names
_validate_dataset_names(dataset)
_validate_attrs(dataset)

if region is not None:
_validate_region(dataset, region)
Expand Down
13 changes: 11 additions & 2 deletions xarray/backends/zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,15 @@ def _validate_existing_dims(var_name, new_var, existing_var, region, append_dim)
)


def _put_attrs(zarr_obj, attrs):
for key, value in attrs.items():
try:
zarr_obj.attrs[key] = value
except TypeError as e:
raise TypeError(f"Invalid attr {key!r}: {value!r}. {e!s}") from e
return zarr_obj
Copy link
Contributor

@rabernat rabernat Jun 1, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should probably avoid looping over individual attrs. Imagine we have 100s of attributes (not unheard of for some NetCDF datasets) and are using a high-latency store like HTTP; then we would be doing 100s of sequential PUT operations. This could take a really long time. Much better to leverage the .update method on zarr_obj.attrs.

At that point, I would be curious what happens when Zarr gets invalid attrs. Do we even need to catch the error in Xarray? Or could we just let Zarr raise it an be done. In which case, this whole method would become.

Suggested change
for key, value in attrs.items():
try:
zarr_obj.attrs[key] = value
except TypeError as e:
raise TypeError(f"Invalid attr {key!r}: {value!r}. {e!s}") from e
return zarr_obj
zarr_obj.attrs.update(attrs)

...and at that point do we even need a function for it?

Copy link
Contributor Author

@malmans2 malmans2 Jun 1, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes - I started with an error as close as possible to the one that was already implemented, but I also think that it's better to avoid the loop.

For example, if I assign a DataArray to an attribute, I get the following error from zarr:

TypeError: Object of type DataArray is not JSON serializable

I think it's OK if xarray doesn't explicitly say which attributes generate the problem (i.e., remove the loop), but maybe xarray should say that the issue is in the attributes (i.e., keep the try/except to raise a more informative error).

What do you think?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the most careful and helpful thing for us to do would be to catch the error, and reraise it with a little more context, which is exactly what your latest commits now do. 😊



class ZarrStore(AbstractWritableDataStore):
"""Store for reading and writing data via zarr"""

Expand Down Expand Up @@ -479,7 +488,7 @@ def set_dimensions(self, variables, unlimited_dims=None):
)

def set_attributes(self, attributes):
self.zarr_group.attrs.put(attributes)
_put_attrs(self.zarr_group, attributes)

def encode_variable(self, variable):
variable = encode_zarr_variable(variable)
Expand Down Expand Up @@ -618,7 +627,7 @@ def set_variables(self, variables, check_encoding_set, writer, unlimited_dims=No
zarr_array = self.zarr_group.create(
name, shape=shape, dtype=dtype, fill_value=fill_value, **encoding
)
zarr_array.attrs.put(encoded_attrs)
zarr_array = _put_attrs(zarr_array, encoded_attrs)

write_region = self._write_region if self._write_region is not None else {}
write_region = {dim: write_region.get(dim, slice(None)) for dim in dims}
Expand Down
16 changes: 16 additions & 0 deletions xarray/tests/test_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -2443,6 +2443,22 @@ def test_write_read_select_write(self):
with self.create_zarr_target() as final_store:
ds_sel.to_zarr(final_store, mode="w")

@pytest.mark.parametrize("obj", [Dataset(), DataArray(name="foo")])
def test_attributes(self, obj):
obj = obj.copy()

obj.attrs["good"] = {"key": "value"}
ds = obj if isinstance(obj, Dataset) else obj.to_dataset()
with self.create_zarr_target() as store_target:
ds.to_zarr(store_target)
assert_identical(ds, xr.open_zarr(store_target))

obj.attrs["bad"] = DataArray()
ds = obj if isinstance(obj, Dataset) else obj.to_dataset()
with self.create_zarr_target() as store_target:
with pytest.raises(TypeError, match=r"Invalid attr 'bad'"):
ds.to_zarr(store_target)


@requires_zarr
class TestZarrDictStore(ZarrBase):
Expand Down