Skip to content

Commit 36f05d7

Browse files
Use .to_numpy() for quantified facetgrids (#5886)
Co-authored-by: Illviljan <[email protected]>
1 parent c210f8b commit 36f05d7

File tree

4 files changed

+40
-13
lines changed

4 files changed

+40
-13
lines changed

doc/whats-new.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,9 @@ Bug fixes
8282
By `Jimmy Westling <https://github.com/illviljan>`_.
8383
- Numbers are properly formatted in a plot's title (:issue:`5788`, :pull:`5789`).
8484
By `Maxime Liquet <https://github.com/maximlt>`_.
85+
- Faceted plots will no longer raise a `pint.UnitStrippedWarning` when a `pint.Quantity` array is plotted,
86+
and will correctly display the units of the data in the colorbar (if there is one) (:pull:`5886`).
87+
By `Tom Nicholas <https://github.com/TomNicholas>`_.
8588
- With backends, check for path-like objects rather than ``pathlib.Path``
8689
type, use ``os.fspath`` (:pull:`5879`).
8790
By `Mike Taves <https://github.com/mwtoews>`_.

xarray/plot/facetgrid.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -173,11 +173,11 @@ def __init__(
173173
)
174174

175175
# Set up the lists of names for the row and column facet variables
176-
col_names = list(data[col].values) if col else []
177-
row_names = list(data[row].values) if row else []
176+
col_names = list(data[col].to_numpy()) if col else []
177+
row_names = list(data[row].to_numpy()) if row else []
178178

179179
if single_group:
180-
full = [{single_group: x} for x in data[single_group].values]
180+
full = [{single_group: x} for x in data[single_group].to_numpy()]
181181
empty = [None for x in range(nrow * ncol - len(full))]
182182
name_dicts = full + empty
183183
else:
@@ -251,7 +251,7 @@ def map_dataarray(self, func, x, y, **kwargs):
251251
raise ValueError("cbar_ax not supported by FacetGrid.")
252252

253253
cmap_params, cbar_kwargs = _process_cmap_cbar_kwargs(
254-
func, self.data.values, **kwargs
254+
func, self.data.to_numpy(), **kwargs
255255
)
256256

257257
self._cmap_extend = cmap_params.get("extend")
@@ -347,7 +347,7 @@ def map_dataset(
347347

348348
if hue and meta_data["hue_style"] == "continuous":
349349
cmap_params, cbar_kwargs = _process_cmap_cbar_kwargs(
350-
func, self.data[hue].values, **kwargs
350+
func, self.data[hue].to_numpy(), **kwargs
351351
)
352352
kwargs["meta_data"]["cmap_params"] = cmap_params
353353
kwargs["meta_data"]["cbar_kwargs"] = cbar_kwargs
@@ -423,7 +423,7 @@ def _adjust_fig_for_guide(self, guide):
423423
def add_legend(self, **kwargs):
424424
self.figlegend = self.fig.legend(
425425
handles=self._mappables[-1],
426-
labels=list(self._hue_var.values),
426+
labels=list(self._hue_var.to_numpy()),
427427
title=self._hue_label,
428428
loc="center right",
429429
**kwargs,
@@ -619,7 +619,7 @@ def map(self, func, *args, **kwargs):
619619
if namedict is not None:
620620
data = self.data.loc[namedict]
621621
plt.sca(ax)
622-
innerargs = [data[a].values for a in args]
622+
innerargs = [data[a].to_numpy() for a in args]
623623
maybe_mappable = func(*innerargs, **kwargs)
624624
# TODO: better way to verify that an artist is mappable?
625625
# https://stackoverflow.com/questions/33023036/is-it-possible-to-detect-if-a-matplotlib-artist-is-a-mappable-suitable-for-use-w#33023522

xarray/plot/plot.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1075,7 +1075,7 @@ def newplotfunc(
10751075
# Matplotlib does not support normalising RGB data, so do it here.
10761076
# See eg. https://github.com/matplotlib/matplotlib/pull/10220
10771077
if robust or vmax is not None or vmin is not None:
1078-
darray = _rescale_imshow_rgb(darray, vmin, vmax, robust)
1078+
darray = _rescale_imshow_rgb(darray.as_numpy(), vmin, vmax, robust)
10791079
vmin, vmax, robust = None, None, False
10801080

10811081
if subplot_kws is None:
@@ -1146,10 +1146,6 @@ def newplotfunc(
11461146
else:
11471147
dims = (yval.dims[0], xval.dims[0])
11481148

1149-
# better to pass the ndarrays directly to plotting functions
1150-
xval = xval.to_numpy()
1151-
yval = yval.to_numpy()
1152-
11531149
# May need to transpose for correct x, y labels
11541150
# xlab may be the name of a coord, we have to check for dim names
11551151
if imshow_rgb:
@@ -1162,6 +1158,10 @@ def newplotfunc(
11621158
if dims != darray.dims:
11631159
darray = darray.transpose(*dims, transpose_coords=True)
11641160

1161+
# better to pass the ndarrays directly to plotting functions
1162+
xval = xval.to_numpy()
1163+
yval = yval.to_numpy()
1164+
11651165
# Pass the data as a masked ndarray too
11661166
zval = darray.to_masked_array(copy=False)
11671167

xarray/tests/test_units.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5614,11 +5614,35 @@ def test_units_in_line_plot_labels(self):
56145614
assert ax.get_ylabel() == "pressure [pascal]"
56155615
assert ax.get_xlabel() == "x [meters]"
56165616

5617-
def test_units_in_2d_plot_labels(self):
5617+
def test_units_in_2d_plot_colorbar_label(self):
56185618
arr = np.ones((2, 3)) * unit_registry.Pa
56195619
da = xr.DataArray(data=arr, dims=["x", "y"], name="pressure")
56205620

56215621
fig, (ax, cax) = plt.subplots(1, 2)
56225622
ax = da.plot.contourf(ax=ax, cbar_ax=cax, add_colorbar=True)
56235623

56245624
assert cax.get_ylabel() == "pressure [pascal]"
5625+
5626+
def test_units_facetgrid_plot_labels(self):
5627+
arr = np.ones((2, 3)) * unit_registry.Pa
5628+
da = xr.DataArray(data=arr, dims=["x", "y"], name="pressure")
5629+
5630+
fig, (ax, cax) = plt.subplots(1, 2)
5631+
fgrid = da.plot.line(x="x", col="y")
5632+
5633+
assert fgrid.axes[0, 0].get_ylabel() == "pressure [pascal]"
5634+
5635+
def test_units_facetgrid_2d_imshow_plot_colorbar_labels(self):
5636+
arr = np.ones((2, 3, 4, 5)) * unit_registry.Pa
5637+
da = xr.DataArray(data=arr, dims=["x", "y", "z", "w"], name="pressure")
5638+
5639+
da.plot.imshow(x="x", y="y", col="w") # no colorbar to check labels of
5640+
5641+
def test_units_facetgrid_2d_contourf_plot_colorbar_labels(self):
5642+
arr = np.ones((2, 3, 4)) * unit_registry.Pa
5643+
da = xr.DataArray(data=arr, dims=["x", "y", "z"], name="pressure")
5644+
5645+
fig, (ax1, ax2, ax3, cax) = plt.subplots(1, 4)
5646+
fgrid = da.plot.contourf(x="x", y="y", col="z")
5647+
5648+
assert fgrid.cbar.ax.get_ylabel() == "pressure [pascal]"

0 commit comments

Comments
 (0)