Skip to content

Allow .attrs to support any dict-likes #5667

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 58 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
58 commits
Select commit Hold shift + click to select a range
a89b766
add a dict-like checker
Illviljan Aug 3, 2021
8912de4
Update dataset.py
Illviljan Aug 3, 2021
70c39e0
Update merge.py
Illviljan Aug 3, 2021
9f46ea9
Update variable.py
Illviljan Aug 3, 2021
cf11be7
Update dataset.py
Illviljan Aug 3, 2021
4b5b519
Update utils.py
Illviljan Aug 3, 2021
043e7b6
Update utils.py
Illviljan Aug 3, 2021
1bb3327
Use a shallow .copy() like dict(result) did.
Illviljan Aug 3, 2021
e4fc989
use copy() instead
Illviljan Aug 4, 2021
de21491
Update utils.py
Illviljan Aug 4, 2021
68ac983
Update utils.py
Illviljan Aug 4, 2021
a4a1c3d
Initialize ds attrs the same way as in Variable
Illviljan Aug 5, 2021
93510c5
change typing
Illviljan Aug 5, 2021
0369166
more typing
Illviljan Aug 5, 2021
40891c6
Update utils.py
Illviljan Aug 6, 2021
475cb0c
shallow copy here just like with the dict
Illviljan Aug 6, 2021
100731b
Update utils.py
Illviljan Aug 6, 2021
252faf4
Update utils.py
Illviljan Aug 6, 2021
f8fb503
Update conventions.py
Illviljan Aug 7, 2021
6021e12
Update utils.py
Illviljan Aug 7, 2021
ccae97f
Update dataarray.py
Illviljan Aug 7, 2021
9391de6
Update merge.py
Illviljan Aug 7, 2021
bbf4916
typing
Illviljan Aug 7, 2021
1164f3d
Update dataset.py
Illviljan Aug 7, 2021
c5a842e
Update dataset.py
Illviljan Aug 7, 2021
badc69a
Update dataset.py
Illviljan Aug 7, 2021
a721f4d
Update dataarray.py
Illviljan Aug 7, 2021
2515a8e
Update dataarray.py
Illviljan Aug 7, 2021
653327c
Update dataset.py
Illviljan Aug 7, 2021
3ad4193
Update dataset.py
Illviljan Aug 7, 2021
a088b16
try coercing to dict
Illviljan Aug 7, 2021
a798f8e
Update dataset.py
Illviljan Aug 7, 2021
3391ac3
Update conventions.py
Illviljan Aug 7, 2021
68459de
issubclass seems slightly faster
Illviljan Aug 14, 2021
894b6ea
Update utils.py
Illviljan Aug 14, 2021
6f5b372
Try out TypeGuard to avoid the slow isinstance check
Illviljan Aug 20, 2021
e479c5c
Update utils.py
Illviljan Aug 20, 2021
2fb7f75
Merge branch 'main' into Illviljan-attrs_supports_dict_like
Illviljan Aug 20, 2021
41fd3b9
Update utils.py
Illviljan Aug 20, 2021
b10d098
Merge branch 'main' into Illviljan-attrs_supports_dict_like
Illviljan Aug 22, 2021
08887c1
Merge branch 'main' into Illviljan-attrs_supports_dict_like
Illviljan Sep 1, 2021
c728006
add typing in variable
Illviljan Sep 1, 2021
0558996
lint
Illviljan Sep 1, 2021
320debb
hashable -> any
Illviljan Sep 1, 2021
596c661
add typing to indexvariable
Illviljan Sep 1, 2021
4e5b262
Update variable.py
Illviljan Sep 1, 2021
26e4b66
Merge branch 'main' into Illviljan-attrs_supports_dict_like
Illviljan Oct 31, 2021
84b494f
Update utils.py
Illviljan Oct 31, 2021
8c09d4a
TypeGuard not in typing_extensions 3.7
Illviljan Oct 31, 2021
9758348
Update utils.py
Illviljan Oct 31, 2021
2192835
Update utils.py
Illviljan Oct 31, 2021
b995c4a
Update utils.py
Illviljan Oct 31, 2021
650ce22
Update utils.py
Illviljan Oct 31, 2021
1f0d41b
typing_extensions now a required package
Illviljan Nov 14, 2021
d46ba6f
Merge branch 'main' into Illviljan-attrs_supports_dict_like
Illviljan Nov 14, 2021
3661d43
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 14, 2021
a494fd7
TypeGuard still not in typing_extensions 3.7
Illviljan Nov 14, 2021
0e6036f
Update utils.py
Illviljan Nov 14, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions xarray/conventions.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import warnings
from collections import defaultdict
from copy import copy

import numpy as np
import pandas as pd
Expand All @@ -9,6 +10,7 @@
from .core import duck_array_ops, indexing
from .core.common import contains_cftime_datetimes
from .core.pycompat import is_duck_dask_array
from .core.utils import maybe_coerce_to_dict
from .core.variable import IndexVariable, Variable, as_variable

CF_RELATED_DATA = (
Expand Down Expand Up @@ -95,7 +97,7 @@ def __getitem__(self, key):


def _var_as_tuple(var):
return var.dims, var.data, var.attrs.copy(), var.encoding.copy()
return var.dims, var.data, copy(var.attrs), var.encoding.copy()


def maybe_encode_nonstring_dtype(var, name=None):
Expand Down Expand Up @@ -562,7 +564,7 @@ def stackable(dim):
del var_attrs[attr_name]

if decode_coords and "coordinates" in attributes:
attributes = dict(attributes)
attributes = maybe_coerce_to_dict(attributes)
coord_names.update(attributes.pop("coordinates").split())

return new_vars, attributes, coord_names
Expand Down Expand Up @@ -786,7 +788,7 @@ def _encode_coordinates(variables, attributes, non_dim_coord_names):
# http://mailman.cgd.ucar.edu/pipermail/cf-metadata/2014/007571.html
global_coordinates.difference_update(written_coords)
if global_coordinates:
attributes = dict(attributes)
attributes = copy(attributes)
if "coordinates" in attributes:
warnings.warn(
f"cannot serialize global coordinates {global_coordinates!r} because the global "
Expand Down
8 changes: 5 additions & 3 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import datetime
import warnings
from copy import copy
from typing import (
TYPE_CHECKING,
Any,
Expand All @@ -11,6 +12,7 @@
Iterable,
List,
Mapping,
MutableMapping,
Optional,
Sequence,
Tuple,
Expand Down Expand Up @@ -575,7 +577,7 @@ def to_dataset(
result = self._to_dataset_whole(name)

if promote_attrs:
result.attrs = dict(self.attrs)
result.attrs = copy(self.attrs)

return result

Expand Down Expand Up @@ -788,9 +790,9 @@ def loc(self) -> _LocIndexer:
"""Attribute for location based indexing like pandas."""
return _LocIndexer(self)

@property
# Key type needs to be `Any` because of mypy#4167
def attrs(self) -> Dict[Any, Any]:
@property
def attrs(self) -> MutableMapping[Any, Any]:
"""Dictionary storing arbitrary metadata with this array."""
return self.variable.attrs

Expand Down
23 changes: 13 additions & 10 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@
infix_dims,
is_dict_like,
is_scalar,
maybe_coerce_to_dict,
maybe_wrap_array,
)
from .variable import (
Expand Down Expand Up @@ -697,7 +698,7 @@ class Dataset(DataWithCoords, DatasetArithmetic, Mapping):
description: Weather related data.
"""

_attrs: Optional[Dict[Hashable, Any]]
_attrs: Optional[MutableMapping[Any, Any]]
_cache: Dict[str, Any]
_coord_names: Set[Hashable]
_dims: Dict[Hashable, int]
Expand Down Expand Up @@ -752,7 +753,9 @@ def __init__(
data_vars, coords, compat="broadcast_equals"
)

self._attrs = dict(attrs) if attrs is not None else None
self._attrs = None
if attrs is not None:
self.attrs = attrs # type: ignore[assignment] # https://github.com/python/mypy/issues/3004
self._close = None
self._encoding = None
self._variables = variables
Expand Down Expand Up @@ -784,15 +787,15 @@ def variables(self) -> Mapping[Hashable, Variable]:
return Frozen(self._variables)

@property
def attrs(self) -> Dict[Hashable, Any]:
def attrs(self) -> MutableMapping[Any, Any]:
"""Dictionary of global attributes on this dataset"""
if self._attrs is None:
self._attrs = {}
return self._attrs

@attrs.setter
def attrs(self, value: Mapping[Any, Any]) -> None:
self._attrs = dict(value)
self._attrs = maybe_coerce_to_dict(value)

@property
def encoding(self) -> Dict:
Expand Down Expand Up @@ -1096,8 +1099,8 @@ def _replace(
variables: Dict[Hashable, Variable] = None,
coord_names: Set[Hashable] = None,
dims: Dict[Any, int] = None,
attrs: Union[Dict[Hashable, Any], None, Default] = _default,
indexes: Union[Dict[Hashable, Index], None, Default] = _default,
attrs: Union[MutableMapping[Any, Any], None, Default] = _default,
indexes: Union[Dict[Any, Index], None, Default] = _default,
encoding: Union[dict, None, Default] = _default,
inplace: bool = False,
) -> "Dataset":
Expand Down Expand Up @@ -1145,7 +1148,7 @@ def _replace_with_new_dims(
self,
variables: Dict[Hashable, Variable],
coord_names: set = None,
attrs: Union[Dict[Hashable, Any], None, Default] = _default,
attrs: Union[MutableMapping[Any, Any], None, Default] = _default,
indexes: Union[Dict[Hashable, Index], None, Default] = _default,
inplace: bool = False,
) -> "Dataset":
Expand All @@ -1160,7 +1163,7 @@ def _replace_vars_and_dims(
variables: Dict[Hashable, Variable],
coord_names: set = None,
dims: Dict[Hashable, int] = None,
attrs: Union[Dict[Hashable, Any], None, Default] = _default,
attrs: Union[MutableMapping[Any, Any], None, Default] = _default,
inplace: bool = False,
) -> "Dataset":
"""Deprecated version of _replace_with_new_dims().
Expand Down Expand Up @@ -6996,7 +6999,7 @@ def polyfit(
covariance = xr.DataArray(Vbase, dims=("cov_i", "cov_j")) * fac
variables[name + "polyfit_covariance"] = covariance

return Dataset(data_vars=variables, attrs=self.attrs.copy())
return Dataset(data_vars=variables, attrs=copy.copy(self.attrs))

def pad(
self,
Expand Down Expand Up @@ -7726,6 +7729,6 @@ def _wrapper(Y, *coords_, **kwargs):
result = result.assign_coords(
{"param": params, "cov_i": params, "cov_j": params}
)
result.attrs = self.attrs.copy()
result.attrs = copy.copy(self.attrs)

return result
7 changes: 4 additions & 3 deletions xarray/core/merge.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from copy import copy
from typing import (
TYPE_CHECKING,
AbstractSet,
Expand Down Expand Up @@ -524,9 +525,9 @@ def merge_attrs(variable_attrs, combine_attrs, context=None):
elif combine_attrs == "drop":
return {}
elif combine_attrs == "override":
return dict(variable_attrs[0])
return copy(variable_attrs[0])
elif combine_attrs == "no_conflicts":
result = dict(variable_attrs[0])
result = copy(variable_attrs[0])
for attrs in variable_attrs[1:]:
try:
result = compat_dict_union(result, attrs)
Expand Down Expand Up @@ -555,7 +556,7 @@ def merge_attrs(variable_attrs, combine_attrs, context=None):
dropped_keys |= {key for key in attrs if key not in result}
return result
elif combine_attrs == "identical":
result = dict(variable_attrs[0])
result = copy(variable_attrs[0])
for attrs in variable_attrs[1:]:
if not dict_equiv(result, attrs):
raise MergeError(
Expand Down
44 changes: 42 additions & 2 deletions xarray/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import re
import sys
import warnings
from copy import copy
from enum import Enum
from typing import (
TYPE_CHECKING,
Expand Down Expand Up @@ -95,6 +96,45 @@ def maybe_coerce_to_str(index, original_coords):
return index


# See GH5624, this is a convoluted way to allow type-checking to use `TypeGuard` without
# requiring typing_extensions as a required dependency to _run_ the code (it is required
# to type-check).
try:
if sys.version_info >= (3, 10):
from typing import TypeGuard
else:
from typing_extensions import TypeGuard
except (ImportError, NameError) as e:
if TYPE_CHECKING:
raise e
else:

def _is_MutableMapping(obj: Mapping[Any, Any]) -> bool:
"""Check if the object is a mutable mapping."""
return hasattr(obj, "__setitem__")


else:

def _is_MutableMapping(
obj: Mapping[Any, Any]
) -> TypeGuard[MutableMapping[Any, Any]]:
"""Check if the object is a mutable mapping."""
return hasattr(obj, "__setitem__")


def maybe_coerce_to_dict(obj: Mapping[Any, Any]) -> MutableMapping[Any, Any]:
"""Convert to dict if the object is not a valid dict-like."""
# if isinstance(obj, MutableMapping):
if _is_MutableMapping(obj):
# if hasattr(obj, "update"):
# return obj.copy()
return copy(obj)
# return obj
else:
return dict(obj)


def safe_cast_to_index(array: Any) -> pd.Index:
"""Given an array, safely cast it to a pandas.Index.

Expand Down Expand Up @@ -417,7 +457,7 @@ def compat_dict_intersection(


def compat_dict_union(
first_dict: Mapping[K, V],
first_dict: MutableMapping[K, V],
second_dict: Mapping[K, V],
compat: Callable[[V, V], bool] = equivalent,
) -> MutableMapping[K, V]:
Expand All @@ -439,7 +479,7 @@ def compat_dict_union(
union : dict
union of the contents.
"""
new_dict = dict(first_dict)
new_dict = copy(first_dict)
update_safety_check(first_dict, second_dict, compat)
new_dict.update(second_dict)
return new_dict
Expand Down
30 changes: 25 additions & 5 deletions xarray/core/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
Hashable,
List,
Mapping,
MutableMapping,
Optional,
Sequence,
Tuple,
Expand Down Expand Up @@ -55,6 +56,7 @@
ensure_us_time_resolution,
infix_dims,
is_duck_array,
maybe_coerce_to_dict,
maybe_coerce_to_str,
)

Expand Down Expand Up @@ -286,9 +288,18 @@ class Variable(AbstractArray, NdimSizeLenMixin, VariableArithmetic):
they can use more complete metadata in context of coordinate labels.
"""

_attrs: Optional[MutableMapping[Any, Any]]

__slots__ = ("_dims", "_data", "_attrs", "_encoding")

def __init__(self, dims, data, attrs=None, encoding=None, fastpath=False):
def __init__(
self,
dims,
data,
attrs: Optional[Mapping[Any, Any]] = None,
encoding=None,
fastpath=False,
):
"""
Parameters
----------
Expand All @@ -313,7 +324,7 @@ def __init__(self, dims, data, attrs=None, encoding=None, fastpath=False):
self._attrs = None
self._encoding = None
if attrs is not None:
self.attrs = attrs
self.attrs = attrs # type: ignore[assignment] # https://github.com/python/mypy/issues/3004
if encoding is not None:
self.encoding = encoding

Expand Down Expand Up @@ -863,15 +874,15 @@ def __setitem__(self, key, value):
indexable[index_tuple] = value

@property
def attrs(self) -> Dict[Hashable, Any]:
def attrs(self) -> MutableMapping[Any, Any]:
"""Dictionary of local attributes on this variable."""
if self._attrs is None:
self._attrs = {}
return self._attrs

@attrs.setter
def attrs(self, value: Mapping[Any, Any]) -> None:
self._attrs = dict(value)
self._attrs = maybe_coerce_to_dict(value)

@property
def encoding(self):
Expand Down Expand Up @@ -2602,9 +2613,18 @@ class IndexVariable(Variable):
unless another name is given.
"""

_attrs: Optional[MutableMapping[Any, Any]]

__slots__ = ()

def __init__(self, dims, data, attrs=None, encoding=None, fastpath=False):
def __init__(
self,
dims,
data,
attrs: Optional[Mapping[Any, Any]] = None,
encoding=None,
fastpath=False,
):
super().__init__(dims, data, attrs, encoding, fastpath)
if self.ndim != 1:
raise ValueError(f"{type(self).__name__} objects must be 1-dimensional")
Expand Down