Skip to content

Commit f270b9f

Browse files
author
Joe Hamman
committed
Merge pull request #608 from jhamman/feature/select_coords_for_plotting
allow passing coordinate names as x and y to plot methods
2 parents 7edd526 + e25ac4b commit f270b9f

File tree

5 files changed

+142
-101
lines changed

5 files changed

+142
-101
lines changed

doc/plotting.rst

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -141,8 +141,7 @@ Simple Example
141141
~~~~~~~~~~~~~~
142142

143143
The default method :py:meth:`xray.DataArray.plot` sees that the data is
144-
2 dimensional. If the coordinates are uniformly spaced then it
145-
calls :py:func:`xray.plot.imshow`.
144+
2 dimensional and calls :py:func:`xray.plot.pcolormesh`.
146145

147146
.. ipython:: python
148147
@@ -159,6 +158,14 @@ and ``xincrease``.
159158
@savefig 2d_simple_yincrease.png width=4in
160159
air2d.plot(yincrease=False)
161160
161+
.. note::
162+
163+
We use :py:func:`xray.plot.pcolormesh` as the default two-dimensional plot
164+
method because it is more flexible than :py:func:`xray.plot.imshow`.
165+
However, for large arrays, ``imshow`` can be much faster than ``pcolormesh``.
166+
If speed is important to you and you are plotting a regular mesh, consider
167+
using ``imshow``.
168+
162169
Missing Values
163170
~~~~~~~~~~~~~~
164171

@@ -176,9 +183,9 @@ Xray plots data with :ref:`missing_values`.
176183
Nonuniform Coordinates
177184
~~~~~~~~~~~~~~~~~~~~~~
178185

179-
It's not necessary for the coordinates to be evenly spaced. If not, then
180-
:py:meth:`xray.DataArray.plot` produces a filled contour plot by calling
181-
:py:func:`xray.plot.contourf`.
186+
It's not necessary for the coordinates to be evenly spaced. Both
187+
:py:func:`xray.plot.pcolormesh` (default) and :py:func:`xray.plot.contourf` can
188+
produce plots with nonuniform coordinates.
182189

183190
.. ipython:: python
184191
@@ -201,6 +208,7 @@ matplotlib is available.
201208
plt.title('These colors prove North America\nhas fallen in the ocean')
202209
plt.ylabel('latitude')
203210
plt.xlabel('longitude')
211+
plt.tight_layout()
204212
205213
@savefig plotting_2d_call_matplotlib.png width=4in
206214
plt.show()
@@ -376,8 +384,8 @@ Faceted plotting supports other arguments common to xray 2d plots.
376384
hasoutliers[-1, -1, -1] = 400
377385
378386
@savefig plot_facet_robust.png height=12in
379-
g = hasoutliers.plot.imshow('lon', 'lat', col='time', col_wrap=3,
380-
robust=True, cmap='viridis')
387+
g = hasoutliers.plot.pcolormesh('lon', 'lat', col='time', col_wrap=3,
388+
robust=True, cmap='viridis')
381389
382390
FacetGrid Objects
383391
~~~~~~~~~~~~~~~~~
@@ -473,14 +481,13 @@ plotting function based on the dimensions of the ``DataArray`` and whether
473481
the coordinates are sorted and uniformly spaced. This table
474482
describes what gets plotted:
475483

476-
=============== =========== ===========================
477-
Dimensions Coordinates Plotting function
478-
--------------- ----------- ---------------------------
479-
1 :py:func:`xray.plot.line`
480-
2 Uniform :py:func:`xray.plot.imshow`
481-
2 Irregular :py:func:`xray.plot.contourf`
482-
Anything else :py:func:`xray.plot.hist`
483-
=============== =========== ===========================
484+
=============== ===========================
485+
Dimensions Plotting function
486+
--------------- ---------------------------
487+
1 :py:func:`xray.plot.line`
488+
2 :py:func:`xray.plot.pcolormesh`
489+
Anything else :py:func:`xray.plot.hist`
490+
=============== ===========================
484491

485492
Coordinates
486493
~~~~~~~~~~~

xray/plot/facetgrid.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import numpy as np
88

99
from ..core.formatting import format_item
10-
from .utils import _determine_cmap_params
10+
from .utils import _determine_cmap_params, _infer_xy_labels
1111

1212

1313
# Overrides axes.labelsize, xtick.major.size, ytick.major.size
@@ -242,6 +242,10 @@ def map_dataarray(self, func, x, y, **kwargs):
242242
defaults.update(cmap_params)
243243
defaults.update(kwargs)
244244

245+
# Get x, y labels for the first subplot
246+
x, y = _infer_xy_labels(darray=self.data.loc[self.name_dicts.flat[0]],
247+
x=x, y=y)
248+
245249
for d, ax in zip(self.name_dicts.flat, self.axes.flat):
246250
# None is the sentinel value
247251
if d is not None:
@@ -270,7 +274,7 @@ def map_dataarray(self, func, x, y, **kwargs):
270274
extend=cmap_params['extend'])
271275

272276
if self.data.name:
273-
cbar.set_label(self.data.name, rotation=270,
277+
cbar.set_label(self.data.name, rotation=90,
274278
verticalalignment='bottom')
275279

276280
self._x_var = x

xray/plot/plot.py

Lines changed: 21 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,8 @@
1313
import numpy as np
1414
import pandas as pd
1515

16-
from .utils import _determine_cmap_params
16+
from .utils import _determine_cmap_params, _infer_xy_labels
1717
from .facetgrid import FacetGrid
18-
from ..core.utils import is_uniform_spaced
1918

2019

2120
# Maybe more appropriate to keep this in .utils
@@ -39,47 +38,6 @@ def _ensure_plottable(*args):
3938
'or dates.')
4039

4140

42-
def _infer_xy_labels(plotfunc, darray, x, y):
43-
"""
44-
Determine x and y labels when some are missing. For use in _plot2d
45-
46-
darray is a 2 dimensional data array.
47-
"""
48-
dims = list(darray.dims)
49-
50-
if len(dims) != 2:
51-
raise ValueError('{type} plots are for 2 dimensional DataArrays. '
52-
'Passed DataArray has {ndim} dimensions'
53-
.format(type=plotfunc.__name__, ndim=len(dims)))
54-
55-
if x and x not in dims:
56-
raise KeyError('{0} is not a dimension of this DataArray. Use '
57-
'{1} or {2} for x'
58-
.format(x, *dims))
59-
60-
if y and y not in dims:
61-
raise KeyError('{0} is not a dimension of this DataArray. Use '
62-
'{1} or {2} for y'
63-
.format(y, *dims))
64-
65-
# Get label names
66-
if x and y:
67-
xlab = x
68-
ylab = y
69-
elif x and not y:
70-
xlab = x
71-
del dims[dims.index(x)]
72-
ylab = dims.pop()
73-
elif y and not x:
74-
ylab = y
75-
del dims[dims.index(y)]
76-
xlab = dims.pop()
77-
else:
78-
ylab, xlab = dims
79-
80-
return xlab, ylab
81-
82-
8341
def _easy_facetgrid(darray, plotfunc, x, y, row=None, col=None, col_wrap=None,
8442
aspect=1, size=3, subplot_kws=None, **kwargs):
8543
"""
@@ -99,19 +57,18 @@ def _easy_facetgrid(darray, plotfunc, x, y, row=None, col=None, col_wrap=None,
9957
def plot(darray, row=None, col=None, col_wrap=None, ax=None, rtol=0.01,
10058
subplot_kws=None, **kwargs):
10159
"""
102-
Default plot of DataArray using matplotlib / pylab.
60+
Default plot of DataArray using matplotlib.pyplot.
10361
10462
Calls xray plotting function based on the dimensions of
10563
darray.squeeze()
10664
107-
=============== =========== ===========================
108-
Dimensions Coordinates Plotting function
109-
--------------- ----------- ---------------------------
110-
1 :py:func:`xray.plot.line`
111-
2 Uniform :py:func:`xray.plot.imshow`
112-
2 Irregular :py:func:`xray.plot.contourf`
113-
Anything else :py:func:`xray.plot.hist`
114-
=============== =========== ===========================
65+
=============== ===========================
66+
Dimensions Plotting function
67+
--------------- ---------------------------
68+
1 :py:func:`xray.plot.line`
69+
2 :py:func:`xray.plot.pcolormesh`
70+
Anything else :py:func:`xray.plot.hist`
71+
=============== ===========================
11572
11673
Parameters
11774
----------
@@ -156,9 +113,7 @@ def plot(darray, row=None, col=None, col_wrap=None, ax=None, rtol=0.01,
156113
kwargs['col_wrap'] = col_wrap
157114
kwargs['subplot_kws'] = subplot_kws
158115

159-
indexes = (darray.indexes[dim].values for dim in plot_dims)
160-
uniform = all(is_uniform_spaced(i, rtol=rtol) for i in indexes)
161-
plotfunc = imshow if uniform else contourf
116+
plotfunc = pcolormesh
162117
else:
163118
if row or col:
164119
raise ValueError(error_msg)
@@ -376,7 +331,7 @@ def _plot2d(plotfunc):
376331

377332
@functools.wraps(plotfunc)
378333
def newplotfunc(darray, x=None, y=None, ax=None, row=None, col=None,
379-
col_wrap=None, xincrease=None, yincrease=None,
334+
col_wrap=None, xincrease=True, yincrease=True,
380335
add_colorbar=True, add_labels=True, vmin=None, vmax=None,
381336
cmap=None, center=None, robust=False, extend=None,
382337
levels=None, colors=None, subplot_kws=None, **kwargs):
@@ -416,8 +371,7 @@ def newplotfunc(darray, x=None, y=None, ax=None, row=None, col=None,
416371
if ax is None:
417372
ax = plt.gca()
418373

419-
xlab, ylab = _infer_xy_labels(plotfunc=plotfunc, darray=darray,
420-
x=x, y=y)
374+
xlab, ylab = _infer_xy_labels(darray=darray, x=x, y=y)
421375

422376
# better to pass the ndarrays directly to plotting functions
423377
xval = darray[xlab].values
@@ -471,7 +425,7 @@ def newplotfunc(darray, x=None, y=None, ax=None, row=None, col=None,
471425
if add_colorbar:
472426
cbar = plt.colorbar(primitive, ax=ax, extend=cmap_params['extend'])
473427
if darray.name and add_labels:
474-
cbar.set_label(darray.name)
428+
cbar.set_label(darray.name, rotation=90)
475429

476430
_update_axes_limits(ax, xincrease, yincrease)
477431

@@ -480,7 +434,7 @@ def newplotfunc(darray, x=None, y=None, ax=None, row=None, col=None,
480434
# For use as DataArray.plot.plotmethod
481435
@functools.wraps(newplotfunc)
482436
def plotmethod(_PlotMethods_obj, x=None, y=None, ax=None, row=None,
483-
col=None, col_wrap=None, xincrease=None, yincrease=None,
437+
col=None, col_wrap=None, xincrease=True, yincrease=True,
484438
add_colorbar=True, add_labels=True, vmin=None, vmax=None,
485439
cmap=None, colors=None, center=None, robust=False,
486440
extend=None, levels=None, subplot_kws=None, **kwargs):
@@ -506,7 +460,7 @@ def plotmethod(_PlotMethods_obj, x=None, y=None, ax=None, row=None,
506460
@_plot2d
507461
def imshow(x, y, z, ax, **kwargs):
508462
"""
509-
Image plot of 2d DataArray using matplotlib / pylab
463+
Image plot of 2d DataArray using matplotlib.pyplot
510464
511465
Wraps matplotlib.pyplot.imshow
512466
@@ -518,6 +472,11 @@ def imshow(x, y, z, ax, **kwargs):
518472
The pixels are centered on the coordinates values. Ie, if the coordinate
519473
value is 3.2 then the pixels for those coordinates will be centered on 3.2.
520474
"""
475+
476+
if x.ndim != 1 or y.ndim != 1:
477+
raise ValueError('imshow requires 1D coordinates, try using '
478+
'pcolormesh or contour(f)')
479+
521480
# Centering the pixels- Assumes uniform spacing
522481
xstep = (x[1] - x[0]) / 2.0
523482
ystep = (y[1] - y[0]) / 2.0
@@ -589,7 +548,7 @@ def pcolormesh(x, y, z, ax, **kwargs):
589548

590549
# by default, pcolormesh picks "round" values for bounds
591550
# this results in ugly looking plots with lots of surrounding whitespace
592-
if not hasattr(ax, 'projection'):
551+
if not hasattr(ax, 'projection') and x.ndim == 1 and y.ndim == 1:
593552
# not a cartopy geoaxis
594553
ax.set_xlim(x[0], x[-1])
595554
ax.set_ylim(y[0], y[-1])

xray/plot/utils.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,3 +177,21 @@ def _determine_cmap_params(plot_data, vmin=None, vmax=None, cmap=None,
177177

178178
return dict(vmin=vmin, vmax=vmax, cmap=cmap, extend=extend,
179179
levels=levels, cnorm=cnorm)
180+
181+
182+
def _infer_xy_labels(darray, x, y):
183+
"""
184+
Determine x and y labels. For use in _plot2d
185+
186+
darray must be a 2 dimensional data array.
187+
"""
188+
189+
if x is None and y is None:
190+
if darray.ndim != 2:
191+
raise ValueError('DataArray must be 2d')
192+
y, x = darray.dims
193+
elif x is None or y is None:
194+
raise ValueError('cannot supply only one of x and y')
195+
elif any(k not in darray.coords for k in (x, y)):
196+
raise ValueError('x and y must be coordinate variables')
197+
return x, y

0 commit comments

Comments
 (0)