Skip to content

Commit 834d4c4

Browse files
Allow passing axis kwargs to plot (#4020)
* fix facecolor plot * temp version * finish fix facecolor + solves #3169 * black formatting * add testing * allow cartopy projection to be a kwarg * fix PEP8 comment * black formatting * fix testing, plt not in parameterize * fix testing, allows for no matplotlib * black formating * fix tests without matplotlib * fix some mistakes * isort, mypy * fix mypy * remove empty line * correction from review * correction from 2nd review * updated tests * updated tests * black formatting * follow up correction from review * fix tests * fix tests again * fix bug in tests * fix pb in tests * remove useless line * clean up tests * fix * Add whats-new Co-authored-by: dcherian <[email protected]>
1 parent 329cefb commit 834d4c4

File tree

6 files changed

+68
-14
lines changed

6 files changed

+68
-14
lines changed

doc/plotting.rst

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -743,12 +743,13 @@ This script will plot the air temperature on a map.
743743
744744
air = xr.tutorial.open_dataset("air_temperature").air
745745
746-
ax = plt.axes(projection=ccrs.Orthographic(-80, 35))
747-
air.isel(time=0).plot.contourf(ax=ax, transform=ccrs.PlateCarree())
748-
ax.set_global()
746+
p = air.isel(time=0).plot(
747+
subplot_kws=dict(projection=ccrs.Orthographic(-80, 35), facecolor="gray"),
748+
transform=ccrs.PlateCarree())
749+
p.axes.set_global()
749750
750751
@savefig plotting_maps_cartopy.png width=100%
751-
ax.coastlines()
752+
p.axes.coastlines()
752753
753754
When faceting on maps, the projection can be transferred to the ``plot``
754755
function using the ``subplot_kws`` keyword. The axes for the subplots created

doc/whats-new.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,9 @@ Enhancements
5454
By `Stephan Hoyer <https://github.com/shoyer>`_.
5555
- :py:meth:`DataArray.reset_index` and :py:meth:`Dataset.reset_index` now keep
5656
coordinate attributes (:pull:`4103`). By `Oriol Abril <https://github.com/OriolAbril>`_.
57+
- Axes kwargs such as ``facecolor`` can now be passed to :py:meth:`DataArray.plot` in ``subplot_kws``.
58+
This works for both single axes plots and FacetGrid plots.
59+
By `Raphael Dussin <https://github.com/raphaeldussin`_.
5760

5861
New Features
5962
~~~~~~~~~~~~

xarray/plot/plot.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -155,8 +155,7 @@ def plot(
155155
Relative tolerance used to determine if the indexes
156156
are uniformly spaced. Usually a small positive number.
157157
subplot_kws : dict, optional
158-
Dictionary of keyword arguments for matplotlib subplots. Only applies
159-
to FacetGrid plotting.
158+
Dictionary of keyword arguments for matplotlib subplots.
160159
**kwargs : optional
161160
Additional keyword arguments to matplotlib
162161
@@ -177,10 +176,10 @@ def plot(
177176

178177
if ndims in [1, 2]:
179178
if row or col:
179+
kwargs["subplot_kws"] = subplot_kws
180180
kwargs["row"] = row
181181
kwargs["col"] = col
182182
kwargs["col_wrap"] = col_wrap
183-
kwargs["subplot_kws"] = subplot_kws
184183
if ndims == 1:
185184
plotfunc = line
186185
kwargs["hue"] = hue
@@ -190,6 +189,7 @@ def plot(
190189
kwargs["hue"] = hue
191190
else:
192191
plotfunc = pcolormesh
192+
kwargs["subplot_kws"] = subplot_kws
193193
else:
194194
if row or col or hue:
195195
raise ValueError(error_msg)
@@ -553,8 +553,8 @@ def _plot2d(plotfunc):
553553
always infer intervals, unless the mesh is irregular and plotted on
554554
a map projection.
555555
subplot_kws : dict, optional
556-
Dictionary of keyword arguments for matplotlib subplots. Only applies
557-
to FacetGrid plotting.
556+
Dictionary of keyword arguments for matplotlib subplots. Only used
557+
for 2D and FacetGrid plots.
558558
cbar_ax : matplotlib Axes, optional
559559
Axes in which to draw the colorbar.
560560
cbar_kwargs : dict, optional
@@ -724,7 +724,10 @@ def newplotfunc(
724724
"plt.imshow's `aspect` kwarg is not available " "in xarray"
725725
)
726726

727-
ax = get_axis(figsize, size, aspect, ax)
727+
if subplot_kws is None:
728+
subplot_kws = dict()
729+
ax = get_axis(figsize, size, aspect, ax, **subplot_kws)
730+
728731
primitive = plotfunc(
729732
xplt,
730733
yplt,

xarray/plot/utils.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -406,9 +406,12 @@ def _assert_valid_xy(darray, xy, name):
406406
raise ValueError(f"{name} must be one of None, '{valid_xy_str}'")
407407

408408

409-
def get_axis(figsize, size, aspect, ax):
410-
import matplotlib as mpl
411-
import matplotlib.pyplot as plt
409+
def get_axis(figsize=None, size=None, aspect=None, ax=None, **kwargs):
410+
try:
411+
import matplotlib as mpl
412+
import matplotlib.pyplot as plt
413+
except ImportError:
414+
raise ImportError("matplotlib is required for plot.utils.get_axis")
412415

413416
if figsize is not None:
414417
if ax is not None:
@@ -427,8 +430,11 @@ def get_axis(figsize, size, aspect, ax):
427430
elif aspect is not None:
428431
raise ValueError("cannot provide `aspect` argument without `size`")
429432

433+
if kwargs and ax is not None:
434+
raise ValueError("cannot use subplot_kws with existing ax")
435+
430436
if ax is None:
431-
ax = plt.gca()
437+
ax = plt.gca(**kwargs)
432438

433439
return ax
434440

xarray/tests/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ def LooseVersion(vstring):
7777
has_numbagg, requires_numbagg = _importorskip("numbagg")
7878
has_seaborn, requires_seaborn = _importorskip("seaborn")
7979
has_sparse, requires_sparse = _importorskip("sparse")
80+
has_cartopy, requires_cartopy = _importorskip("cartopy")
8081

8182
# some special cases
8283
has_scipy_or_netCDF4 = has_scipy or has_netCDF4

xarray/tests/test_plot.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
_build_discrete_cmap,
1616
_color_palette,
1717
_determine_cmap_params,
18+
get_axis,
1819
label_from_attrs,
1920
)
2021

@@ -23,6 +24,7 @@
2324
assert_equal,
2425
has_nc_time_axis,
2526
raises_regex,
27+
requires_cartopy,
2628
requires_cftime,
2729
requires_matplotlib,
2830
requires_nc_time_axis,
@@ -36,6 +38,11 @@
3638
except ImportError:
3739
pass
3840

41+
try:
42+
import cartopy as ctpy # type: ignore
43+
except ImportError:
44+
ctpy = None
45+
3946

4047
@pytest.mark.flaky
4148
@pytest.mark.skip(reason="maybe flaky")
@@ -2393,3 +2400,36 @@ def test_facetgrid_single_contour():
23932400
ds["time"] = [0, 1]
23942401

23952402
ds.plot.contour(col="time", levels=[4], colors=["k"])
2403+
2404+
2405+
@requires_matplotlib
2406+
def test_get_axis():
2407+
# test get_axis works with different args combinations
2408+
# and return the right type
2409+
2410+
# cannot provide both ax and figsize
2411+
with pytest.raises(ValueError, match="both `figsize` and `ax`"):
2412+
get_axis(figsize=[4, 4], size=None, aspect=None, ax="something")
2413+
2414+
# cannot provide both ax and size
2415+
with pytest.raises(ValueError, match="both `size` and `ax`"):
2416+
get_axis(figsize=None, size=200, aspect=4 / 3, ax="something")
2417+
2418+
# cannot provide both size and figsize
2419+
with pytest.raises(ValueError, match="both `figsize` and `size`"):
2420+
get_axis(figsize=[4, 4], size=200, aspect=None, ax=None)
2421+
2422+
# cannot provide aspect and size
2423+
with pytest.raises(ValueError, match="`aspect` argument without `size`"):
2424+
get_axis(figsize=None, size=None, aspect=4 / 3, ax=None)
2425+
2426+
ax = get_axis()
2427+
assert isinstance(ax, mpl.axes.Axes)
2428+
2429+
2430+
@requires_cartopy
2431+
def test_get_axis_cartopy():
2432+
2433+
kwargs = {"projection": ctpy.crs.PlateCarree()}
2434+
ax = get_axis(**kwargs)
2435+
assert isinstance(ax, ctpy.mpl.geoaxes.GeoAxesSubplot)

0 commit comments

Comments
 (0)