Skip to content

Commit 226c23b

Browse files
authored
Add Ellipsis typehint to reductions (#7048)
1 parent e678a1d commit 226c23b

14 files changed

+448
-340
lines changed

xarray/core/_reductions.py

Lines changed: 238 additions & 193 deletions
Large diffs are not rendered by default.

xarray/core/computation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
from .coordinates import Coordinates
4141
from .dataarray import DataArray
4242
from .dataset import Dataset
43-
from .types import CombineAttrsOptions, Ellipsis, JoinOptions
43+
from .types import CombineAttrsOptions, JoinOptions
4444

4545
_NO_FILL_VALUE = utils.ReprObject("<no-fill-value>")
4646
_DEFAULT_NAME = utils.ReprObject("<default-name>")
@@ -1624,7 +1624,7 @@ def cross(
16241624

16251625
def dot(
16261626
*arrays,
1627-
dims: str | Iterable[Hashable] | Ellipsis | None = None,
1627+
dims: str | Iterable[Hashable] | ellipsis | None = None,
16281628
**kwargs: Any,
16291629
):
16301630
"""Generalized dot product for xarray objects. Like np.einsum, but

xarray/core/dataarray.py

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@
7878
from .types import (
7979
CoarsenBoundaryOptions,
8080
DatetimeUnitOptions,
81-
Ellipsis,
81+
Dims,
8282
ErrorOptions,
8383
ErrorOptionsWithWarn,
8484
InterpOptions,
@@ -900,30 +900,30 @@ def coords(self) -> DataArrayCoordinates:
900900
@overload
901901
def reset_coords(
902902
self: T_DataArray,
903-
names: Hashable | Iterable[Hashable] | None = None,
903+
names: Dims = None,
904904
drop: Literal[False] = False,
905905
) -> Dataset:
906906
...
907907

908908
@overload
909909
def reset_coords(
910910
self: T_DataArray,
911-
names: Hashable | Iterable[Hashable] | None = None,
911+
names: Dims = None,
912912
*,
913913
drop: Literal[True],
914914
) -> T_DataArray:
915915
...
916916

917917
def reset_coords(
918918
self: T_DataArray,
919-
names: Hashable | Iterable[Hashable] | None = None,
919+
names: Dims = None,
920920
drop: bool = False,
921921
) -> T_DataArray | Dataset:
922922
"""Given names of coordinates, reset them to become variables.
923923
924924
Parameters
925925
----------
926-
names : Hashable or iterable of Hashable, optional
926+
names : str, Iterable of Hashable or None, optional
927927
Name(s) of non-index coordinates in this dataset to reset into
928928
variables. By default, all non-index coordinates are reset.
929929
drop : bool, default: False
@@ -2574,7 +2574,7 @@ def stack(
25742574
# https://github.com/python/mypy/issues/12846 is resolved
25752575
def unstack(
25762576
self,
2577-
dim: Hashable | Sequence[Hashable] | None = None,
2577+
dim: Dims = None,
25782578
fill_value: Any = dtypes.NA,
25792579
sparse: bool = False,
25802580
) -> DataArray:
@@ -2586,7 +2586,7 @@ def unstack(
25862586
25872587
Parameters
25882588
----------
2589-
dim : Hashable or sequence of Hashable, optional
2589+
dim : str, Iterable of Hashable or None, optional
25902590
Dimension(s) over which to unstack. By default unstacks all
25912591
MultiIndexes.
25922592
fill_value : scalar or dict-like, default: nan
@@ -3400,9 +3400,9 @@ def combine_first(self: T_DataArray, other: T_DataArray) -> T_DataArray:
34003400
def reduce(
34013401
self: T_DataArray,
34023402
func: Callable[..., Any],
3403-
dim: None | Hashable | Iterable[Hashable] = None,
3403+
dim: Dims | ellipsis = None,
34043404
*,
3405-
axis: None | int | Sequence[int] = None,
3405+
axis: int | Sequence[int] | None = None,
34063406
keep_attrs: bool | None = None,
34073407
keepdims: bool = False,
34083408
**kwargs: Any,
@@ -3415,8 +3415,9 @@ def reduce(
34153415
Function which can be called in the form
34163416
`f(x, axis=axis, **kwargs)` to return the result of reducing an
34173417
np.ndarray over an integer valued axis.
3418-
dim : Hashable or Iterable of Hashable, optional
3419-
Dimension(s) over which to apply `func`.
3418+
dim : "...", str, Iterable of Hashable or None, optional
3419+
Dimension(s) over which to apply `func`. By default `func` is
3420+
applied over all dimensions.
34203421
axis : int or sequence of int, optional
34213422
Axis(es) over which to repeatedly apply `func`. Only one of the
34223423
'dim' and 'axis' arguments can be supplied. If neither are
@@ -4386,7 +4387,7 @@ def imag(self: T_DataArray) -> T_DataArray:
43864387
def dot(
43874388
self: T_DataArray,
43884389
other: T_DataArray,
4389-
dims: str | Iterable[Hashable] | Ellipsis | None = None,
4390+
dims: Dims | ellipsis = None,
43904391
) -> T_DataArray:
43914392
"""Perform dot product of two DataArrays along their shared dims.
43924393
@@ -4396,7 +4397,7 @@ def dot(
43964397
----------
43974398
other : DataArray
43984399
The other array with which the dot product is performed.
4399-
dims : ..., str or Iterable of Hashable, optional
4400+
dims : ..., str, Iterable of Hashable or None, optional
44004401
Which dimensions to sum over. Ellipsis (`...`) sums over all dimensions.
44014402
If not specified, then all the common dimensions are summed over.
44024403
@@ -4506,7 +4507,7 @@ def sortby(
45064507
def quantile(
45074508
self: T_DataArray,
45084509
q: ArrayLike,
4509-
dim: str | Iterable[Hashable] | None = None,
4510+
dim: Dims = None,
45104511
method: QUANTILE_METHODS = "linear",
45114512
keep_attrs: bool | None = None,
45124513
skipna: bool | None = None,
@@ -5390,7 +5391,7 @@ def idxmax(
53905391
# https://github.com/python/mypy/issues/12846 is resolved
53915392
def argmin(
53925393
self,
5393-
dim: Hashable | Sequence[Hashable] | Ellipsis | None = None,
5394+
dim: Dims | ellipsis = None,
53945395
axis: int | None = None,
53955396
keep_attrs: bool | None = None,
53965397
skipna: bool | None = None,
@@ -5406,7 +5407,7 @@ def argmin(
54065407
54075408
Parameters
54085409
----------
5409-
dim : Hashable, sequence of Hashable, None or ..., optional
5410+
dim : "...", str, Iterable of Hashable or None, optional
54105411
The dimensions over which to find the minimum. By default, finds minimum over
54115412
all dimensions - for now returning an int for backward compatibility, but
54125413
this is deprecated, in future will return a dict with indices for all
@@ -5495,7 +5496,7 @@ def argmin(
54955496
# https://github.com/python/mypy/issues/12846 is resolved
54965497
def argmax(
54975498
self,
5498-
dim: Hashable | Sequence[Hashable] | Ellipsis | None = None,
5499+
dim: Dims | ellipsis = None,
54995500
axis: int | None = None,
55005501
keep_attrs: bool | None = None,
55015502
skipna: bool | None = None,
@@ -5511,7 +5512,7 @@ def argmax(
55115512
55125513
Parameters
55135514
----------
5514-
dim : Hashable, sequence of Hashable, None or ..., optional
5515+
dim : "...", str, Iterable of Hashable or None, optional
55155516
The dimensions over which to find the maximum. By default, finds maximum over
55165517
all dimensions - for now returning an int for backward compatibility, but
55175518
this is deprecated, in future will return a dict with indices for all
@@ -5679,7 +5680,7 @@ def curvefit(
56795680
self,
56805681
coords: str | DataArray | Iterable[str | DataArray],
56815682
func: Callable[..., Any],
5682-
reduce_dims: Hashable | Iterable[Hashable] | None = None,
5683+
reduce_dims: Dims = None,
56835684
skipna: bool = True,
56845685
p0: dict[str, Any] | None = None,
56855686
bounds: dict[str, Any] | None = None,
@@ -5704,7 +5705,7 @@ def curvefit(
57045705
array of length `len(x)`. `params` are the fittable parameters which are optimized
57055706
by scipy curve_fit. `x` can also be specified as a sequence containing multiple
57065707
coordinates, e.g. `f((x0, x1), *params)`.
5707-
reduce_dims : Hashable or sequence of Hashable
5708+
reduce_dims : str, Iterable of Hashable or None, optional
57085709
Additional dimension(s) over which to aggregate while fitting. For example,
57095710
calling `ds.curvefit(coords='time', reduce_dims=['lat', 'lon'], ...)` will
57105711
aggregate all lat and lon points and fit the specified function along the

xarray/core/dataset.py

Lines changed: 35 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@
107107
CombineAttrsOptions,
108108
CompatOptions,
109109
DatetimeUnitOptions,
110-
Ellipsis,
110+
Dims,
111111
ErrorOptions,
112112
ErrorOptionsWithWarn,
113113
InterpOptions,
@@ -1698,14 +1698,14 @@ def set_coords(self: T_Dataset, names: Hashable | Iterable[Hashable]) -> T_Datas
16981698

16991699
def reset_coords(
17001700
self: T_Dataset,
1701-
names: Hashable | Iterable[Hashable] | None = None,
1701+
names: Dims = None,
17021702
drop: bool = False,
17031703
) -> T_Dataset:
17041704
"""Given names of coordinates, reset them to become variables
17051705
17061706
Parameters
17071707
----------
1708-
names : hashable or iterable of hashable, optional
1708+
names : str, Iterable of Hashable or None, optional
17091709
Name(s) of non-index coordinates in this dataset to reset into
17101710
variables. By default, all non-index coordinates are reset.
17111711
drop : bool, default: False
@@ -4457,7 +4457,7 @@ def _get_stack_index(
44574457

44584458
def _stack_once(
44594459
self: T_Dataset,
4460-
dims: Sequence[Hashable | Ellipsis],
4460+
dims: Sequence[Hashable | ellipsis],
44614461
new_dim: Hashable,
44624462
index_cls: type[Index],
44634463
create_index: bool | None = True,
@@ -4516,10 +4516,10 @@ def _stack_once(
45164516

45174517
def stack(
45184518
self: T_Dataset,
4519-
dimensions: Mapping[Any, Sequence[Hashable | Ellipsis]] | None = None,
4519+
dimensions: Mapping[Any, Sequence[Hashable | ellipsis]] | None = None,
45204520
create_index: bool | None = True,
45214521
index_cls: type[Index] = PandasMultiIndex,
4522-
**dimensions_kwargs: Sequence[Hashable | Ellipsis],
4522+
**dimensions_kwargs: Sequence[Hashable | ellipsis],
45234523
) -> T_Dataset:
45244524
"""
45254525
Stack any number of existing dimensions into a single new dimension.
@@ -4770,7 +4770,7 @@ def _unstack_full_reindex(
47704770

47714771
def unstack(
47724772
self: T_Dataset,
4773-
dim: Hashable | Iterable[Hashable] | None = None,
4773+
dim: Dims = None,
47744774
fill_value: Any = xrdtypes.NA,
47754775
sparse: bool = False,
47764776
) -> T_Dataset:
@@ -4782,7 +4782,7 @@ def unstack(
47824782
47834783
Parameters
47844784
----------
4785-
dim : hashable or iterable of hashable, optional
4785+
dim : str, Iterable of Hashable or None, optional
47864786
Dimension(s) over which to unstack. By default unstacks all
47874787
MultiIndexes.
47884788
fill_value : scalar or dict-like, default: nan
@@ -4860,15 +4860,13 @@ def unstack(
48604860
for v in nonindexes
48614861
)
48624862

4863-
for dim in dims:
4863+
for d in dims:
48644864
if needs_full_reindex:
48654865
result = result._unstack_full_reindex(
4866-
dim, stacked_indexes[dim], fill_value, sparse
4866+
d, stacked_indexes[d], fill_value, sparse
48674867
)
48684868
else:
4869-
result = result._unstack_once(
4870-
dim, stacked_indexes[dim], fill_value, sparse
4871-
)
4869+
result = result._unstack_once(d, stacked_indexes[d], fill_value, sparse)
48724870
return result
48734871

48744872
def update(self: T_Dataset, other: CoercibleMapping) -> T_Dataset:
@@ -5324,15 +5322,15 @@ def drop_isel(self: T_Dataset, indexers=None, **indexers_kwargs) -> T_Dataset:
53245322

53255323
def drop_dims(
53265324
self: T_Dataset,
5327-
drop_dims: Hashable | Iterable[Hashable],
5325+
drop_dims: str | Iterable[Hashable],
53285326
*,
53295327
errors: ErrorOptions = "raise",
53305328
) -> T_Dataset:
53315329
"""Drop dimensions and associated variables from this dataset.
53325330
53335331
Parameters
53345332
----------
5335-
drop_dims : hashable or iterable of hashable
5333+
drop_dims : str or Iterable of Hashable
53365334
Dimension or dimensions to drop.
53375335
errors : {"raise", "ignore"}, default: "raise"
53385336
If 'raise', raises a ValueError error if any of the
@@ -5763,7 +5761,7 @@ def combine_first(self: T_Dataset, other: T_Dataset) -> T_Dataset:
57635761
def reduce(
57645762
self: T_Dataset,
57655763
func: Callable,
5766-
dim: Hashable | Iterable[Hashable] = None,
5764+
dim: Dims | ellipsis = None,
57675765
*,
57685766
keep_attrs: bool | None = None,
57695767
keepdims: bool = False,
@@ -5778,8 +5776,8 @@ def reduce(
57785776
Function which can be called in the form
57795777
`f(x, axis=axis, **kwargs)` to return the result of reducing an
57805778
np.ndarray over an integer valued axis.
5781-
dim : str or sequence of str, optional
5782-
Dimension(s) over which to apply `func`. By default `func` is
5779+
dim : str, Iterable of Hashable or None, optional
5780+
Dimension(s) over which to apply `func`. By default `func` is
57835781
applied over all dimensions.
57845782
keep_attrs : bool or None, optional
57855783
If True, the dataset's attributes (`attrs`) will be copied from
@@ -5837,18 +5835,15 @@ def reduce(
58375835
or np.issubdtype(var.dtype, np.number)
58385836
or (var.dtype == np.bool_)
58395837
):
5840-
reduce_maybe_single: Hashable | None | list[Hashable]
5841-
if len(reduce_dims) == 1:
5842-
# unpack dimensions for the benefit of functions
5843-
# like np.argmin which can't handle tuple arguments
5844-
(reduce_maybe_single,) = reduce_dims
5845-
elif len(reduce_dims) == var.ndim:
5846-
# prefer to aggregate over axis=None rather than
5847-
# axis=(0, 1) if they will be equivalent, because
5848-
# the former is often more efficient
5849-
reduce_maybe_single = None
5850-
else:
5851-
reduce_maybe_single = reduce_dims
5838+
# prefer to aggregate over axis=None rather than
5839+
# axis=(0, 1) if they will be equivalent, because
5840+
# the former is often more efficient
5841+
# keep single-element dims as list, to support Hashables
5842+
reduce_maybe_single = (
5843+
None
5844+
if len(reduce_dims) == var.ndim and var.ndim != 1
5845+
else reduce_dims
5846+
)
58525847
variables[name] = var.reduce(
58535848
func,
58545849
dim=reduce_maybe_single,
@@ -6957,7 +6952,7 @@ def sortby(
69576952
def quantile(
69586953
self: T_Dataset,
69596954
q: ArrayLike,
6960-
dim: str | Iterable[Hashable] | None = None,
6955+
dim: Dims = None,
69616956
method: QUANTILE_METHODS = "linear",
69626957
numeric_only: bool = False,
69636958
keep_attrs: bool = None,
@@ -8303,7 +8298,9 @@ def argmin(self: T_Dataset, dim: Hashable | None = None, **kwargs) -> T_Dataset:
83038298
# Return int index if single dimension is passed, and is not part of a
83048299
# sequence
83058300
argmin_func = getattr(duck_array_ops, "argmin")
8306-
return self.reduce(argmin_func, dim=dim, **kwargs)
8301+
return self.reduce(
8302+
argmin_func, dim=None if dim is None else [dim], **kwargs
8303+
)
83078304
else:
83088305
raise ValueError(
83098306
"When dim is a sequence or ..., DataArray.argmin() returns a dict. "
@@ -8361,7 +8358,9 @@ def argmax(self: T_Dataset, dim: Hashable | None = None, **kwargs) -> T_Dataset:
83618358
# Return int index if single dimension is passed, and is not part of a
83628359
# sequence
83638360
argmax_func = getattr(duck_array_ops, "argmax")
8364-
return self.reduce(argmax_func, dim=dim, **kwargs)
8361+
return self.reduce(
8362+
argmax_func, dim=None if dim is None else [dim], **kwargs
8363+
)
83658364
else:
83668365
raise ValueError(
83678366
"When dim is a sequence or ..., DataArray.argmin() returns a dict. "
@@ -8469,7 +8468,7 @@ def curvefit(
84698468
self: T_Dataset,
84708469
coords: str | DataArray | Iterable[str | DataArray],
84718470
func: Callable[..., Any],
8472-
reduce_dims: Hashable | Iterable[Hashable] | None = None,
8471+
reduce_dims: Dims = None,
84738472
skipna: bool = True,
84748473
p0: dict[str, Any] | None = None,
84758474
bounds: dict[str, Any] | None = None,
@@ -8494,7 +8493,7 @@ def curvefit(
84948493
array of length `len(x)`. `params` are the fittable parameters which are optimized
84958494
by scipy curve_fit. `x` can also be specified as a sequence containing multiple
84968495
coordinates, e.g. `f((x0, x1), *params)`.
8497-
reduce_dims : hashable or sequence of hashable
8496+
reduce_dims : str, Iterable of Hashable or None, optional
84988497
Additional dimension(s) over which to aggregate while fitting. For example,
84998498
calling `ds.curvefit(coords='time', reduce_dims=['lat', 'lon'], ...)` will
85008499
aggregate all lat and lon points and fit the specified function along the
@@ -8545,6 +8544,7 @@ def curvefit(
85458544
if kwargs is None:
85468545
kwargs = {}
85478546

8547+
reduce_dims_: list[Hashable]
85488548
if not reduce_dims:
85498549
reduce_dims_ = []
85508550
elif isinstance(reduce_dims, str) or not isinstance(reduce_dims, Iterable):

0 commit comments

Comments
 (0)