diff --git a/ultraplot/axes/base.py b/ultraplot/axes/base.py index 70d7c4f9..41c33bab 100644 --- a/ultraplot/axes/base.py +++ b/ultraplot/axes/base.py @@ -1249,11 +1249,40 @@ def _add_colorbar( width=tickwidth * tickwidthratio, ) # noqa: E501 if label is not None: - obj.set_label(label) + # Note for some reason axis.set_label does not work here. We need to use set_x/ylabel explicitly + match loc: + case "top" | "bottom": + if labelloc in (None, "top", "bottom"): + obj.set_label(label) + elif labelloc in ("left", "right"): + obj.ax.set_ylabel(label) + else: + raise ValueError("Could not determined position") + case "left" | "right": + if labelloc in (None, "left", "right"): + obj.set_label(label) + elif labelloc in ("top", "bottom"): + obj.ax.set_xlabel(label) + else: + raise ValueError("Could not determined position") + # Default to setting label on long axis + case _: + obj.set_label(label) if labelloc is not None: + # Temporarily modify the axis to set the label and its properties + match loc: + case "top" | "bottom": + if labelloc in ("left", "right"): + axis = obj._short_axis() + case "left" | "right": + if labelloc in ("top", "bottom"): + axis = obj._short_axis() + case _: + raise ValueError("Location not understood.") axis.set_label_position(labelloc) axis.label.update(kw_label) - for label in axis.get_ticklabels(): + # Assume ticks are set on the long axis(!) + for label in obj._long_axis().get_ticklabels(): label.update(kw_ticklabels) kw_outline = {"edgecolor": color, "linewidth": linewidth} if obj.outline is not None: diff --git a/ultraplot/tests/test_colorbar.py b/ultraplot/tests/test_colorbar.py index 28866c71..98a2ec11 100644 --- a/ultraplot/tests/test_colorbar.py +++ b/ultraplot/tests/test_colorbar.py @@ -275,3 +275,16 @@ def test_draw_edges(): axi.colorbar(h, drawedges=drawedges) axi.set_title(f"{drawedges=}") return fig + + +def test_label_placement_colorbar(): + """ + Ensure that all potential combinations of colorbar + label placement is possible. + """ + data = np.random.rand(10, 10) + fig, ax = uplt.subplots() + h = ax.imshow(data) + locs = "top bottom left right".split() + for loc, labelloc in zip(locs, locs): + ax.colorbar(h, loc=loc, labelloc=labelloc)