Skip to content

Use native legends when converting from matplotlib #5312

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 36 additions & 79 deletions plotly/matplotlylib/renderer.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ def __init__(self):
self.mpl_x_bounds = (0, 1)
self.mpl_y_bounds = (0, 1)
self.msg = "Initialized PlotlyRenderer\n"
self._processing_legend = False
self._legend_visible = False

def open_figure(self, fig, props):
"""Creates a new figure by beginning to fill out layout dict.
Expand Down Expand Up @@ -108,7 +110,6 @@ def close_figure(self, fig):
fig -- a matplotlib.figure.Figure object.

"""
self.plotly_fig["layout"]["showlegend"] = False
self.msg += "Closing figure\n"

def open_axes(self, ax, props):
Expand Down Expand Up @@ -198,6 +199,35 @@ def close_axes(self, ax):
self.msg += " Closing axes\n"
self.x_is_mpl_date = False

def open_legend(self, legend, props):
"""Enable Plotly's native legend when matplotlib legend is detected.

This method is called when a matplotlib legend is found. It enables
Plotly's showlegend only if the matplotlib legend is visible.

Positional arguments:
legend -- matplotlib.legend.Legend object
props -- legend properties dictionary
"""
self.msg += " Opening legend\n"
self._processing_legend = True
self._legend_visible = props.get("visible", True)
if self._legend_visible:
self.msg += " Enabling native plotly legend (matplotlib legend is visible)\n"
self.plotly_fig["layout"]["showlegend"] = True
else:
self.msg += " Not enabling legend (matplotlib legend is not visible)\n"

def close_legend(self, legend):
"""Finalize legend processing.

Positional arguments:
legend -- matplotlib.legend.Legend object
"""
self.msg += " Closing legend\n"
self._processing_legend = False
self._legend_visible = False

def draw_bars(self, bars):

# sort bars according to bar containers
Expand Down Expand Up @@ -310,83 +340,6 @@ def draw_bar(self, coll):
"assuming data redundancy, not plotting."
)

def draw_legend_shapes(self, mode, shape, **props):
"""Create a shape that matches lines or markers in legends.

Main issue is that path for circles do not render, so we have to use 'circle'
instead of 'path'.
"""
for single_mode in mode.split("+"):
x = props["data"][0][0]
y = props["data"][0][1]
if single_mode == "markers" and props.get("markerstyle"):
size = shape.pop("size", 6)
symbol = shape.pop("symbol")
# aligning to "center"
x0 = 0
y0 = 0
x1 = size
y1 = size
markerpath = props["markerstyle"].get("markerpath")
if markerpath is None and symbol != "circle":
self.msg += (
"not sure how to handle this marker without a valid path\n"
)
return
# marker path to SVG path conversion
path = " ".join(
[f"{a} {t[0]},{t[1]}" for a, t in zip(markerpath[1], markerpath[0])]
)

if symbol == "circle":
# symbols like . and o in matplotlib, use circle
# plotly also maps many other markers to circle, such as 1,8 and p
path = None
shape_type = "circle"
x0 = -size / 2
y0 = size / 2
x1 = size / 2
y1 = size + size / 2
else:
# triangles, star etc
shape_type = "path"
legend_shape = go.layout.Shape(
type=shape_type,
xref="paper",
yref="paper",
x0=x0,
y0=y0,
x1=x1,
y1=y1,
xsizemode="pixel",
ysizemode="pixel",
xanchor=x,
yanchor=y,
path=path,
**shape,
)

elif single_mode == "lines":
mode = "line"
x1 = props["data"][1][0]
y1 = props["data"][1][1]

legend_shape = go.layout.Shape(
type=mode,
xref="paper",
yref="paper",
x0=x,
y0=y + 0.02,
x1=x1,
y1=y1 + 0.02,
**shape,
)
else:
self.msg += "not sure how to handle this element\n"
return
self.plotly_fig.add_shape(legend_shape)
self.msg += " Heck yeah, I drew that shape\n"

def draw_marked_line(self, **props):
"""Create a data dict for a line obj.

Expand Down Expand Up @@ -502,7 +455,7 @@ def draw_marked_line(self, **props):
self.msg += " Heck yeah, I drew that line\n"
elif props["coordinates"] == "axes":
# dealing with legend graphical elements
self.draw_legend_shapes(mode=mode, shape=shape, **props)
self.msg += " Using native legend\n"
else:
self.msg += " Line didn't have 'data' coordinates, " "not drawing\n"
warnings.warn(
Expand Down Expand Up @@ -668,6 +621,10 @@ def draw_text(self, **props):
self.draw_title(**props)
else: # just a regular text annotation...
self.msg += " Text object is a normal annotation\n"
# Skip creating annotations for legend text when using native legend
if self._processing_legend and self._legend_visible and props["coordinates"] == "axes":
self.msg += " Skipping legend text annotation (using native legend)\n"
return
if props["coordinates"] != "data":
self.msg += (
" Text object isn't linked to 'data' " "coordinates\n"
Expand Down
4 changes: 4 additions & 0 deletions plotly/matplotlylib/tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
import matplotlib

matplotlib.use("Agg")
import matplotlib.pyplot as plt
81 changes: 81 additions & 0 deletions plotly/matplotlylib/tests/test_renderer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import plotly.tools as tls

from . import plt


def test_native_legend_enabled_when_matplotlib_legend_present():
"""Test that when matplotlib legend is present, Plotly uses native legend."""
fig, ax = plt.subplots()
ax.plot([0, 1], [0, 1], label="Line 1")
ax.plot([0, 1], [1, 0], label="Line 2")
ax.legend()

plotly_fig = tls.mpl_to_plotly(fig)

# Should enable native legend
assert plotly_fig.layout.showlegend == True
# Should have 2 traces with names
assert len(plotly_fig.data) == 2
assert plotly_fig.data[0].name == "Line 1"
assert plotly_fig.data[1].name == "Line 2"


def test_no_fake_legend_shapes_with_native_legend():
"""Test that fake legend shapes are not created when using native legend."""
fig, ax = plt.subplots()
ax.plot([0, 1], [0, 1], "o-", label="Data with markers")
ax.legend()

plotly_fig = tls.mpl_to_plotly(fig)

# Should use native legend
assert plotly_fig.layout.showlegend == True
# Should not create fake legend elements
assert len(plotly_fig.layout.shapes) == 0
assert len(plotly_fig.layout.annotations) == 0


def test_legend_disabled_when_no_matplotlib_legend():
"""Test that legend is not enabled when no matplotlib legend is present."""
fig, ax = plt.subplots()
ax.plot([0, 1], [0, 1], label="Line 1") # Has label but no legend() call

plotly_fig = tls.mpl_to_plotly(fig)

# Should not have showlegend explicitly set to True
# (Plotly's default behavior when no legend elements exist)
assert not hasattr(plotly_fig.layout, 'showlegend') or plotly_fig.layout.showlegend != True


def test_legend_disabled_when_matplotlib_legend_not_visible():
"""Test that legend is not enabled when no matplotlib legend is not visible."""
fig, ax = plt.subplots()
ax.plot([0, 1], [0, 1], label="Line 1")
legend = ax.legend()
legend.set_visible(False) # Hide the legend

plotly_fig = tls.mpl_to_plotly(fig)

# Should not enable legend when matplotlib legend is hidden
assert not hasattr(plotly_fig.layout, 'showlegend') or plotly_fig.layout.showlegend != True


def test_multiple_traces_native_legend():
"""Test native legend works with multiple traces of different types."""
fig, ax = plt.subplots()
ax.plot([0, 1, 2], [0, 1, 0], '-', label="Line")
ax.plot([0, 1, 2], [1, 0, 1], 'o', label="Markers")
ax.plot([0, 1, 2], [0.5, 0.5, 0.5], 's-', label="Line+Markers")
ax.legend()

plotly_fig = tls.mpl_to_plotly(fig)

assert plotly_fig.layout.showlegend == True
assert len(plotly_fig.data) == 3
assert plotly_fig.data[0].name == "Line"
assert plotly_fig.data[1].name == "Markers"
assert plotly_fig.data[2].name == "Line+Markers"
# Verify modes are correct
assert plotly_fig.data[0].mode == "lines"
assert plotly_fig.data[1].mode == "markers"
assert plotly_fig.data[2].mode == "lines+markers"