diff --git a/xarray/backends/zarr.py b/xarray/backends/zarr.py index 242507f9c20..3fc2cd40ebe 100644 --- a/xarray/backends/zarr.py +++ b/xarray/backends/zarr.py @@ -1218,8 +1218,53 @@ def open_datatree( zarr_version=None, **kwargs, ) -> DataTree: - from xarray.backends.api import open_dataset + from xarray.core.datatree import DataTree + + groups_dicts = self.open_groups_as_dict( + filename_or_obj=filename_or_obj, + mask_and_scale=mask_and_scale, + decode_times=decode_times, + concat_characters=concat_characters, + decode_coords=decode_coords, + drop_variables=drop_variables, + use_cftime=use_cftime, + decode_timedelta=decode_timedelta, + group=group, + mode=mode, + synchronizer=synchronizer, + consolidated=consolidated, + chunk_store=chunk_store, + storage_options=storage_options, + stacklevel=stacklevel, + zarr_version=zarr_version, + **kwargs, + ) + + return DataTree.from_dict(groups_dicts) + + def open_groups_as_dict( + self, + filename_or_obj: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore, + *, + mask_and_scale=True, + decode_times=True, + concat_characters=True, + decode_coords=True, + drop_variables: str | Iterable[str] | None = None, + use_cftime=None, + decode_timedelta=None, + group: str | Iterable[str] | Callable | None = None, + mode="r", + synchronizer=None, + consolidated=None, + chunk_store=None, + storage_options=None, + stacklevel=3, + zarr_version=None, + **kwargs, + ) -> dict[str, Dataset]: + from xarray.backends.api import open_dataset from xarray.core.treenode import NodePath filename_or_obj = _normalize_path(filename_or_obj) @@ -1241,7 +1286,7 @@ def open_datatree( ds = open_dataset( filename_or_obj, group=parent, engine="zarr", **kwargs ) - return DataTree.from_dict({str(parent): ds}) + return {str(parent): ds} else: parent = NodePath("/") stores = ZarrStore.open_store( @@ -1256,26 +1301,44 @@ def open_datatree( stacklevel=stacklevel + 1, zarr_version=zarr_version, ) - ds = open_dataset(filename_or_obj, group=parent, engine="zarr", **kwargs) - tree_root = DataTree.from_dict({str(parent): ds}) - for path_group, store in stores.items(): - ds = open_dataset( - filename_or_obj, store=store, group=path_group, engine="zarr", **kwargs - ) - new_node: DataTree = DataTree(name=NodePath(path_group).name, data=ds) - tree_root._set_item( - path_group, - new_node, - allow_overwrite=False, - new_nodes_along_path=True, + groups_dict = {} + for path_group, group_store in stores.items(): + group_ds = open_dataset( + filename_or_obj, + store=group_store, + group=path_group, + engine="zarr", + mask_and_scale=mask_and_scale, + decode_times=decode_times, + concat_characters=concat_characters, + decode_coords=decode_coords, + drop_variables=drop_variables, + use_cftime=use_cftime, + decode_timedelta=decode_timedelta, + **kwargs, ) - return tree_root + # store_entrypoint = StoreBackendEntrypoint() # | + # with close_on_error(group_store): # | + # group_ds = store_entrypoint.open_dataset( # | + # group_store, # | + # mask_and_scale=mask_and_scale, # | + # decode_times=decode_times, # | + # concat_characters=concat_characters, # |---> This is slower than using `open_dataset` + # decode_coords=decode_coords, # | + # drop_variables=drop_variables, # | + # use_cftime=use_cftime, # | + # decode_timedelta=decode_timedelta, # | + # ) # | + group_name = str(NodePath(path_group)) # | + groups_dict[group_name] = group_ds + return groups_dict def _iter_zarr_groups(root, parent="/"): from xarray.core.treenode import NodePath parent = NodePath(parent) + yield parent for path, group in root.groups(): gpath = parent / path yield str(gpath)