Skip to content

Commit 378b902

Browse files
authored
Change annotations to allow str keys (#5690)
* Change typing to allow str keys * Change all incoming Mapping types * Add in some annotated tests * whatsnew
1 parent 22548f8 commit 378b902

12 files changed

+70
-59
lines changed

doc/whats-new.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,8 @@ Internal Changes
5353
By `Deepak Cherian <https://github.com/dcherian>`_.
5454
- Explicit indexes refactor: decouple ``xarray.Index``` from ``xarray.Variable`` (:pull:`5636`).
5555
By `Benoit Bovy <https://github.com/benbovy>`_.
56+
- Fix ``Mapping`` argument typing to allow mypy to pass on ``str`` keys (:pull:`5690`).
57+
By `Maximilian Roos <https://github.com/max-sixty>`_.
5658
- Improve the performance of reprs for large datasets or dataarrays. (:pull:`5661`)
5759
By `Jimmy Westling <https://github.com/illviljan>`_.
5860
- Use isort's `float_to_top` config. (:pull:`5695`).

properties/test_pandas_roundtrip.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828

2929

3030
@st.composite
31-
def datasets_1d_vars(draw):
31+
def datasets_1d_vars(draw) -> xr.Dataset:
3232
"""Generate datasets with only 1D variables
3333
3434
Suitable for converting to pandas dataframes.
@@ -49,7 +49,7 @@ def datasets_1d_vars(draw):
4949

5050

5151
@given(st.data(), an_array)
52-
def test_roundtrip_dataarray(data, arr):
52+
def test_roundtrip_dataarray(data, arr) -> None:
5353
names = data.draw(
5454
st.lists(st.text(), min_size=arr.ndim, max_size=arr.ndim, unique=True).map(
5555
tuple
@@ -62,15 +62,15 @@ def test_roundtrip_dataarray(data, arr):
6262

6363

6464
@given(datasets_1d_vars())
65-
def test_roundtrip_dataset(dataset):
65+
def test_roundtrip_dataset(dataset) -> None:
6666
df = dataset.to_dataframe()
6767
assert isinstance(df, pd.DataFrame)
6868
roundtripped = xr.Dataset(df)
6969
xr.testing.assert_identical(dataset, roundtripped)
7070

7171

7272
@given(numeric_series, st.text())
73-
def test_roundtrip_pandas_series(ser, ix_name):
73+
def test_roundtrip_pandas_series(ser, ix_name) -> None:
7474
# Need to name the index, otherwise Xarray calls it 'dim_0'.
7575
ser.index.name = ix_name
7676
arr = xr.DataArray(ser)
@@ -87,7 +87,7 @@ def test_roundtrip_pandas_series(ser, ix_name):
8787

8888
@pytest.mark.xfail
8989
@given(numeric_homogeneous_dataframe)
90-
def test_roundtrip_pandas_dataframe(df):
90+
def test_roundtrip_pandas_dataframe(df) -> None:
9191
# Need to name the indexes, otherwise Xarray names them 'dim_0', 'dim_1'.
9292
df.index.name = "rows"
9393
df.columns.name = "cols"

xarray/core/accessor_str.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def _apply_str_ufunc(
114114
obj: Any,
115115
dtype: Union[str, np.dtype, Type] = None,
116116
output_core_dims: Union[list, tuple] = ((),),
117-
output_sizes: Mapping[Hashable, int] = None,
117+
output_sizes: Mapping[Any, int] = None,
118118
func_args: Tuple = (),
119119
func_kwargs: Mapping = {},
120120
) -> Any:
@@ -227,7 +227,7 @@ def _apply(
227227
func: Callable,
228228
dtype: Union[str, np.dtype, Type] = None,
229229
output_core_dims: Union[list, tuple] = ((),),
230-
output_sizes: Mapping[Hashable, int] = None,
230+
output_sizes: Mapping[Any, int] = None,
231231
func_args: Tuple = (),
232232
func_kwargs: Mapping = {},
233233
) -> Any:

xarray/core/common.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -818,7 +818,7 @@ def weighted(
818818

819819
def rolling(
820820
self,
821-
dim: Mapping[Hashable, int] = None,
821+
dim: Mapping[Any, int] = None,
822822
min_periods: int = None,
823823
center: Union[bool, Mapping[Hashable, bool]] = False,
824824
**window_kwargs: int,
@@ -892,7 +892,7 @@ def rolling(
892892

893893
def rolling_exp(
894894
self,
895-
window: Mapping[Hashable, int] = None,
895+
window: Mapping[Any, int] = None,
896896
window_type: str = "span",
897897
**window_kwargs,
898898
):
@@ -933,7 +933,7 @@ def rolling_exp(
933933

934934
def coarsen(
935935
self,
936-
dim: Mapping[Hashable, int] = None,
936+
dim: Mapping[Any, int] = None,
937937
boundary: str = "exact",
938938
side: Union[str, Mapping[Hashable, str]] = "left",
939939
coord_func: str = "mean",
@@ -1009,7 +1009,7 @@ def coarsen(
10091009

10101010
def resample(
10111011
self,
1012-
indexer: Mapping[Hashable, str] = None,
1012+
indexer: Mapping[Any, str] = None,
10131013
skipna=None,
10141014
closed: str = None,
10151015
label: str = None,

xarray/core/computation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -400,7 +400,7 @@ def apply_dict_of_variables_vfunc(
400400

401401

402402
def _fast_dataset(
403-
variables: Dict[Hashable, Variable], coord_variables: Mapping[Hashable, Variable]
403+
variables: Dict[Hashable, Variable], coord_variables: Mapping[Any, Variable]
404404
) -> "Dataset":
405405
"""Create a dataset as quickly as possible.
406406

xarray/core/coordinates.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ def to_index(self, ordered_dims: Sequence[Hashable] = None) -> pd.Index:
158158

159159
return pd.MultiIndex(level_list, code_list, names=names)
160160

161-
def update(self, other: Mapping[Hashable, Any]) -> None:
161+
def update(self, other: Mapping[Any, Any]) -> None:
162162
other_vars = getattr(other, "variables", other)
163163
coords, indexes = merge_coords(
164164
[self.variables, other_vars], priority_arg=1, indexes=self.xindexes
@@ -270,7 +270,7 @@ def to_dataset(self) -> "Dataset":
270270
return self._data._copy_listed(names)
271271

272272
def _update_coords(
273-
self, coords: Dict[Hashable, Variable], indexes: Mapping[Hashable, Index]
273+
self, coords: Dict[Hashable, Variable], indexes: Mapping[Any, Index]
274274
) -> None:
275275
from .dataset import calculate_dimensions
276276

@@ -333,7 +333,7 @@ def __getitem__(self, key: Hashable) -> "DataArray":
333333
return self._data._getitem_coord(key)
334334

335335
def _update_coords(
336-
self, coords: Dict[Hashable, Variable], indexes: Mapping[Hashable, Index]
336+
self, coords: Dict[Hashable, Variable], indexes: Mapping[Any, Index]
337337
) -> None:
338338
from .dataset import calculate_dimensions
339339

@@ -376,7 +376,7 @@ def _ipython_key_completions_(self):
376376

377377

378378
def assert_coordinate_consistent(
379-
obj: Union["DataArray", "Dataset"], coords: Mapping[Hashable, Variable]
379+
obj: Union["DataArray", "Dataset"], coords: Mapping[Any, Variable]
380380
) -> None:
381381
"""Make sure the dimension coordinate of obj is consistent with coords.
382382
@@ -394,7 +394,7 @@ def assert_coordinate_consistent(
394394

395395
def remap_label_indexers(
396396
obj: Union["DataArray", "Dataset"],
397-
indexers: Mapping[Hashable, Any] = None,
397+
indexers: Mapping[Any, Any] = None,
398398
method: str = None,
399399
tolerance=None,
400400
**indexers_kwargs: Any,

xarray/core/dataarray.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -462,7 +462,7 @@ def _replace_maybe_drop_dims(
462462
)
463463
return self._replace(variable, coords, name, indexes=indexes)
464464

465-
def _overwrite_indexes(self, indexes: Mapping[Hashable, Any]) -> "DataArray":
465+
def _overwrite_indexes(self, indexes: Mapping[Any, Any]) -> "DataArray":
466466
if not len(indexes):
467467
return self
468468
coords = self._coords.copy()
@@ -792,7 +792,7 @@ def attrs(self) -> Dict[Hashable, Any]:
792792
return self.variable.attrs
793793

794794
@attrs.setter
795-
def attrs(self, value: Mapping[Hashable, Any]) -> None:
795+
def attrs(self, value: Mapping[Any, Any]) -> None:
796796
# Disable type checking to work around mypy bug - see mypy#4167
797797
self.variable.attrs = value # type: ignore[assignment]
798798

@@ -803,7 +803,7 @@ def encoding(self) -> Dict[Hashable, Any]:
803803
return self.variable.encoding
804804

805805
@encoding.setter
806-
def encoding(self, value: Mapping[Hashable, Any]) -> None:
806+
def encoding(self, value: Mapping[Any, Any]) -> None:
807807
self.variable.encoding = value
808808

809809
@property
@@ -1110,7 +1110,7 @@ def chunk(
11101110

11111111
def isel(
11121112
self,
1113-
indexers: Mapping[Hashable, Any] = None,
1113+
indexers: Mapping[Any, Any] = None,
11141114
drop: bool = False,
11151115
missing_dims: str = "raise",
11161116
**indexers_kwargs: Any,
@@ -1193,7 +1193,7 @@ def isel(
11931193

11941194
def sel(
11951195
self,
1196-
indexers: Mapping[Hashable, Any] = None,
1196+
indexers: Mapping[Any, Any] = None,
11971197
method: str = None,
11981198
tolerance=None,
11991199
drop: bool = False,
@@ -1498,7 +1498,7 @@ def reindex_like(
14981498

14991499
def reindex(
15001500
self,
1501-
indexers: Mapping[Hashable, Any] = None,
1501+
indexers: Mapping[Any, Any] = None,
15021502
method: str = None,
15031503
tolerance=None,
15041504
copy: bool = True,
@@ -1591,7 +1591,7 @@ def reindex(
15911591

15921592
def interp(
15931593
self,
1594-
coords: Mapping[Hashable, Any] = None,
1594+
coords: Mapping[Any, Any] = None,
15951595
method: str = "linear",
15961596
assume_sorted: bool = False,
15971597
kwargs: Mapping[str, Any] = None,
@@ -1815,7 +1815,7 @@ def rename(
18151815
return self._replace(name=new_name_or_name_dict)
18161816

18171817
def swap_dims(
1818-
self, dims_dict: Mapping[Hashable, Hashable] = None, **dims_kwargs
1818+
self, dims_dict: Mapping[Any, Hashable] = None, **dims_kwargs
18191819
) -> "DataArray":
18201820
"""Returns a new DataArray with swapped dimensions.
18211821
@@ -2333,7 +2333,7 @@ def drop(
23332333

23342334
def drop_sel(
23352335
self,
2336-
labels: Mapping[Hashable, Any] = None,
2336+
labels: Mapping[Any, Any] = None,
23372337
*,
23382338
errors: str = "raise",
23392339
**labels_kwargs,
@@ -3163,7 +3163,7 @@ def diff(self, dim: Hashable, n: int = 1, label: Hashable = "upper") -> "DataArr
31633163

31643164
def shift(
31653165
self,
3166-
shifts: Mapping[Hashable, int] = None,
3166+
shifts: Mapping[Any, int] = None,
31673167
fill_value: Any = dtypes.NA,
31683168
**shifts_kwargs: int,
31693169
) -> "DataArray":
@@ -3210,7 +3210,7 @@ def shift(
32103210

32113211
def roll(
32123212
self,
3213-
shifts: Mapping[Hashable, int] = None,
3213+
shifts: Mapping[Any, int] = None,
32143214
roll_coords: bool = None,
32153215
**shifts_kwargs: int,
32163216
) -> "DataArray":
@@ -4433,7 +4433,7 @@ def argmax(
44334433

44344434
def query(
44354435
self,
4436-
queries: Mapping[Hashable, Any] = None,
4436+
queries: Mapping[Any, Any] = None,
44374437
parser: str = "pandas",
44384438
engine: str = None,
44394439
missing_dims: str = "raise",

0 commit comments

Comments
 (0)