Skip to content

Commit b52b29d

Browse files
authored
Allow plotting categorical data (#5464)
1 parent 5381962 commit b52b29d

File tree

4 files changed

+57
-22
lines changed

4 files changed

+57
-22
lines changed

doc/whats-new.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ New Features
4444
By `Thomas Hirtz <https://github.com/thomashirtz>`_.
4545
- allow passing a function to ``combine_attrs`` (:pull:`4896`).
4646
By `Justus Magin <https://github.com/keewis>`_.
47+
- Allow plotting categorical data (:pull:`5464`).
48+
By `Jimmy Westling <https://github.com/illviljan>`_.
4749

4850
Breaking changes
4951
~~~~~~~~~~~~~~~~

xarray/plot/plot.py

Lines changed: 39 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -925,18 +925,26 @@ def imshow(x, y, z, ax, **kwargs):
925925
"imshow requires 1D coordinates, try using pcolormesh or contour(f)"
926926
)
927927

928-
# Centering the pixels- Assumes uniform spacing
929-
try:
930-
xstep = (x[1] - x[0]) / 2.0
931-
except IndexError:
932-
# Arbitrary default value, similar to matplotlib behaviour
933-
xstep = 0.1
934-
try:
935-
ystep = (y[1] - y[0]) / 2.0
936-
except IndexError:
937-
ystep = 0.1
938-
left, right = x[0] - xstep, x[-1] + xstep
939-
bottom, top = y[-1] + ystep, y[0] - ystep
928+
def _center_pixels(x):
929+
"""Center the pixels on the coordinates."""
930+
if np.issubdtype(x.dtype, str):
931+
# When using strings as inputs imshow converts it to
932+
# integers. Choose extent values which puts the indices in
933+
# in the center of the pixels:
934+
return 0 - 0.5, len(x) - 0.5
935+
936+
try:
937+
# Center the pixels assuming uniform spacing:
938+
xstep = 0.5 * (x[1] - x[0])
939+
except IndexError:
940+
# Arbitrary default value, similar to matplotlib behaviour:
941+
xstep = 0.1
942+
943+
return x[0] - xstep, x[-1] + xstep
944+
945+
# Center the pixels:
946+
left, right = _center_pixels(x)
947+
top, bottom = _center_pixels(y)
940948

941949
defaults = {"origin": "upper", "interpolation": "nearest"}
942950

@@ -967,6 +975,13 @@ def imshow(x, y, z, ax, **kwargs):
967975

968976
primitive = ax.imshow(z, **defaults)
969977

978+
# If x or y are strings the ticklabels have been replaced with
979+
# integer indices. Replace them back to strings:
980+
for axis, v in [("x", x), ("y", y)]:
981+
if np.issubdtype(v.dtype, str):
982+
getattr(ax, f"set_{axis}ticks")(np.arange(len(v)))
983+
getattr(ax, f"set_{axis}ticklabels")(v)
984+
970985
return primitive
971986

972987

@@ -1011,9 +1026,13 @@ def pcolormesh(x, y, z, ax, infer_intervals=None, **kwargs):
10111026
else:
10121027
infer_intervals = True
10131028

1014-
if infer_intervals and (
1015-
(np.shape(x)[0] == np.shape(z)[1])
1016-
or ((x.ndim > 1) and (np.shape(x)[1] == np.shape(z)[1]))
1029+
if (
1030+
infer_intervals
1031+
and not np.issubdtype(x.dtype, str)
1032+
and (
1033+
(np.shape(x)[0] == np.shape(z)[1])
1034+
or ((x.ndim > 1) and (np.shape(x)[1] == np.shape(z)[1]))
1035+
)
10171036
):
10181037
if len(x.shape) == 1:
10191038
x = _infer_interval_breaks(x, check_monotonic=True)
@@ -1022,7 +1041,11 @@ def pcolormesh(x, y, z, ax, infer_intervals=None, **kwargs):
10221041
x = _infer_interval_breaks(x, axis=1)
10231042
x = _infer_interval_breaks(x, axis=0)
10241043

1025-
if infer_intervals and (np.shape(y)[0] == np.shape(z)[0]):
1044+
if (
1045+
infer_intervals
1046+
and not np.issubdtype(y.dtype, str)
1047+
and (np.shape(y)[0] == np.shape(z)[0])
1048+
):
10261049
if len(y.shape) == 1:
10271050
y = _infer_interval_breaks(y, check_monotonic=True)
10281051
else:

xarray/plot/utils.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -604,7 +604,14 @@ def _ensure_plottable(*args):
604604
Raise exception if there is anything in args that can't be plotted on an
605605
axis by matplotlib.
606606
"""
607-
numpy_types = [np.floating, np.integer, np.timedelta64, np.datetime64, np.bool_]
607+
numpy_types = [
608+
np.floating,
609+
np.integer,
610+
np.timedelta64,
611+
np.datetime64,
612+
np.bool_,
613+
np.str_,
614+
]
608615
other_types = [datetime]
609616
try:
610617
import cftime

xarray/tests/test_plot.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -684,10 +684,9 @@ def test_format_string(self):
684684
def test_can_pass_in_axis(self):
685685
self.pass_in_axis(self.darray.plot.line)
686686

687-
def test_nonnumeric_index_raises_typeerror(self):
687+
def test_nonnumeric_index(self):
688688
a = DataArray([1, 2, 3], {"letter": ["a", "b", "c"]}, dims="letter")
689-
with pytest.raises(TypeError, match=r"[Pp]lot"):
690-
a.plot.line()
689+
a.plot.line()
691690

692691
def test_primitive_returned(self):
693692
p = self.darray.plot.line()
@@ -1162,9 +1161,13 @@ def test_3d_raises_valueerror(self):
11621161
with pytest.raises(ValueError, match=r"DataArray must be 2d"):
11631162
self.plotfunc(a)
11641163

1165-
def test_nonnumeric_index_raises_typeerror(self):
1164+
def test_nonnumeric_index(self):
11661165
a = DataArray(easy_array((3, 2)), coords=[["a", "b", "c"], ["d", "e"]])
1167-
with pytest.raises(TypeError, match=r"[Pp]lot"):
1166+
if self.plotfunc.__name__ == "surface":
1167+
# ax.plot_surface errors with nonnumerics:
1168+
with pytest.raises(Exception):
1169+
self.plotfunc(a)
1170+
else:
11681171
self.plotfunc(a)
11691172

11701173
def test_multiindex_raises_typeerror(self):

0 commit comments

Comments
 (0)