Skip to content

Commit da9c1d1

Browse files
authored
Add typing to plot methods (#7052)
* add plot methods statically and add typing to plot tests * whats-new update * fix copy-paste typo * correct plot signatures * add *some* typing to plot methods * annotate darray in plot tests * correct typing of plot returns * fix plotting overloads * add correct overloads to dataset_plot * update whats-new * rename xr.plot.plot module since it shadows the xr.plot.plot method * move accessor to its own module * move DSPlotAccessor to accessor module * fix DSPlotAccessor import * add explanation to import statement * add breaking change to whats-new * remove unused `rtol` argument from plot * make most arguments of plotmethods kwargs only * fix wrong return types * add breaking kwarg change to whats-new * support for aspect='auto' or 'equal * typing support for Dataset FacetGrid * deprecate positional arguments for all plot methods * add deprecation to whats-new * add FacetGrid generic type * fix mypy 0.981 complaints * fix index errors in plots * add overloads to scatter * deprecate scatter args * add scatter to accessors and fix docstrings * undo some breaking changes * fix the docstrings and some typing * fix typing of scatter accessor funcs * align docstrings with signature and complete typing * add remaining typing * align more docstrings * re add ValueError for scatter plots with u, v * fix whats-new conflict * fix some typing errors * more typing fixes * fix last mypy complaints * try fixing facetgrid examples * fix py3.8 problems * update plotting.rst * update api * update plot docstring * add a tip about yincrease in imshow * set default for x/yincrease in docstring * simplify typing * add deprecation date as comment * update whats-new to new release * fix whats-new
1 parent 50301ac commit da9c1d1

21 files changed

+4992
-2181
lines changed

ci/requirements/min-all-deps.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@ channels:
33
- conda-forge
44
- nodefaults
55
dependencies:
6-
# MINIMUM VERSIONS POLICY: see doc/installing.rst
6+
# MINIMUM VERSIONS POLICY: see doc/user-guide/installing.rst
77
# Run ci/min_deps_check.py to verify that this file respects the policy.
88
# When upgrading python, numpy, or pandas, must also change
9-
# doc/installing.rst and setup.py.
9+
# doc/user-guide/installing.rst, doc/user-guide/plotting.rst and setup.py.
1010
- python=3.8
1111
- boto3=1.18
1212
- bottleneck=1.3

doc/api-hidden.rst

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -330,11 +330,6 @@
330330
plot.scatter
331331
plot.surface
332332

333-
plot.FacetGrid.map_dataarray
334-
plot.FacetGrid.set_titles
335-
plot.FacetGrid.set_ticks
336-
plot.FacetGrid.map
337-
338333
CFTimeIndex.all
339334
CFTimeIndex.any
340335
CFTimeIndex.append

doc/api.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -703,6 +703,7 @@ DataArray
703703
DataArray.plot.line
704704
DataArray.plot.pcolormesh
705705
DataArray.plot.step
706+
DataArray.plot.scatter
706707
DataArray.plot.surface
707708

708709

@@ -719,6 +720,7 @@ Faceting
719720
plot.FacetGrid.map_dataarray
720721
plot.FacetGrid.map_dataarray_line
721722
plot.FacetGrid.map_dataset
723+
plot.FacetGrid.map_plot1d
722724
plot.FacetGrid.set_axis_labels
723725
plot.FacetGrid.set_ticks
724726
plot.FacetGrid.set_titles

doc/user-guide/plotting.rst

Lines changed: 34 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ Matplotlib must be installed before xarray can plot.
2727

2828
To use xarray's plotting capabilities with time coordinates containing
2929
``cftime.datetime`` objects
30-
`nc-time-axis <https://github.com/SciTools/nc-time-axis>`_ v1.2.0 or later
30+
`nc-time-axis <https://github.com/SciTools/nc-time-axis>`_ v1.3.0 or later
3131
needs to be installed.
3232

3333
For more extensive plotting applications consider the following projects:
@@ -106,7 +106,13 @@ The simplest way to make a plot is to call the :py:func:`DataArray.plot()` metho
106106
@savefig plotting_1d_simple.png width=4in
107107
air1d.plot()
108108
109-
Xarray uses the coordinate name along with metadata ``attrs.long_name``, ``attrs.standard_name``, ``DataArray.name`` and ``attrs.units`` (if available) to label the axes. The names ``long_name``, ``standard_name`` and ``units`` are copied from the `CF-conventions spec <https://cfconventions.org/Data/cf-conventions/cf-conventions-1.7/build/ch03s03.html>`_. When choosing names, the order of precedence is ``long_name``, ``standard_name`` and finally ``DataArray.name``. The y-axis label in the above plot was constructed from the ``long_name`` and ``units`` attributes of ``air1d``.
109+
Xarray uses the coordinate name along with metadata ``attrs.long_name``,
110+
``attrs.standard_name``, ``DataArray.name`` and ``attrs.units`` (if available)
111+
to label the axes.
112+
The names ``long_name``, ``standard_name`` and ``units`` are copied from the
113+
`CF-conventions spec <https://cfconventions.org/Data/cf-conventions/cf-conventions-1.7/build/ch03s03.html>`_.
114+
When choosing names, the order of precedence is ``long_name``, ``standard_name`` and finally ``DataArray.name``.
115+
The y-axis label in the above plot was constructed from the ``long_name`` and ``units`` attributes of ``air1d``.
110116

111117
.. ipython:: python
112118
@@ -340,7 +346,10 @@ The keyword arguments ``xincrease`` and ``yincrease`` let you control the axes d
340346
y="lat", hue="lon", xincrease=False, yincrease=False
341347
)
342348
343-
In addition, one can use ``xscale, yscale`` to set axes scaling; ``xticks, yticks`` to set axes ticks and ``xlim, ylim`` to set axes limits. These accept the same values as the matplotlib methods ``Axes.set_(x,y)scale()``, ``Axes.set_(x,y)ticks()``, ``Axes.set_(x,y)lim()`` respectively.
349+
In addition, one can use ``xscale, yscale`` to set axes scaling;
350+
``xticks, yticks`` to set axes ticks and ``xlim, ylim`` to set axes limits.
351+
These accept the same values as the matplotlib methods ``Axes.set_(x,y)scale()``,
352+
``Axes.set_(x,y)ticks()``, ``Axes.set_(x,y)lim()`` respectively.
344353

345354

346355
Two Dimensions
@@ -350,7 +359,8 @@ Two Dimensions
350359
Simple Example
351360
================
352361

353-
The default method :py:meth:`DataArray.plot` calls :py:func:`xarray.plot.pcolormesh` by default when the data is two-dimensional.
362+
The default method :py:meth:`DataArray.plot` calls :py:func:`xarray.plot.pcolormesh`
363+
by default when the data is two-dimensional.
354364

355365
.. ipython:: python
356366
:okwarning:
@@ -585,7 +595,10 @@ Faceting here refers to splitting an array along one or two dimensions and
585595
plotting each group.
586596
Xarray's basic plotting is useful for plotting two dimensional arrays. What
587597
about three or four dimensional arrays? That's where facets become helpful.
588-
The general approach to plotting here is called “small multiples”, where the same kind of plot is repeated multiple times, and the specific use of small multiples to display the same relationship conditioned on one or more other variables is often called a “trellis plot”.
598+
The general approach to plotting here is called “small multiples”, where the
599+
same kind of plot is repeated multiple times, and the specific use of small
600+
multiples to display the same relationship conditioned on one or more other
601+
variables is often called a “trellis plot”.
589602

590603
Consider the temperature data set. There are 4 observations per day for two
591604
years which makes for 2920 values along the time dimension.
@@ -670,8 +683,8 @@ Faceted plotting supports other arguments common to xarray 2d plots.
670683
671684
@savefig plot_facet_robust.png
672685
g = hasoutliers.plot.pcolormesh(
673-
"lon",
674-
"lat",
686+
x="lon",
687+
y="lat",
675688
col="time",
676689
col_wrap=3,
677690
robust=True,
@@ -711,7 +724,7 @@ they have been plotted.
711724
.. ipython:: python
712725
:okwarning:
713726
714-
g = t.plot.imshow("lon", "lat", col="time", col_wrap=3, robust=True)
727+
g = t.plot.imshow(x="lon", y="lat", col="time", col_wrap=3, robust=True)
715728
716729
for i, ax in enumerate(g.axes.flat):
717730
ax.set_title("Air Temperature %d" % i)
@@ -727,7 +740,8 @@ they have been plotted.
727740
axis labels, axis ticks and plot titles. See :py:meth:`~xarray.plot.FacetGrid.set_titles`,
728741
:py:meth:`~xarray.plot.FacetGrid.set_xlabels`, :py:meth:`~xarray.plot.FacetGrid.set_ylabels` and
729742
:py:meth:`~xarray.plot.FacetGrid.set_ticks` for more information.
730-
Plotting functions can be applied to each subset of the data by calling :py:meth:`~xarray.plot.FacetGrid.map_dataarray` or to each subplot by calling :py:meth:`~xarray.plot.FacetGrid.map`.
743+
Plotting functions can be applied to each subset of the data by calling
744+
:py:meth:`~xarray.plot.FacetGrid.map_dataarray` or to each subplot by calling :py:meth:`~xarray.plot.FacetGrid.map`.
731745

732746
TODO: add an example of using the ``map`` method to plot dataset variables
733747
(e.g., with ``plt.quiver``).
@@ -777,7 +791,8 @@ Additionally, the boolean kwarg ``add_guide`` can be used to prevent the display
777791
@savefig ds_discrete_legend_hue_scatter.png
778792
ds.plot.scatter(x="A", y="B", hue="w", hue_style="discrete")
779793
780-
The ``markersize`` kwarg lets you vary the point's size by variable value. You can additionally pass ``size_norm`` to control how the variable's values are mapped to point sizes.
794+
The ``markersize`` kwarg lets you vary the point's size by variable value.
795+
You can additionally pass ``size_norm`` to control how the variable's values are mapped to point sizes.
781796

782797
.. ipython:: python
783798
:okwarning:
@@ -794,7 +809,8 @@ Faceting is also possible
794809
ds.plot.scatter(x="A", y="B", col="x", row="z", hue="w", hue_style="discrete")
795810
796811
797-
For more advanced scatter plots, we recommend converting the relevant data variables to a pandas DataFrame and using the extensive plotting capabilities of ``seaborn``.
812+
For more advanced scatter plots, we recommend converting the relevant data variables
813+
to a pandas DataFrame and using the extensive plotting capabilities of ``seaborn``.
798814

799815
Quiver
800816
~~~~~~
@@ -816,7 +832,8 @@ where ``u`` and ``v`` denote the x and y direction components of the arrow vecto
816832
@savefig ds_facet_quiver.png
817833
ds.plot.quiver(x="x", y="y", u="A", v="B", col="w", row="z", scale=4)
818834
819-
``scale`` is required for faceted quiver plots. The scale determines the number of data units per arrow length unit, i.e. a smaller scale parameter makes the arrow longer.
835+
``scale`` is required for faceted quiver plots.
836+
The scale determines the number of data units per arrow length unit, i.e. a smaller scale parameter makes the arrow longer.
820837

821838
Streamplot
822839
~~~~~~~~~~
@@ -830,7 +847,8 @@ Visualizing vector fields is also supported with streamline plots:
830847
ds.isel(w=1, z=1).plot.streamplot(x="x", y="y", u="A", v="B")
831848
832849
833-
where ``u`` and ``v`` denote the x and y direction components of the vectors tangent to the streamlines. Again, faceting is also possible:
850+
where ``u`` and ``v`` denote the x and y direction components of the vectors tangent to the streamlines.
851+
Again, faceting is also possible:
834852

835853
.. ipython:: python
836854
:okwarning:
@@ -983,7 +1001,7 @@ instead of the default ones:
9831001
)
9841002
9851003
@savefig plotting_example_2d_irreg.png width=4in
986-
da.plot.pcolormesh("lon", "lat")
1004+
da.plot.pcolormesh(x="lon", y="lat")
9871005
9881006
Note that in this case, xarray still follows the pixel centered convention.
9891007
This might be undesirable in some cases, for example when your data is defined
@@ -996,7 +1014,7 @@ this convention when plotting on a map:
9961014
import cartopy.crs as ccrs
9971015
9981016
ax = plt.subplot(projection=ccrs.PlateCarree())
999-
da.plot.pcolormesh("lon", "lat", ax=ax)
1017+
da.plot.pcolormesh(x="lon", y="lat", ax=ax)
10001018
ax.scatter(lon, lat, transform=ccrs.PlateCarree())
10011019
ax.coastlines()
10021020
@savefig plotting_example_2d_irreg_map.png width=4in
@@ -1009,7 +1027,7 @@ You can however decide to infer the cell boundaries and use the
10091027
:okwarning:
10101028
10111029
ax = plt.subplot(projection=ccrs.PlateCarree())
1012-
da.plot.pcolormesh("lon", "lat", ax=ax, infer_intervals=True)
1030+
da.plot.pcolormesh(x="lon", y="lat", ax=ax, infer_intervals=True)
10131031
ax.scatter(lon, lat, transform=ccrs.PlateCarree())
10141032
ax.coastlines()
10151033
@savefig plotting_example_2d_irreg_map_infer.png width=4in

doc/whats-new.rst

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,22 @@ v2022.10.1 (unreleased)
2323
New Features
2424
~~~~~~~~~~~~
2525

26+
- Add static typing to plot accessors (:issue:`6949`, :pull:`7052`).
27+
By `Michael Niklas <https://github.com/headtr1ck>`_.
2628

2729
Breaking changes
2830
~~~~~~~~~~~~~~~~
2931

32+
- Many arguments of plotmethods have been made keyword-only.
33+
- ``xarray.plot.plot`` module renamed to ``xarray.plot.dataarray_plot`` to prevent
34+
shadowing of the ``plot`` method. (:issue:`6949`, :pull:`7052`).
35+
By `Michael Niklas <https://github.com/headtr1ck>`_.
3036

3137
Deprecations
3238
~~~~~~~~~~~~
3339

40+
- Positional arguments for all plot methods have been deprecated (:issue:`6949`, :pull:`7052`).
41+
By `Michael Niklas <https://github.com/headtr1ck>`_.
3442

3543
Bug fixes
3644
~~~~~~~~~
@@ -64,8 +72,8 @@ New Features
6472
the z argument. (:pull:`6778`)
6573
By `Jimmy Westling <https://github.com/illviljan>`_.
6674
- Include the variable name in the error message when CF decoding fails to allow
67-
for easier identification of problematic variables (:issue:`7145`,
68-
:pull:`7147`). By `Spencer Clark <https://github.com/spencerkclark>`_.
75+
for easier identification of problematic variables (:issue:`7145`, :pull:`7147`).
76+
By `Spencer Clark <https://github.com/spencerkclark>`_.
6977

7078
Breaking changes
7179
~~~~~~~~~~~~~~~~

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ module = [
4848
"importlib_metadata.*",
4949
"iris.*",
5050
"matplotlib.*",
51+
"mpl_toolkits.*",
5152
"Nio.*",
5253
"nc_time_axis.*",
5354
"numbagg.*",

setup.cfg

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,6 @@ float_to_top = true
165165
default_section = THIRDPARTY
166166
known_first_party = xarray
167167

168-
169168
[aliases]
170169
test = pytest
171170

xarray/core/alignment.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
if TYPE_CHECKING:
3939
from .dataarray import DataArray
4040
from .dataset import Dataset
41-
from .types import JoinOptions, T_DataArray, T_DataArrayOrSet, T_Dataset
41+
from .types import JoinOptions, T_DataArray, T_Dataset, T_DataWithCoords
4242

4343
DataAlignable = TypeVar("DataAlignable", bound=DataWithCoords)
4444

@@ -944,8 +944,8 @@ def _get_broadcast_dims_map_common_coords(args, exclude):
944944

945945

946946
def _broadcast_helper(
947-
arg: T_DataArrayOrSet, exclude, dims_map, common_coords
948-
) -> T_DataArrayOrSet:
947+
arg: T_DataWithCoords, exclude, dims_map, common_coords
948+
) -> T_DataWithCoords:
949949

950950
from .dataarray import DataArray
951951
from .dataset import Dataset
@@ -976,14 +976,16 @@ def _broadcast_dataset(ds: T_Dataset) -> T_Dataset:
976976

977977
# remove casts once https://github.com/python/mypy/issues/12800 is resolved
978978
if isinstance(arg, DataArray):
979-
return cast("T_DataArrayOrSet", _broadcast_array(arg))
979+
return cast("T_DataWithCoords", _broadcast_array(arg))
980980
elif isinstance(arg, Dataset):
981-
return cast("T_DataArrayOrSet", _broadcast_dataset(arg))
981+
return cast("T_DataWithCoords", _broadcast_dataset(arg))
982982
else:
983983
raise ValueError("all input must be Dataset or DataArray objects")
984984

985985

986-
def broadcast(*args, exclude=None):
986+
# TODO: this typing is too restrictive since it cannot deal with mixed
987+
# DataArray and Dataset types...? Is this a problem?
988+
def broadcast(*args: T_DataWithCoords, exclude=None) -> tuple[T_DataWithCoords, ...]:
987989
"""Explicitly broadcast any number of DataArray or Dataset objects against
988990
one another.
989991

xarray/core/dataarray.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
from ..coding.calendar_ops import convert_calendar, interp_calendar
2424
from ..coding.cftimeindex import CFTimeIndex
25-
from ..plot.plot import _PlotMethods
25+
from ..plot.accessor import DataArrayPlotAccessor
2626
from ..plot.utils import _get_units_from_attrs
2727
from . import alignment, computation, dtypes, indexing, ops, utils
2828
from ._reductions import DataArrayReductions
@@ -4189,7 +4189,7 @@ def _inplace_binary_op(self: T_DataArray, other: Any, f: Callable) -> T_DataArra
41894189
def _copy_attrs_from(self, other: DataArray | Dataset | Variable) -> None:
41904190
self.attrs = other.attrs
41914191

4192-
plot = utils.UncachedAccessor(_PlotMethods)
4192+
plot = utils.UncachedAccessor(DataArrayPlotAccessor)
41934193

41944194
def _title_for_slice(self, truncate: int = 50) -> str:
41954195
"""

xarray/core/dataset.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535

3636
from ..coding.calendar_ops import convert_calendar, interp_calendar
3737
from ..coding.cftimeindex import CFTimeIndex, _parse_array_of_cftime_strings
38-
from ..plot.dataset_plot import _Dataset_PlotMethods
38+
from ..plot.accessor import DatasetPlotAccessor
3939
from . import alignment
4040
from . import dtypes as xrdtypes
4141
from . import duck_array_ops, formatting, formatting_html, ops, utils
@@ -7483,7 +7483,7 @@ def imag(self: T_Dataset) -> T_Dataset:
74837483
"""
74847484
return self.map(lambda x: x.imag, keep_attrs=True)
74857485

7486-
plot = utils.UncachedAccessor(_Dataset_PlotMethods)
7486+
plot = utils.UncachedAccessor(DatasetPlotAccessor)
74877487

74887488
def filter_by_attrs(self: T_Dataset, **kwargs) -> T_Dataset:
74897489
"""Returns a ``Dataset`` with variables that match specific conditions.
@@ -8575,7 +8575,9 @@ def curvefit(
85758575
or not isinstance(coords, Iterable)
85768576
):
85778577
coords = [coords]
8578-
coords_ = [self[coord] if isinstance(coord, str) else coord for coord in coords]
8578+
coords_: Sequence[DataArray] = [
8579+
self[coord] if isinstance(coord, str) else coord for coord in coords
8580+
]
85798581

85808582
# Determine whether any coords are dims on self
85818583
for coord in coords_:

xarray/core/pycompat.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
from importlib import import_module
4+
from typing import Any, Literal
45

56
import numpy as np
67
from packaging.version import Version
@@ -9,6 +10,8 @@
910

1011
integer_types = (int, np.integer)
1112

13+
ModType = Literal["dask", "pint", "cupy", "sparse"]
14+
1215

1316
class DuckArrayModule:
1417
"""
@@ -18,7 +21,12 @@ class DuckArrayModule:
1821
https://github.com/pydata/xarray/pull/5561#discussion_r664815718
1922
"""
2023

21-
def __init__(self, mod):
24+
module: ModType | None
25+
version: Version
26+
type: tuple[type[Any]] # TODO: improve this? maybe Generic
27+
available: bool
28+
29+
def __init__(self, mod: ModType) -> None:
2230
try:
2331
duck_array_module = import_module(mod)
2432
duck_array_version = Version(duck_array_module.__version__)

xarray/core/types.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,10 @@ def dtype(self) -> np.dtype:
162162
CoarsenBoundaryOptions = Literal["exact", "trim", "pad"]
163163
SideOptions = Literal["left", "right"]
164164

165+
ScaleOptions = Literal["linear", "symlog", "log", "logit", None]
165166
HueStyleOptions = Literal["continuous", "discrete", None]
167+
AspectOptions = Union[Literal["auto", "equal"], float, None]
168+
ExtendOptions = Literal["neither", "both", "min", "max", None]
166169

167170
# TODO: Wait until mypy supports recursive objects in combination with typevars
168171
_T = TypeVar("_T")

0 commit comments

Comments
 (0)