Skip to content

Commit 9a3f95e

Browse files
committed
More precisely type pipe methods.
In addition, enhance mypy job configuration to support running it locally via `act`. Fixes #9997
1 parent 56f9e4c commit 9a3f95e

File tree

6 files changed

+676
-75
lines changed

6 files changed

+676
-75
lines changed

.github/workflows/ci-additional.yaml

+26-54
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,15 @@ jobs:
3333
with:
3434
keyword: "[skip-ci]"
3535

36+
detect-act:
37+
name: Detect 'act' runner
38+
runs-on: ubuntu-latest
39+
outputs:
40+
running: ${{ steps.detect-act.outputs.running }}
41+
steps:
42+
- id: detect-act
43+
run: echo "running=${{ env.ACT || 'false' }}" >> "$GITHUB_OUTPUT"
44+
3645
doctest:
3746
name: Doctests
3847
runs-on: "ubuntu-latest"
@@ -81,15 +90,23 @@ jobs:
8190
python -m pytest --doctest-modules xarray --ignore xarray/tests -Werror
8291
8392
mypy:
84-
name: Mypy
93+
strategy:
94+
matrix:
95+
include:
96+
- python-version: "3.10"
97+
codecov-flags: mypy-min
98+
- python-version: "3.12"
99+
codecov-flags: mypy
100+
name: Mypy ${{ matrix.python-version }}
85101
runs-on: "ubuntu-latest"
86-
needs: detect-ci-trigger
102+
needs: [detect-ci-trigger, detect-act]
103+
if: always() && (needs.detect-ci-trigger.outputs.triggered == 'true' || needs.detect-act.outputs.running == 'true')
87104
defaults:
88105
run:
89106
shell: bash -l {0}
90107
env:
91108
CONDA_ENV_FILE: ci/requirements/environment.yml
92-
PYTHON_VERSION: "3.12"
109+
PYTHON_VERSION: ${{ matrix.python-version }}
93110

94111
steps:
95112
- uses: actions/checkout@v4
@@ -116,68 +133,23 @@ jobs:
116133
python xarray/util/print_versions.py
117134
- name: Install mypy
118135
run: |
119-
python -m pip install mypy
136+
python -m pip install mypy pytest-mypy-plugins
120137
121138
- name: Run mypy
122139
run: |
123140
python -m mypy --install-types --non-interactive --cobertura-xml-report mypy_report
124141
125-
- name: Upload mypy coverage to Codecov
126-
uses: codecov/[email protected]
127-
with:
128-
files: mypy_report/cobertura.xml
129-
flags: mypy
130-
env_vars: PYTHON_VERSION
131-
name: codecov-umbrella
132-
fail_ci_if_error: false
133-
134-
mypy-min:
135-
name: Mypy 3.10
136-
runs-on: "ubuntu-latest"
137-
needs: detect-ci-trigger
138-
defaults:
139-
run:
140-
shell: bash -l {0}
141-
env:
142-
CONDA_ENV_FILE: ci/requirements/environment.yml
143-
PYTHON_VERSION: "3.10"
144-
145-
steps:
146-
- uses: actions/checkout@v4
147-
with:
148-
fetch-depth: 0 # Fetch all history for all branches and tags.
149-
150-
- name: set environment variables
142+
- name: Run mypy tests
143+
# Run pytest with mypy plugin even if mypy analysis in previous step fails.
144+
if: ${{ always() }}
151145
run: |
152-
echo "TODAY=$(date +'%Y-%m-%d')" >> $GITHUB_ENV
153-
- name: Setup micromamba
154-
uses: mamba-org/setup-micromamba@v2
155-
with:
156-
environment-file: ${{env.CONDA_ENV_FILE}}
157-
environment-name: xarray-tests
158-
create-args: >-
159-
python=${{env.PYTHON_VERSION}}
160-
cache-environment: true
161-
cache-environment-key: "${{runner.os}}-${{runner.arch}}-py${{env.PYTHON_VERSION}}-${{env.TODAY}}-${{hashFiles(env.CONDA_ENV_FILE)}}"
162-
- name: Install xarray
163-
run: |
164-
python -m pip install --no-deps -e .
165-
- name: Version info
166-
run: |
167-
python xarray/util/print_versions.py
168-
- name: Install mypy
169-
run: |
170-
python -m pip install mypy
171-
172-
- name: Run mypy
173-
run: |
174-
python -m mypy --install-types --non-interactive --cobertura-xml-report mypy_report
146+
python -m pytest -v --mypy-only-local-stub --mypy-pyproject-toml-file=pyproject.toml xarray/**/test_*.yml
175147
176148
- name: Upload mypy coverage to Codecov
177149
uses: codecov/[email protected]
178150
with:
179151
files: mypy_report/cobertura.xml
180-
flags: mypy-min
152+
flags: ${{ matrix.codecov-flags }}
181153
env_vars: PYTHON_VERSION
182154
name: codecov-umbrella
183155
fail_ci_if_error: false

xarray/core/common.py

+27-6
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from contextlib import suppress
77
from html import escape
88
from textwrap import dedent
9-
from typing import TYPE_CHECKING, Any, TypeVar, Union, overload
9+
from typing import TYPE_CHECKING, Any, Concatenate, ParamSpec, TypeVar, Union, overload
1010

1111
import numpy as np
1212
import pandas as pd
@@ -60,6 +60,7 @@
6060
T_Resample = TypeVar("T_Resample", bound="Resample")
6161
C = TypeVar("C")
6262
T = TypeVar("T")
63+
P = ParamSpec("P")
6364

6465

6566
class ImplementsArrayReduce:
@@ -718,11 +719,27 @@ def assign_attrs(self, *args: Any, **kwargs: Any) -> Self:
718719
out.attrs.update(*args, **kwargs)
719720
return out
720721

722+
@overload
723+
def pipe(
724+
self,
725+
func: Callable[Concatenate[Self, P], T],
726+
*args: P.args,
727+
**kwargs: P.kwargs,
728+
) -> T: ...
729+
730+
@overload
721731
def pipe(
722732
self,
723-
func: Callable[..., T] | tuple[Callable[..., T], str],
733+
func: tuple[Callable[..., T], str],
724734
*args: Any,
725735
**kwargs: Any,
736+
) -> T: ...
737+
738+
def pipe(
739+
self,
740+
func: Callable[Concatenate[Self, P], T] | tuple[Callable[P, T], str],
741+
*args: P.args,
742+
**kwargs: P.kwargs,
726743
) -> T:
727744
"""
728745
Apply ``func(self, *args, **kwargs)``
@@ -840,15 +857,19 @@ def pipe(
840857
pandas.DataFrame.pipe
841858
"""
842859
if isinstance(func, tuple):
843-
func, target = func
860+
# Use different var when unpacking function from tuple because the type
861+
# signature of the unpacked function differs from the expected type
862+
# signature in the case where only a function is given, rather than a tuple.
863+
# This makes type checkers happy at both call sites below.
864+
f, target = func
844865
if target in kwargs:
845866
raise ValueError(
846867
f"{target} is both the pipe target and a keyword argument"
847868
)
848869
kwargs[target] = self
849-
return func(*args, **kwargs)
850-
else:
851-
return func(self, *args, **kwargs)
870+
return f(*args, **kwargs)
871+
872+
return func(self, *args, **kwargs)
852873

853874
def rolling_exp(
854875
self: T_DataWithCoords,

xarray/core/datatree.py

+53-15
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,17 @@
1212
Mapping,
1313
)
1414
from html import escape
15-
from typing import TYPE_CHECKING, Any, Literal, NoReturn, Union, overload
15+
from typing import (
16+
TYPE_CHECKING,
17+
Any,
18+
Concatenate,
19+
Literal,
20+
NoReturn,
21+
ParamSpec,
22+
TypeVar,
23+
Union,
24+
overload,
25+
)
1626

1727
from xarray.core import utils
1828
from xarray.core._aggregations import DataTreeAggregations
@@ -79,18 +89,23 @@
7989
# """
8090
# DEVELOPERS' NOTE
8191
# ----------------
82-
# The idea of this module is to create a `DataTree` class which inherits the tree structure from TreeNode, and also copies
83-
# the entire API of `xarray.Dataset`, but with certain methods decorated to instead map the dataset function over every
84-
# node in the tree. As this API is copied without directly subclassing `xarray.Dataset` we instead create various Mixin
85-
# classes (in ops.py) which each define part of `xarray.Dataset`'s extensive API.
92+
# The idea of this module is to create a `DataTree` class which inherits the tree
93+
# structure from TreeNode, and also copies the entire API of `xarray.Dataset`, but with
94+
# certain methods decorated to instead map the dataset function over every node in the
95+
# tree. As this API is copied without directly subclassing `xarray.Dataset` we instead
96+
# create various Mixin classes (in ops.py) which each define part of `xarray.Dataset`'s
97+
# extensive API.
8698
#
87-
# Some of these methods must be wrapped to map over all nodes in the subtree. Others are fine to inherit unaltered
88-
# (normally because they (a) only call dataset properties and (b) don't return a dataset that should be nested into a new
89-
# tree) and some will get overridden by the class definition of DataTree.
99+
# Some of these methods must be wrapped to map over all nodes in the subtree. Others are
100+
# fine to inherit unaltered (normally because they (a) only call dataset properties and
101+
# (b) don't return a dataset that should be nested into a new tree) and some will get
102+
# overridden by the class definition of DataTree.
90103
# """
91104

92105

93106
T_Path = Union[str, NodePath]
107+
T = TypeVar("T")
108+
P = ParamSpec("P")
94109

95110

96111
def _collect_data_and_coord_variables(
@@ -1460,9 +1475,28 @@ def map_over_datasets(
14601475
# TODO fix this typing error
14611476
return map_over_datasets(func, self, *args)
14621477

1478+
@overload
1479+
def pipe(
1480+
self,
1481+
func: Callable[Concatenate[Self, P], T],
1482+
*args: P.args,
1483+
**kwargs: P.kwargs,
1484+
) -> T: ...
1485+
1486+
@overload
1487+
def pipe(
1488+
self,
1489+
func: tuple[Callable[..., T], str],
1490+
*args: Any,
1491+
**kwargs: Any,
1492+
) -> T: ...
1493+
14631494
def pipe(
1464-
self, func: Callable | tuple[Callable, str], *args: Any, **kwargs: Any
1465-
) -> Any:
1495+
self,
1496+
func: Callable[Concatenate[Self, P], T] | tuple[Callable[..., T], str],
1497+
*args: Any,
1498+
**kwargs: Any,
1499+
) -> T:
14661500
"""Apply ``func(self, *args, **kwargs)``
14671501
14681502
This method replicates the pandas method of the same name.
@@ -1482,7 +1516,7 @@ def pipe(
14821516
14831517
Returns
14841518
-------
1485-
object : Any
1519+
object : T
14861520
the return type of ``func``.
14871521
14881522
Notes
@@ -1510,15 +1544,19 @@ def pipe(
15101544
15111545
"""
15121546
if isinstance(func, tuple):
1513-
func, target = func
1547+
# Use different var when unpacking function from tuple because the type
1548+
# signature of the unpacked function differs from the expected type
1549+
# signature in the case where only a function is given, rather than a tuple.
1550+
# This makes type checkers happy at both call sites below.
1551+
f, target = func
15141552
if target in kwargs:
15151553
raise ValueError(
15161554
f"{target} is both the pipe target and a keyword argument"
15171555
)
15181556
kwargs[target] = self
1519-
else:
1520-
args = (self,) + args
1521-
return func(*args, **kwargs)
1557+
return f(*args, **kwargs)
1558+
1559+
return func(self, *args, **kwargs)
15221560

15231561
# TODO some kind of .collapse() or .flatten() method to merge a subtree
15241562

0 commit comments

Comments
 (0)