diff --git a/doc/whats-new.rst b/doc/whats-new.rst index f60941033f4..04b06fff221 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -40,6 +40,9 @@ Breaking changes extracts and add the indexes from another :py:class:`Coordinates` object passed via ``coords`` (:pull:`8107`). By `BenoƮt Bovy `_. +- Static typing of ``xlim`` and ``ylim`` arguments in plotting functions now must + be ``tuple[float, float]`` to align with matplotlib requirements. (:issue:`7802`, :pull:`8030`). + By `Michael Niklas `_. Deprecations ~~~~~~~~~~~~ diff --git a/xarray/core/options.py b/xarray/core/options.py index eb0c56c7ee0..a197cb4da10 100644 --- a/xarray/core/options.py +++ b/xarray/core/options.py @@ -6,10 +6,8 @@ from xarray.core.utils import FrozenDict if TYPE_CHECKING: - try: - from matplotlib.colors import Colormap - except ImportError: - Colormap = str + from matplotlib.colors import Colormap + Options = Literal[ "arithmetic_join", "cmap_divergent", @@ -164,11 +162,11 @@ class set_options: cmap_divergent : str or matplotlib.colors.Colormap, default: "RdBu_r" Colormap to use for divergent data plots. If string, must be matplotlib built-in colormap. Can also be a Colormap object - (e.g. mpl.cm.magma) + (e.g. mpl.colormaps["magma"]) cmap_sequential : str or matplotlib.colors.Colormap, default: "viridis" Colormap to use for nondivergent data plots. If string, must be matplotlib built-in colormap. Can also be a Colormap object - (e.g. mpl.cm.magma) + (e.g. mpl.colormaps["magma"]) display_expand_attrs : {"default", True, False} Whether to expand the attributes section for display of ``DataArray`` or ``Dataset`` objects. Can be diff --git a/xarray/plot/accessor.py b/xarray/plot/accessor.py index ff707602545..203bae2691f 100644 --- a/xarray/plot/accessor.py +++ b/xarray/plot/accessor.py @@ -16,6 +16,7 @@ from matplotlib.container import BarContainer from matplotlib.contour import QuadContourSet from matplotlib.image import AxesImage + from matplotlib.patches import Polygon from matplotlib.quiver import Quiver from mpl_toolkits.mplot3d.art3d import Line3D, Poly3DCollection from numpy.typing import ArrayLike @@ -47,11 +48,13 @@ def __call__(self, **kwargs) -> Any: return dataarray_plot.plot(self._da, **kwargs) @functools.wraps(dataarray_plot.hist) - def hist(self, *args, **kwargs) -> tuple[np.ndarray, np.ndarray, BarContainer]: + def hist( + self, *args, **kwargs + ) -> tuple[np.ndarray, np.ndarray, BarContainer | Polygon]: return dataarray_plot.hist(self._da, *args, **kwargs) @overload - def line( # type: ignore[misc] # None is hashable :( + def line( # type: ignore[misc,unused-ignore] # None is hashable :( self, *args: Any, row: None = None, # no wrap -> primitive @@ -69,8 +72,8 @@ def line( # type: ignore[misc] # None is hashable :( yscale: ScaleOptions = None, xticks: ArrayLike | None = None, yticks: ArrayLike | None = None, - xlim: ArrayLike | None = None, - ylim: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, add_legend: bool = True, _labels: bool = True, **kwargs: Any, @@ -96,8 +99,8 @@ def line( yscale: ScaleOptions = None, xticks: ArrayLike | None = None, yticks: ArrayLike | None = None, - xlim: ArrayLike | None = None, - ylim: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, add_legend: bool = True, _labels: bool = True, **kwargs: Any, @@ -123,20 +126,20 @@ def line( yscale: ScaleOptions = None, xticks: ArrayLike | None = None, yticks: ArrayLike | None = None, - xlim: ArrayLike | None = None, - ylim: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, add_legend: bool = True, _labels: bool = True, **kwargs: Any, ) -> FacetGrid[DataArray]: ... - @functools.wraps(dataarray_plot.line) + @functools.wraps(dataarray_plot.line, assigned=("__doc__",)) def line(self, *args, **kwargs) -> list[Line3D] | FacetGrid[DataArray]: return dataarray_plot.line(self._da, *args, **kwargs) @overload - def step( # type: ignore[misc] # None is hashable :( + def step( # type: ignore[misc,unused-ignore] # None is hashable :( self, *args: Any, where: Literal["pre", "post", "mid"] = "pre", @@ -174,12 +177,12 @@ def step( ) -> FacetGrid[DataArray]: ... - @functools.wraps(dataarray_plot.step) + @functools.wraps(dataarray_plot.step, assigned=("__doc__",)) def step(self, *args, **kwargs) -> list[Line3D] | FacetGrid[DataArray]: return dataarray_plot.step(self._da, *args, **kwargs) @overload - def scatter( + def scatter( # type: ignore[misc,unused-ignore] # None is hashable :( self, *args: Any, x: Hashable | None = None, @@ -207,8 +210,8 @@ def scatter( yscale: ScaleOptions = None, xticks: ArrayLike | None = None, yticks: ArrayLike | None = None, - xlim: ArrayLike | None = None, - ylim: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, cmap=None, vmin: float | None = None, vmax: float | None = None, @@ -248,8 +251,8 @@ def scatter( yscale: ScaleOptions = None, xticks: ArrayLike | None = None, yticks: ArrayLike | None = None, - xlim: ArrayLike | None = None, - ylim: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, cmap=None, vmin: float | None = None, vmax: float | None = None, @@ -289,8 +292,8 @@ def scatter( yscale: ScaleOptions = None, xticks: ArrayLike | None = None, yticks: ArrayLike | None = None, - xlim: ArrayLike | None = None, - ylim: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, cmap=None, vmin: float | None = None, vmax: float | None = None, @@ -301,12 +304,12 @@ def scatter( ) -> FacetGrid[DataArray]: ... - @functools.wraps(dataarray_plot.scatter) - def scatter(self, *args, **kwargs): + @functools.wraps(dataarray_plot.scatter, assigned=("__doc__",)) + def scatter(self, *args, **kwargs) -> PathCollection | FacetGrid[DataArray]: return dataarray_plot.scatter(self._da, *args, **kwargs) @overload - def imshow( + def imshow( # type: ignore[misc,unused-ignore] # None is hashable :( self, *args: Any, x: Hashable | None = None, @@ -338,8 +341,8 @@ def imshow( yscale: ScaleOptions = None, xticks: ArrayLike | None = None, yticks: ArrayLike | None = None, - xlim: ArrayLike | None = None, - ylim: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, norm: Normalize | None = None, **kwargs: Any, ) -> AxesImage: @@ -378,8 +381,8 @@ def imshow( yscale: ScaleOptions = None, xticks: ArrayLike | None = None, yticks: ArrayLike | None = None, - xlim: ArrayLike | None = None, - ylim: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, norm: Normalize | None = None, **kwargs: Any, ) -> FacetGrid[DataArray]: @@ -418,19 +421,19 @@ def imshow( yscale: ScaleOptions = None, xticks: ArrayLike | None = None, yticks: ArrayLike | None = None, - xlim: ArrayLike | None = None, - ylim: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, norm: Normalize | None = None, **kwargs: Any, ) -> FacetGrid[DataArray]: ... - @functools.wraps(dataarray_plot.imshow) - def imshow(self, *args, **kwargs) -> AxesImage: + @functools.wraps(dataarray_plot.imshow, assigned=("__doc__",)) + def imshow(self, *args, **kwargs) -> AxesImage | FacetGrid[DataArray]: return dataarray_plot.imshow(self._da, *args, **kwargs) @overload - def contour( + def contour( # type: ignore[misc,unused-ignore] # None is hashable :( self, *args: Any, x: Hashable | None = None, @@ -462,8 +465,8 @@ def contour( yscale: ScaleOptions = None, xticks: ArrayLike | None = None, yticks: ArrayLike | None = None, - xlim: ArrayLike | None = None, - ylim: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, norm: Normalize | None = None, **kwargs: Any, ) -> QuadContourSet: @@ -502,8 +505,8 @@ def contour( yscale: ScaleOptions = None, xticks: ArrayLike | None = None, yticks: ArrayLike | None = None, - xlim: ArrayLike | None = None, - ylim: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, norm: Normalize | None = None, **kwargs: Any, ) -> FacetGrid[DataArray]: @@ -542,19 +545,19 @@ def contour( yscale: ScaleOptions = None, xticks: ArrayLike | None = None, yticks: ArrayLike | None = None, - xlim: ArrayLike | None = None, - ylim: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, norm: Normalize | None = None, **kwargs: Any, ) -> FacetGrid[DataArray]: ... - @functools.wraps(dataarray_plot.contour) - def contour(self, *args, **kwargs) -> QuadContourSet: + @functools.wraps(dataarray_plot.contour, assigned=("__doc__",)) + def contour(self, *args, **kwargs) -> QuadContourSet | FacetGrid[DataArray]: return dataarray_plot.contour(self._da, *args, **kwargs) @overload - def contourf( + def contourf( # type: ignore[misc,unused-ignore] # None is hashable :( self, *args: Any, x: Hashable | None = None, @@ -586,8 +589,8 @@ def contourf( yscale: ScaleOptions = None, xticks: ArrayLike | None = None, yticks: ArrayLike | None = None, - xlim: ArrayLike | None = None, - ylim: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, norm: Normalize | None = None, **kwargs: Any, ) -> QuadContourSet: @@ -626,8 +629,8 @@ def contourf( yscale: ScaleOptions = None, xticks: ArrayLike | None = None, yticks: ArrayLike | None = None, - xlim: ArrayLike | None = None, - ylim: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, norm: Normalize | None = None, **kwargs: Any, ) -> FacetGrid[DataArray]: @@ -666,19 +669,19 @@ def contourf( yscale: ScaleOptions = None, xticks: ArrayLike | None = None, yticks: ArrayLike | None = None, - xlim: ArrayLike | None = None, - ylim: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, norm: Normalize | None = None, **kwargs: Any, ) -> FacetGrid: ... - @functools.wraps(dataarray_plot.contourf) - def contourf(self, *args, **kwargs) -> QuadContourSet: + @functools.wraps(dataarray_plot.contourf, assigned=("__doc__",)) + def contourf(self, *args, **kwargs) -> QuadContourSet | FacetGrid[DataArray]: return dataarray_plot.contourf(self._da, *args, **kwargs) @overload - def pcolormesh( + def pcolormesh( # type: ignore[misc,unused-ignore] # None is hashable :( self, *args: Any, x: Hashable | None = None, @@ -710,8 +713,8 @@ def pcolormesh( yscale: ScaleOptions = None, xticks: ArrayLike | None = None, yticks: ArrayLike | None = None, - xlim: ArrayLike | None = None, - ylim: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, norm: Normalize | None = None, **kwargs: Any, ) -> QuadMesh: @@ -750,11 +753,11 @@ def pcolormesh( yscale: ScaleOptions = None, xticks: ArrayLike | None = None, yticks: ArrayLike | None = None, - xlim: ArrayLike | None = None, - ylim: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, norm: Normalize | None = None, **kwargs: Any, - ) -> FacetGrid: + ) -> FacetGrid[DataArray]: ... @overload @@ -790,15 +793,15 @@ def pcolormesh( yscale: ScaleOptions = None, xticks: ArrayLike | None = None, yticks: ArrayLike | None = None, - xlim: ArrayLike | None = None, - ylim: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, norm: Normalize | None = None, **kwargs: Any, - ) -> FacetGrid: + ) -> FacetGrid[DataArray]: ... - @functools.wraps(dataarray_plot.pcolormesh) - def pcolormesh(self, *args, **kwargs) -> QuadMesh: + @functools.wraps(dataarray_plot.pcolormesh, assigned=("__doc__",)) + def pcolormesh(self, *args, **kwargs) -> QuadMesh | FacetGrid[DataArray]: return dataarray_plot.pcolormesh(self._da, *args, **kwargs) @overload @@ -834,8 +837,8 @@ def surface( yscale: ScaleOptions = None, xticks: ArrayLike | None = None, yticks: ArrayLike | None = None, - xlim: ArrayLike | None = None, - ylim: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, norm: Normalize | None = None, **kwargs: Any, ) -> Poly3DCollection: @@ -874,8 +877,8 @@ def surface( yscale: ScaleOptions = None, xticks: ArrayLike | None = None, yticks: ArrayLike | None = None, - xlim: ArrayLike | None = None, - ylim: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, norm: Normalize | None = None, **kwargs: Any, ) -> FacetGrid: @@ -914,14 +917,14 @@ def surface( yscale: ScaleOptions = None, xticks: ArrayLike | None = None, yticks: ArrayLike | None = None, - xlim: ArrayLike | None = None, - ylim: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, norm: Normalize | None = None, **kwargs: Any, ) -> FacetGrid: ... - @functools.wraps(dataarray_plot.surface) + @functools.wraps(dataarray_plot.surface, assigned=("__doc__",)) def surface(self, *args, **kwargs) -> Poly3DCollection: return dataarray_plot.surface(self._da, *args, **kwargs) @@ -945,7 +948,7 @@ def __call__(self, *args, **kwargs) -> NoReturn: ) @overload - def scatter( + def scatter( # type: ignore[misc,unused-ignore] # None is hashable :( self, *args: Any, x: Hashable | None = None, @@ -973,8 +976,8 @@ def scatter( yscale: ScaleOptions = None, xticks: ArrayLike | None = None, yticks: ArrayLike | None = None, - xlim: ArrayLike | None = None, - ylim: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, cmap=None, vmin: float | None = None, vmax: float | None = None, @@ -1014,8 +1017,8 @@ def scatter( yscale: ScaleOptions = None, xticks: ArrayLike | None = None, yticks: ArrayLike | None = None, - xlim: ArrayLike | None = None, - ylim: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, cmap=None, vmin: float | None = None, vmax: float | None = None, @@ -1023,7 +1026,7 @@ def scatter( extend=None, levels=None, **kwargs: Any, - ) -> FacetGrid[DataArray]: + ) -> FacetGrid[Dataset]: ... @overload @@ -1055,8 +1058,8 @@ def scatter( yscale: ScaleOptions = None, xticks: ArrayLike | None = None, yticks: ArrayLike | None = None, - xlim: ArrayLike | None = None, - ylim: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, cmap=None, vmin: float | None = None, vmax: float | None = None, @@ -1064,15 +1067,15 @@ def scatter( extend=None, levels=None, **kwargs: Any, - ) -> FacetGrid[DataArray]: + ) -> FacetGrid[Dataset]: ... - @functools.wraps(dataset_plot.scatter) - def scatter(self, *args, **kwargs) -> PathCollection | FacetGrid[DataArray]: + @functools.wraps(dataset_plot.scatter, assigned=("__doc__",)) + def scatter(self, *args, **kwargs) -> PathCollection | FacetGrid[Dataset]: return dataset_plot.scatter(self._ds, *args, **kwargs) @overload - def quiver( + def quiver( # type: ignore[misc,unused-ignore] # None is hashable :( self, *args: Any, x: Hashable | None = None, @@ -1142,7 +1145,7 @@ def quiver( extend=None, cmap=None, **kwargs: Any, - ) -> FacetGrid: + ) -> FacetGrid[Dataset]: ... @overload @@ -1179,15 +1182,15 @@ def quiver( extend=None, cmap=None, **kwargs: Any, - ) -> FacetGrid: + ) -> FacetGrid[Dataset]: ... - @functools.wraps(dataset_plot.quiver) - def quiver(self, *args, **kwargs) -> Quiver | FacetGrid: + @functools.wraps(dataset_plot.quiver, assigned=("__doc__",)) + def quiver(self, *args, **kwargs) -> Quiver | FacetGrid[Dataset]: return dataset_plot.quiver(self._ds, *args, **kwargs) @overload - def streamplot( + def streamplot( # type: ignore[misc,unused-ignore] # None is hashable :( self, *args: Any, x: Hashable | None = None, @@ -1257,7 +1260,7 @@ def streamplot( extend=None, cmap=None, **kwargs: Any, - ) -> FacetGrid: + ) -> FacetGrid[Dataset]: ... @overload @@ -1294,9 +1297,9 @@ def streamplot( extend=None, cmap=None, **kwargs: Any, - ) -> FacetGrid: + ) -> FacetGrid[Dataset]: ... - @functools.wraps(dataset_plot.streamplot) - def streamplot(self, *args, **kwargs) -> LineCollection | FacetGrid: + @functools.wraps(dataset_plot.streamplot, assigned=("__doc__",)) + def streamplot(self, *args, **kwargs) -> LineCollection | FacetGrid[Dataset]: return dataset_plot.streamplot(self._ds, *args, **kwargs) diff --git a/xarray/plot/dataarray_plot.py b/xarray/plot/dataarray_plot.py index 1e9324791da..8afd87ea64a 100644 --- a/xarray/plot/dataarray_plot.py +++ b/xarray/plot/dataarray_plot.py @@ -3,7 +3,7 @@ import functools import warnings from collections.abc import Hashable, Iterable, MutableMapping -from typing import TYPE_CHECKING, Any, Callable, Literal, cast, overload +from typing import TYPE_CHECKING, Any, Callable, Literal, Union, cast, overload import numpy as np import pandas as pd @@ -29,7 +29,6 @@ _resolve_intervals_2dplot, _update_axes, get_axis, - import_matplotlib_pyplot, label_from_attrs, ) @@ -40,6 +39,7 @@ from matplotlib.container import BarContainer from matplotlib.contour import QuadContourSet from matplotlib.image import AxesImage + from matplotlib.patches import Polygon from mpl_toolkits.mplot3d.art3d import Line3D, Poly3DCollection from numpy.typing import ArrayLike @@ -53,7 +53,7 @@ ) from xarray.plot.facetgrid import FacetGrid -_styles: MutableMapping[str, Any] = { +_styles: dict[str, Any] = { # Add a white border to make it easier seeing overlapping markers: "scatter.edgecolors": "w", } @@ -307,7 +307,7 @@ def plot( @overload -def line( # type: ignore[misc] # None is hashable :( +def line( # type: ignore[misc,unused-ignore] # None is hashable :( darray: DataArray, *args: Any, row: None = None, # no wrap -> primitive @@ -325,8 +325,8 @@ def line( # type: ignore[misc] # None is hashable :( yscale: ScaleOptions = None, xticks: ArrayLike | None = None, yticks: ArrayLike | None = None, - xlim: ArrayLike | None = None, - ylim: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, add_legend: bool = True, _labels: bool = True, **kwargs: Any, @@ -336,7 +336,7 @@ def line( # type: ignore[misc] # None is hashable :( @overload def line( - darray, + darray: T_DataArray, *args: Any, row: Hashable, # wrap -> FacetGrid col: Hashable | None = None, @@ -353,18 +353,18 @@ def line( yscale: ScaleOptions = None, xticks: ArrayLike | None = None, yticks: ArrayLike | None = None, - xlim: ArrayLike | None = None, - ylim: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, add_legend: bool = True, _labels: bool = True, **kwargs: Any, -) -> FacetGrid[DataArray]: +) -> FacetGrid[T_DataArray]: ... @overload def line( - darray, + darray: T_DataArray, *args: Any, row: Hashable | None = None, col: Hashable, # wrap -> FacetGrid @@ -381,19 +381,19 @@ def line( yscale: ScaleOptions = None, xticks: ArrayLike | None = None, yticks: ArrayLike | None = None, - xlim: ArrayLike | None = None, - ylim: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, add_legend: bool = True, _labels: bool = True, **kwargs: Any, -) -> FacetGrid[DataArray]: +) -> FacetGrid[T_DataArray]: ... # This function signature should not change so that it can use # matplotlib format strings def line( - darray: DataArray, + darray: T_DataArray, *args: Any, row: Hashable | None = None, col: Hashable | None = None, @@ -410,12 +410,12 @@ def line( yscale: ScaleOptions = None, xticks: ArrayLike | None = None, yticks: ArrayLike | None = None, - xlim: ArrayLike | None = None, - ylim: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, add_legend: bool = True, _labels: bool = True, **kwargs: Any, -) -> list[Line3D] | FacetGrid[DataArray]: +) -> list[Line3D] | FacetGrid[T_DataArray]: """ Line plot of DataArray values. @@ -459,7 +459,7 @@ def line( Specifies scaling for the *x*- and *y*-axis, respectively. xticks, yticks : array-like, optional Specify tick locations for *x*- and *y*-axis. - xlim, ylim : array-like, optional + xlim, ylim : tuple[float, float], optional Specify *x*- and *y*-axis limits. add_legend : bool, default: True Add legend with *y* axis coordinates (2D inputs only). @@ -538,7 +538,7 @@ def line( @overload -def step( # type: ignore[misc] # None is hashable :( +def step( # type: ignore[misc,unused-ignore] # None is hashable :( darray: DataArray, *args: Any, where: Literal["pre", "post", "mid"] = "pre", @@ -654,10 +654,10 @@ def hist( yscale: ScaleOptions = None, xticks: ArrayLike | None = None, yticks: ArrayLike | None = None, - xlim: ArrayLike | None = None, - ylim: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, **kwargs: Any, -) -> tuple[np.ndarray, np.ndarray, BarContainer]: +) -> tuple[np.ndarray, np.ndarray, BarContainer | Polygon]: """ Histogram of DataArray. @@ -691,7 +691,7 @@ def hist( Specifies scaling for the *x*- and *y*-axis, respectively. xticks, yticks : array-like, optional Specify tick locations for *x*- and *y*-axis. - xlim, ylim : array-like, optional + xlim, ylim : tuple[float, float], optional Specify *x*- and *y*-axis limits. **kwargs : optional Additional keyword arguments to :py:func:`matplotlib:matplotlib.pyplot.hist`. @@ -708,14 +708,17 @@ def hist( no_nan = np.ravel(darray.to_numpy()) no_nan = no_nan[pd.notnull(no_nan)] - primitive = ax.hist(no_nan, **kwargs) + n, bins, patches = cast( + tuple[np.ndarray, np.ndarray, Union["BarContainer", "Polygon"]], + ax.hist(no_nan, **kwargs), + ) ax.set_title(darray._title_for_slice()) ax.set_xlabel(label_from_attrs(darray)) _update_axes(ax, xincrease, yincrease, xscale, yscale, xticks, yticks, xlim, ylim) - return primitive + return n, bins, patches def _plot1d(plotfunc): @@ -779,9 +782,9 @@ def _plot1d(plotfunc): Specify tick locations for x-axes. yticks : ArrayLike or None, optional Specify tick locations for y-axes. - xlim : ArrayLike or None, optional + xlim : tuple[float, float] or None, optional Specify x-axes limits. - ylim : ArrayLike or None, optional + ylim : tuple[float, float] or None, optional Specify y-axes limits. cmap : matplotlib colormap name or colormap, optional The mapping from data values to color space. Either a @@ -866,8 +869,8 @@ def newplotfunc( yscale: ScaleOptions = None, xticks: ArrayLike | None = None, yticks: ArrayLike | None = None, - xlim: ArrayLike | None = None, - ylim: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, cmap: str | Colormap | None = None, vmin: float | None = None, vmax: float | None = None, @@ -879,7 +882,7 @@ def newplotfunc( # All 1d plots in xarray share this function signature. # Method signature below should be consistent. - plt = import_matplotlib_pyplot() + import matplotlib.pyplot as plt if subplot_kws is None: subplot_kws = dict() @@ -992,9 +995,13 @@ def newplotfunc( with plt.rc_context(_styles): if z is not None: + import mpl_toolkits + if ax is None: subplot_kws.update(projection="3d") ax = get_axis(figsize, size, aspect, ax, **subplot_kws) + assert isinstance(ax, mpl_toolkits.mplot3d.axes3d.Axes3D) + # Using 30, 30 minimizes rotation of the plot. Making it easier to # build on your intuition from 2D plots: ax.view_init(azim=30, elev=30, vertical_axis="y") @@ -1078,12 +1085,12 @@ def newplotfunc( def _add_labels( add_labels: bool | Iterable[bool], - darrays: Iterable[DataArray], + darrays: Iterable[DataArray | None], suffixes: Iterable[str], rotate_labels: Iterable[bool], ax: Axes, ) -> None: - # Set x, y, z labels: + """Set x, y, z labels.""" add_labels = [add_labels] * 3 if isinstance(add_labels, bool) else add_labels for axis, add_label, darray, suffix, rotate_label in zip( ("x", "y", "z"), add_labels, darrays, suffixes, rotate_labels @@ -1107,7 +1114,7 @@ def _add_labels( @overload -def scatter( +def scatter( # type: ignore[misc,unused-ignore] # None is hashable :( darray: DataArray, *args: Any, x: Hashable | None = None, @@ -1150,7 +1157,7 @@ def scatter( @overload def scatter( - darray: DataArray, + darray: T_DataArray, *args: Any, x: Hashable | None = None, y: Hashable | None = None, @@ -1186,13 +1193,13 @@ def scatter( extend: ExtendOptions = None, levels: ArrayLike | None = None, **kwargs, -) -> FacetGrid[DataArray]: +) -> FacetGrid[T_DataArray]: ... @overload def scatter( - darray: DataArray, + darray: T_DataArray, *args: Any, x: Hashable | None = None, y: Hashable | None = None, @@ -1228,7 +1235,7 @@ def scatter( extend: ExtendOptions = None, levels: ArrayLike | None = None, **kwargs, -) -> FacetGrid[DataArray]: +) -> FacetGrid[T_DataArray]: ... @@ -1257,15 +1264,24 @@ def scatter( if sizeplt is not None: kwargs.update(s=sizeplt.to_numpy().ravel()) - axis_order = ["x", "y", "z"] + plts_or_none = (xplt, yplt, zplt) + _add_labels(add_labels, plts_or_none, ("", "", ""), (True, False, False), ax) - plts_dict: dict[str, DataArray | None] = dict(x=xplt, y=yplt, z=zplt) - plts_or_none = [plts_dict[v] for v in axis_order] - plts = [p for p in plts_or_none if p is not None] - primitive = ax.scatter(*[p.to_numpy().ravel() for p in plts], **kwargs) - _add_labels(add_labels, plts, ("", "", ""), (True, False, False), ax) + xplt_np = None if xplt is None else xplt.to_numpy().ravel() + yplt_np = None if yplt is None else yplt.to_numpy().ravel() + zplt_np = None if zplt is None else zplt.to_numpy().ravel() + plts_np = tuple(p for p in (xplt_np, yplt_np, zplt_np) if p is not None) - return primitive + if len(plts_np) == 3: + import mpl_toolkits + + assert isinstance(ax, mpl_toolkits.mplot3d.axes3d.Axes3D) + return ax.scatter(xplt_np, yplt_np, zplt_np, **kwargs) + + if len(plts_np) == 2: + return ax.scatter(plts_np[0], plts_np[1], **kwargs) + + raise ValueError("At least two variables required for a scatter plot.") def _plot2d(plotfunc): @@ -1374,9 +1390,9 @@ def _plot2d(plotfunc): Specify tick locations for x-axes. yticks : ArrayLike or None, optional Specify tick locations for y-axes. - xlim : ArrayLike or None, optional + xlim : tuple[float, float] or None, optional Specify x-axes limits. - ylim : ArrayLike or None, optional + ylim : tuple[float, float] or None, optional Specify y-axes limits. norm : matplotlib.colors.Normalize, optional If ``norm`` has ``vmin`` or ``vmax`` specified, the corresponding @@ -1429,8 +1445,8 @@ def newplotfunc( yscale: ScaleOptions = None, xticks: ArrayLike | None = None, yticks: ArrayLike | None = None, - xlim: ArrayLike | None = None, - ylim: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, norm: Normalize | None = None, **kwargs: Any, ) -> Any: @@ -1502,8 +1518,6 @@ def newplotfunc( # TypeError to be consistent with pandas raise TypeError("No numeric data to plot.") - plt = import_matplotlib_pyplot() - if ( plotfunc.__name__ == "surface" and not kwargs.get("_is_facetgrid", False) @@ -1616,6 +1630,9 @@ def newplotfunc( ax.set_ylabel(label_from_attrs(darray[ylab], ylab_extra)) ax.set_title(darray._title_for_slice()) if plotfunc.__name__ == "surface": + import mpl_toolkits + + assert isinstance(ax, mpl_toolkits.mplot3d.axes3d.Axes3D) ax.set_zlabel(label_from_attrs(darray)) if add_colorbar: @@ -1656,7 +1673,7 @@ def newplotfunc( @overload -def imshow( +def imshow( # type: ignore[misc,unused-ignore] # None is hashable :( darray: DataArray, x: Hashable | None = None, y: Hashable | None = None, @@ -1698,7 +1715,7 @@ def imshow( @overload def imshow( - darray: DataArray, + darray: T_DataArray, x: Hashable | None = None, y: Hashable | None = None, *, @@ -1733,13 +1750,13 @@ def imshow( ylim: ArrayLike | None = None, norm: Normalize | None = None, **kwargs: Any, -) -> FacetGrid[DataArray]: +) -> FacetGrid[T_DataArray]: ... @overload def imshow( - darray: DataArray, + darray: T_DataArray, x: Hashable | None = None, y: Hashable | None = None, *, @@ -1774,7 +1791,7 @@ def imshow( ylim: ArrayLike | None = None, norm: Normalize | None = None, **kwargs: Any, -) -> FacetGrid[DataArray]: +) -> FacetGrid[T_DataArray]: ... @@ -1875,7 +1892,7 @@ def _center_pixels(x): @overload -def contour( +def contour( # type: ignore[misc,unused-ignore] # None is hashable :( darray: DataArray, x: Hashable | None = None, y: Hashable | None = None, @@ -1917,7 +1934,7 @@ def contour( @overload def contour( - darray: DataArray, + darray: T_DataArray, x: Hashable | None = None, y: Hashable | None = None, *, @@ -1952,13 +1969,13 @@ def contour( ylim: ArrayLike | None = None, norm: Normalize | None = None, **kwargs: Any, -) -> FacetGrid[DataArray]: +) -> FacetGrid[T_DataArray]: ... @overload def contour( - darray: DataArray, + darray: T_DataArray, x: Hashable | None = None, y: Hashable | None = None, *, @@ -1993,7 +2010,7 @@ def contour( ylim: ArrayLike | None = None, norm: Normalize | None = None, **kwargs: Any, -) -> FacetGrid[DataArray]: +) -> FacetGrid[T_DataArray]: ... @@ -2011,7 +2028,7 @@ def contour( @overload -def contourf( +def contourf( # type: ignore[misc,unused-ignore] # None is hashable :( darray: DataArray, x: Hashable | None = None, y: Hashable | None = None, @@ -2053,7 +2070,7 @@ def contourf( @overload def contourf( - darray: DataArray, + darray: T_DataArray, x: Hashable | None = None, y: Hashable | None = None, *, @@ -2088,13 +2105,13 @@ def contourf( ylim: ArrayLike | None = None, norm: Normalize | None = None, **kwargs: Any, -) -> FacetGrid[DataArray]: +) -> FacetGrid[T_DataArray]: ... @overload def contourf( - darray: DataArray, + darray: T_DataArray, x: Hashable | None = None, y: Hashable | None = None, *, @@ -2129,7 +2146,7 @@ def contourf( ylim: ArrayLike | None = None, norm: Normalize | None = None, **kwargs: Any, -) -> FacetGrid[DataArray]: +) -> FacetGrid[T_DataArray]: ... @@ -2147,7 +2164,7 @@ def contourf( @overload -def pcolormesh( +def pcolormesh( # type: ignore[misc,unused-ignore] # None is hashable :( darray: DataArray, x: Hashable | None = None, y: Hashable | None = None, @@ -2189,7 +2206,7 @@ def pcolormesh( @overload def pcolormesh( - darray: DataArray, + darray: T_DataArray, x: Hashable | None = None, y: Hashable | None = None, *, @@ -2224,13 +2241,13 @@ def pcolormesh( ylim: ArrayLike | None = None, norm: Normalize | None = None, **kwargs: Any, -) -> FacetGrid[DataArray]: +) -> FacetGrid[T_DataArray]: ... @overload def pcolormesh( - darray: DataArray, + darray: T_DataArray, x: Hashable | None = None, y: Hashable | None = None, *, @@ -2265,7 +2282,7 @@ def pcolormesh( ylim: ArrayLike | None = None, norm: Normalize | None = None, **kwargs: Any, -) -> FacetGrid[DataArray]: +) -> FacetGrid[T_DataArray]: ... @@ -2376,7 +2393,7 @@ def surface( @overload def surface( - darray: DataArray, + darray: T_DataArray, x: Hashable | None = None, y: Hashable | None = None, *, @@ -2411,13 +2428,13 @@ def surface( ylim: ArrayLike | None = None, norm: Normalize | None = None, **kwargs: Any, -) -> FacetGrid[DataArray]: +) -> FacetGrid[T_DataArray]: ... @overload def surface( - darray: DataArray, + darray: T_DataArray, x: Hashable | None = None, y: Hashable | None = None, *, @@ -2452,7 +2469,7 @@ def surface( ylim: ArrayLike | None = None, norm: Normalize | None = None, **kwargs: Any, -) -> FacetGrid[DataArray]: +) -> FacetGrid[T_DataArray]: ... @@ -2465,5 +2482,8 @@ def surface( Wraps :py:meth:`matplotlib:mpl_toolkits.mplot3d.axes3d.Axes3D.plot_surface`. """ + import mpl_toolkits + + assert isinstance(ax, mpl_toolkits.mplot3d.axes3d.Axes3D) primitive = ax.plot_surface(x, y, z, **kwargs) return primitive diff --git a/xarray/plot/dataset_plot.py b/xarray/plot/dataset_plot.py index edb32b73d98..a3ca201eec4 100644 --- a/xarray/plot/dataset_plot.py +++ b/xarray/plot/dataset_plot.py @@ -321,7 +321,7 @@ def newplotfunc( @overload -def quiver( +def quiver( # type: ignore[misc,unused-ignore] # None is hashable :( ds: Dataset, *args: Any, x: Hashable | None = None, @@ -475,7 +475,7 @@ def quiver( @overload -def streamplot( +def streamplot( # type: ignore[misc,unused-ignore] # None is hashable :( ds: Dataset, *args: Any, x: Hashable | None = None, @@ -632,7 +632,6 @@ def streamplot( du = du.transpose(ydim, xdim) dv = dv.transpose(ydim, xdim) - args = [dx.values, dy.values, du.values, dv.values] hue = kwargs.pop("hue") cmap_params = kwargs.pop("cmap_params") @@ -646,7 +645,9 @@ def streamplot( ) kwargs.pop("hue_style") - hdl = ax.streamplot(*args, **kwargs, **cmap_params) + hdl = ax.streamplot( + dx.values, dy.values, du.values, dv.values, **kwargs, **cmap_params + ) # Return .lines so colorbar creation works properly return hdl.lines @@ -748,7 +749,7 @@ def _temp_dataarray(ds: Dataset, y: Hashable, locals_: dict[str, Any]) -> DataAr @overload -def scatter( +def scatter( # type: ignore[misc,unused-ignore] # None is hashable :( ds: Dataset, *args: Any, x: Hashable | None = None, diff --git a/xarray/plot/facetgrid.py b/xarray/plot/facetgrid.py index 2b348d8bedd..faf809a8a74 100644 --- a/xarray/plot/facetgrid.py +++ b/xarray/plot/facetgrid.py @@ -9,7 +9,7 @@ import numpy as np from xarray.core.formatting import format_item -from xarray.core.types import HueStyleOptions, T_Xarray +from xarray.core.types import HueStyleOptions, T_DataArrayOrSet from xarray.plot.utils import ( _LINEWIDTH_RANGE, _MARKERSIZE_RANGE, @@ -21,7 +21,6 @@ _Normalize, _parse_size, _process_cmap_cbar_kwargs, - import_matplotlib_pyplot, label_from_attrs, ) @@ -60,7 +59,7 @@ def _nicetitle(coord, value, maxchar, template): T_FacetGrid = TypeVar("T_FacetGrid", bound="FacetGrid") -class FacetGrid(Generic[T_Xarray]): +class FacetGrid(Generic[T_DataArrayOrSet]): """ Initialize the Matplotlib figure and FacetGrid object. @@ -101,7 +100,7 @@ class FacetGrid(Generic[T_Xarray]): sometimes the rightmost grid positions in the bottom row. """ - data: T_Xarray + data: T_DataArrayOrSet name_dicts: np.ndarray fig: Figure axs: np.ndarray @@ -126,7 +125,7 @@ class FacetGrid(Generic[T_Xarray]): def __init__( self, - data: T_Xarray, + data: T_DataArrayOrSet, col: Hashable | None = None, row: Hashable | None = None, col_wrap: int | None = None, @@ -166,7 +165,7 @@ def __init__( """ - plt = import_matplotlib_pyplot() + import matplotlib.pyplot as plt # Handle corner case of nonunique coordinates rep_col = col is not None and not data[col].to_index().is_unique @@ -681,7 +680,10 @@ def _finalize_grid(self, *axlabels: Hashable) -> None: def _adjust_fig_for_guide(self, guide) -> None: # Draw the plot to set the bounding boxes correctly - renderer = self.fig.canvas.get_renderer() + if hasattr(self.fig.canvas, "get_renderer"): + renderer = self.fig.canvas.get_renderer() + else: + raise RuntimeError("MPL backend has no renderer") self.fig.draw(renderer) # Calculate and set the new width of the figure so the legend fits @@ -731,6 +733,9 @@ def add_colorbar(self, **kwargs: Any) -> None: if hasattr(self._mappables[-1], "extend"): kwargs.pop("extend", None) if "label" not in kwargs: + from xarray import DataArray + + assert isinstance(self.data, DataArray) kwargs.setdefault("label", label_from_attrs(self.data)) self.cbar = self.fig.colorbar( self._mappables[-1], ax=list(self.axs.flat), **kwargs @@ -985,7 +990,7 @@ def map( self : FacetGrid object """ - plt = import_matplotlib_pyplot() + import matplotlib.pyplot as plt for ax, namedict in zip(self.axs.flat, self.name_dicts.flat): if namedict is not None: @@ -1004,7 +1009,7 @@ def map( def _easy_facetgrid( - data: T_Xarray, + data: T_DataArrayOrSet, plotfunc: Callable, kind: Literal["line", "dataarray", "dataset", "plot1d"], x: Hashable | None = None, @@ -1020,7 +1025,7 @@ def _easy_facetgrid( ax: Axes | None = None, figsize: Iterable[float] | None = None, **kwargs: Any, -) -> FacetGrid[T_Xarray]: +) -> FacetGrid[T_DataArrayOrSet]: """ Convenience method to call xarray.plot.FacetGrid from 2d plotting methods diff --git a/xarray/plot/utils.py b/xarray/plot/utils.py index b8cc4ff7349..5694acc06e8 100644 --- a/xarray/plot/utils.py +++ b/xarray/plot/utils.py @@ -47,14 +47,6 @@ _LINEWIDTH_RANGE = (1.5, 1.5, 6.0) -def import_matplotlib_pyplot(): - """import pyplot""" - # TODO: This function doesn't do anything (after #6109), remove it? - import matplotlib.pyplot as plt - - return plt - - def _determine_extend(calc_data, vmin, vmax): extend_min = calc_data.min() < vmin extend_max = calc_data.max() > vmax @@ -505,28 +497,29 @@ def _maybe_gca(**subplot_kws: Any) -> Axes: return plt.axes(**subplot_kws) -def _get_units_from_attrs(da) -> str: +def _get_units_from_attrs(da: DataArray) -> str: """Extracts and formats the unit/units from a attributes.""" pint_array_type = DuckArrayModule("pint").type units = " [{}]" if isinstance(da.data, pint_array_type): - units = units.format(str(da.data.units)) - elif da.attrs.get("units"): - units = units.format(da.attrs["units"]) - elif da.attrs.get("unit"): - units = units.format(da.attrs["unit"]) - else: - units = "" - return units + return units.format(str(da.data.units)) + if "units" in da.attrs: + return units.format(da.attrs["units"]) + if "unit" in da.attrs: + return units.format(da.attrs["unit"]) + return "" -def label_from_attrs(da, extra: str = "") -> str: +def label_from_attrs(da: DataArray | None, extra: str = "") -> str: """Makes informative labels if variable metadata (attrs) follows CF conventions.""" + if da is None: + return "" + name: str = "{}" - if da.attrs.get("long_name"): + if "long_name" in da.attrs: name = name.format(da.attrs["long_name"]) - elif da.attrs.get("standard_name"): + elif "standard_name" in da.attrs: name = name.format(da.attrs["standard_name"]) elif da.name is not None: name = name.format(da.name) @@ -774,8 +767,8 @@ def _update_axes( yscale: ScaleOptions = None, xticks: ArrayLike | None = None, yticks: ArrayLike | None = None, - xlim: ArrayLike | None = None, - ylim: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, ) -> None: """ Update axes with provided parameters @@ -1166,7 +1159,7 @@ def _get_color_and_size(value): def _legend_add_subtitle(handles, labels, text): """Add a subtitle to legend handles.""" - plt = import_matplotlib_pyplot() + import matplotlib.pyplot as plt if text and len(handles) > 1: # Create a blank handle that's not visible, the @@ -1184,7 +1177,7 @@ def _legend_add_subtitle(handles, labels, text): def _adjust_legend_subtitles(legend): """Make invisible-handle "subtitles" entries look more like titles.""" - plt = import_matplotlib_pyplot() + import matplotlib.pyplot as plt # Legend title not in rcParams until 3.0 font_size = plt.rcParams.get("legend.title_fontsize", None) @@ -1640,7 +1633,7 @@ def format(self) -> FuncFormatter: >>> aa.format(1) '3.0' """ - plt = import_matplotlib_pyplot() + import matplotlib.pyplot as plt def _func(x: Any, pos: None | Any = None): return f"{self._lookup_arr([x])[0]}" diff --git a/xarray/tests/test_plot.py b/xarray/tests/test_plot.py index 8b2dfbdec41..b0e6ff90bc7 100644 --- a/xarray/tests/test_plot.py +++ b/xarray/tests/test_plot.py @@ -43,6 +43,7 @@ # import mpl and change the backend before other mpl imports try: import matplotlib as mpl + import matplotlib.dates import matplotlib.pyplot as plt import mpl_toolkits except ImportError: @@ -421,6 +422,7 @@ def test2d_1d_2d_coordinates_pcolormesh(self) -> None: ]: p = a.plot.pcolormesh(x=x, y=y) v = p.get_paths()[0].vertices + assert isinstance(v, np.ndarray) # Check all vertices are different, except last vertex which should be the # same as the first @@ -440,7 +442,7 @@ def test_str_coordinates_pcolormesh(self) -> None: def test_contourf_cmap_set(self) -> None: a = DataArray(easy_array((4, 4)), dims=["z", "time"]) - cmap = mpl.cm.viridis + cmap_expected = mpl.colormaps["viridis"] # use copy to ensure cmap is not changed by contourf() # Set vmin and vmax so that _build_discrete_colormap is called with @@ -450,55 +452,59 @@ def test_contourf_cmap_set(self) -> None: # extend='neither' (but if extend='neither' the under and over values # would not be used because the data would all be within the plotted # range) - pl = a.plot.contourf(cmap=copy(cmap), vmin=0.1, vmax=0.9) + pl = a.plot.contourf(cmap=copy(cmap_expected), vmin=0.1, vmax=0.9) # check the set_bad color + cmap = pl.cmap + assert cmap is not None assert_array_equal( - pl.cmap(np.ma.masked_invalid([np.nan]))[0], cmap(np.ma.masked_invalid([np.nan]))[0], + cmap_expected(np.ma.masked_invalid([np.nan]))[0], ) # check the set_under color - assert pl.cmap(-np.inf) == cmap(-np.inf) + assert cmap(-np.inf) == cmap_expected(-np.inf) # check the set_over color - assert pl.cmap(np.inf) == cmap(np.inf) + assert cmap(np.inf) == cmap_expected(np.inf) def test_contourf_cmap_set_with_bad_under_over(self) -> None: a = DataArray(easy_array((4, 4)), dims=["z", "time"]) # make a copy here because we want a local cmap that we will modify. - cmap = copy(mpl.cm.viridis) + cmap_expected = copy(mpl.colormaps["viridis"]) - cmap.set_bad("w") + cmap_expected.set_bad("w") # check we actually changed the set_bad color assert np.all( - cmap(np.ma.masked_invalid([np.nan]))[0] - != mpl.cm.viridis(np.ma.masked_invalid([np.nan]))[0] + cmap_expected(np.ma.masked_invalid([np.nan]))[0] + != mpl.colormaps["viridis"](np.ma.masked_invalid([np.nan]))[0] ) - cmap.set_under("r") + cmap_expected.set_under("r") # check we actually changed the set_under color - assert cmap(-np.inf) != mpl.cm.viridis(-np.inf) + assert cmap_expected(-np.inf) != mpl.colormaps["viridis"](-np.inf) - cmap.set_over("g") + cmap_expected.set_over("g") # check we actually changed the set_over color - assert cmap(np.inf) != mpl.cm.viridis(-np.inf) + assert cmap_expected(np.inf) != mpl.colormaps["viridis"](-np.inf) # copy to ensure cmap is not changed by contourf() - pl = a.plot.contourf(cmap=copy(cmap)) + pl = a.plot.contourf(cmap=copy(cmap_expected)) + cmap = pl.cmap + assert cmap is not None # check the set_bad color has been kept assert_array_equal( - pl.cmap(np.ma.masked_invalid([np.nan]))[0], cmap(np.ma.masked_invalid([np.nan]))[0], + cmap_expected(np.ma.masked_invalid([np.nan]))[0], ) # check the set_under color has been kept - assert pl.cmap(-np.inf) == cmap(-np.inf) + assert cmap(-np.inf) == cmap_expected(-np.inf) # check the set_over color has been kept - assert pl.cmap(np.inf) == cmap(np.inf) + assert cmap(np.inf) == cmap_expected(np.inf) def test3d(self) -> None: self.darray.plot() @@ -831,19 +837,25 @@ def test_coord_with_interval_step(self) -> None: """Test step plot with intervals.""" bins = [-1, 0, 1, 2] self.darray.groupby_bins("dim_0", bins).mean(...).plot.step() - assert len(plt.gca().lines[0].get_xdata()) == ((len(bins) - 1) * 2) + line = plt.gca().lines[0] + assert isinstance(line, mpl.lines.Line2D) + assert len(np.asarray(line.get_xdata())) == ((len(bins) - 1) * 2) def test_coord_with_interval_step_x(self) -> None: """Test step plot with intervals explicitly on x axis.""" bins = [-1, 0, 1, 2] self.darray.groupby_bins("dim_0", bins).mean(...).plot.step(x="dim_0_bins") - assert len(plt.gca().lines[0].get_xdata()) == ((len(bins) - 1) * 2) + line = plt.gca().lines[0] + assert isinstance(line, mpl.lines.Line2D) + assert len(np.asarray(line.get_xdata())) == ((len(bins) - 1) * 2) def test_coord_with_interval_step_y(self) -> None: """Test step plot with intervals explicitly on y axis.""" bins = [-1, 0, 1, 2] self.darray.groupby_bins("dim_0", bins).mean(...).plot.step(y="dim_0_bins") - assert len(plt.gca().lines[0].get_xdata()) == ((len(bins) - 1) * 2) + line = plt.gca().lines[0] + assert isinstance(line, mpl.lines.Line2D) + assert len(np.asarray(line.get_xdata())) == ((len(bins) - 1) * 2) def test_coord_with_interval_step_x_and_y_raises_valueeerror(self) -> None: """Test that step plot with intervals both on x and y axes raises an error.""" @@ -883,8 +895,11 @@ def test_can_pass_in_axis(self) -> None: self.pass_in_axis(self.darray.plot.hist) def test_primitive_returned(self) -> None: - h = self.darray.plot.hist() - assert isinstance(h[-1][0], mpl.patches.Rectangle) + n, bins, patches = self.darray.plot.hist() + assert isinstance(n, np.ndarray) + assert isinstance(bins, np.ndarray) + assert isinstance(patches, mpl.container.BarContainer) + assert isinstance(patches[0], mpl.patches.Rectangle) @pytest.mark.slow def test_plot_nans(self) -> None: @@ -928,9 +943,9 @@ def test_cmap_sequential_option(self) -> None: assert cmap_params["cmap"] == "magma" def test_cmap_sequential_explicit_option(self) -> None: - with xr.set_options(cmap_sequential=mpl.cm.magma): + with xr.set_options(cmap_sequential=mpl.colormaps["magma"]): cmap_params = _determine_cmap_params(self.data) - assert cmap_params["cmap"] == mpl.cm.magma + assert cmap_params["cmap"] == mpl.colormaps["magma"] def test_cmap_divergent_option(self) -> None: with xr.set_options(cmap_divergent="magma"): @@ -1170,7 +1185,7 @@ def test_discrete_colormap_list_of_levels(self) -> None: def test_discrete_colormap_int_levels(self) -> None: for extend, levels, vmin, vmax, cmap in [ ("neither", 7, None, None, None), - ("neither", 7, None, 20, mpl.cm.RdBu), + ("neither", 7, None, 20, mpl.colormaps["RdBu"]), ("both", 7, 4, 8, None), ("min", 10, 4, 15, None), ]: @@ -1720,8 +1735,8 @@ class TestContour(Common2dMixin, PlotTestCase): # matplotlib cmap.colors gives an rgbA ndarray # when seaborn is used, instead we get an rgb tuple @staticmethod - def _color_as_tuple(c): - return tuple(c[:3]) + def _color_as_tuple(c: Any) -> tuple[Any, Any, Any]: + return c[0], c[1], c[2] def test_colors(self) -> None: # with single color, we don't want rgb array @@ -1743,10 +1758,16 @@ def test_colors_np_levels(self) -> None: # https://github.com/pydata/xarray/issues/3284 levels = np.array([-0.5, 0.0, 0.5, 1.0]) artist = self.darray.plot.contour(levels=levels, colors=["k", "r", "w", "b"]) - assert self._color_as_tuple(artist.cmap.colors[1]) == (1.0, 0.0, 0.0) - assert self._color_as_tuple(artist.cmap.colors[2]) == (1.0, 1.0, 1.0) + cmap = artist.cmap + assert isinstance(cmap, mpl.colors.ListedColormap) + colors = cmap.colors + assert isinstance(colors, list) + + assert self._color_as_tuple(colors[1]) == (1.0, 0.0, 0.0) + assert self._color_as_tuple(colors[2]) == (1.0, 1.0, 1.0) # the last color is now under "over" - assert self._color_as_tuple(artist.cmap._rgba_over) == (0.0, 0.0, 1.0) + assert hasattr(cmap, "_rgba_over") + assert self._color_as_tuple(cmap._rgba_over) == (0.0, 0.0, 1.0) def test_cmap_and_color_both(self) -> None: with pytest.raises(ValueError): @@ -1798,7 +1819,9 @@ def test_dont_infer_interval_breaks_for_cartopy(self) -> None: artist = self.plotmethod(x="x2d", y="y2d", ax=ax) assert isinstance(artist, mpl.collections.QuadMesh) # Let cartopy handle the axis limits and artist size - assert artist.get_array().size <= self.darray.size + arr = artist.get_array() + assert arr is not None + assert arr.size <= self.darray.size class TestPcolormeshLogscale(PlotTestCase): @@ -1949,6 +1972,7 @@ def test_normalize_rgb_imshow( ) -> None: da = DataArray(easy_array((5, 5, 3), start=-0.6, stop=1.4)) arr = da.plot.imshow(vmin=vmin, vmax=vmax, robust=robust).get_array() + assert arr is not None assert 0 <= arr.min() <= arr.max() <= 1 def test_normalize_rgb_one_arg_error(self) -> None: @@ -1965,7 +1989,10 @@ def test_imshow_rgb_values_in_valid_range(self) -> None: da = DataArray(np.arange(75, dtype="uint8").reshape((5, 5, 3))) _, ax = plt.subplots() out = da.plot.imshow(ax=ax).get_array() - assert out.dtype == np.uint8 + assert out is not None + dtype = out.dtype + assert dtype is not None + assert dtype == np.uint8 assert (out[..., :3] == da.values).all() # Compare without added alpha @pytest.mark.filterwarnings("ignore:Several dimensions of this array") @@ -2000,6 +2027,7 @@ def test_2d_coord_names(self) -> None: self.plotmethod(x="x2d", y="y2d") # make sure labels came out ok ax = plt.gca() + assert isinstance(ax, mpl_toolkits.mplot3d.axes3d.Axes3D) assert "x2d" == ax.get_xlabel() assert "y2d" == ax.get_ylabel() assert f"{self.darray.long_name} [{self.darray.units}]" == ax.get_zlabel() @@ -2122,6 +2150,7 @@ def test_colorbar(self) -> None: self.g.map_dataarray(xplt.imshow, "x", "y") for image in plt.gcf().findobj(mpl.image.AxesImage): + assert isinstance(image, mpl.image.AxesImage) clim = np.array(image.get_clim()) assert np.allclose(expected, clim) @@ -2132,7 +2161,9 @@ def test_colorbar_scatter(self) -> None: fg: xplt.FacetGrid = ds.plot.scatter(x="a", y="a", row="x", hue="a") cbar = fg.cbar assert cbar is not None + assert hasattr(cbar, "vmin") assert cbar.vmin == 0 + assert hasattr(cbar, "vmax") assert cbar.vmax == 3 @pytest.mark.slow @@ -2199,6 +2230,7 @@ def test_can_set_vmin_vmax(self) -> None: self.g.map_dataarray(xplt.imshow, "x", "y", vmin=vmin, vmax=vmax) for image in plt.gcf().findobj(mpl.image.AxesImage): + assert isinstance(image, mpl.image.AxesImage) clim = np.array(image.get_clim()) assert np.allclose(expected, clim) @@ -2215,6 +2247,7 @@ def test_can_set_norm(self) -> None: norm = mpl.colors.SymLogNorm(0.1) self.g.map_dataarray(xplt.imshow, "x", "y", norm=norm) for image in plt.gcf().findobj(mpl.image.AxesImage): + assert isinstance(image, mpl.image.AxesImage) assert image.norm is norm @pytest.mark.slow @@ -2752,15 +2785,20 @@ def test_non_numeric_legend(self) -> None: ds2 = self.ds.copy() ds2["hue"] = ["a", "b", "c", "d"] pc = ds2.plot.scatter(x="A", y="B", markersize="hue") + axes = pc.axes + assert axes is not None # should make a discrete legend - assert pc.axes.legend_ is not None + assert hasattr(axes, "legend_") + assert axes.legend_ is not None def test_legend_labels(self) -> None: # regression test for #4126: incorrect legend labels ds2 = self.ds.copy() ds2["hue"] = ["a", "a", "b", "b"] pc = ds2.plot.scatter(x="A", y="B", markersize="hue") - actual = [t.get_text() for t in pc.axes.get_legend().texts] + axes = pc.axes + assert axes is not None + actual = [t.get_text() for t in axes.get_legend().texts] expected = ["hue", "a", "b"] assert actual == expected @@ -2781,7 +2819,9 @@ def test_legend_labels_facetgrid(self) -> None: def test_add_legend_by_default(self) -> None: sc = self.ds.plot.scatter(x="A", y="B", hue="hue") - assert len(sc.figure.axes) == 2 + fig = sc.figure + assert fig is not None + assert len(fig.axes) == 2 class TestDatetimePlot(PlotTestCase): @@ -2834,6 +2874,7 @@ def test_datetime_plot2d(self) -> None: p = da.plot.pcolormesh() ax = p.axes + assert ax is not None # Make sure only mpl converters are used, use type() so only # mpl.dates.AutoDateLocator passes and no other subclasses: