Skip to content

Commit aed5d30

Browse files
committed
Set matplotlib axis scale type earlier so axis limits are correct
When the axis scale is changed to "log", matplotlib changes the default axis limits from (0, 1) to something around (0.1, 10). Since MplDrawer got the limits and the set them back on the axes after setting the scale, it was reapplying a minimum of 0 to the log scale plot, leading to a user warning. matplotlib by default changes the limits to the range of the data when data is added, so this case only came up when generating a figure with no data.
1 parent 9805d50 commit aed5d30

File tree

2 files changed

+25
-17
lines changed

2 files changed

+25
-17
lines changed

qiskit_experiments/visualization/drawers/mpl_drawer.py

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,31 @@ def format_canvas(self):
147147
else:
148148
all_axes = [self._axis]
149149

150+
# Set axes scale. This needs to be done before anything tries work with
151+
# the axis limits because if no limits or data are set explicitly the
152+
# default limits depend on the scale method (for example, the minimum
153+
# value is 0 for linear scaling but not for log scaling).
154+
def signed_sqrt(x):
155+
return np.sign(x) * np.sqrt(abs(x))
156+
157+
def signed_square(x):
158+
return np.sign(x) * x**2
159+
160+
for ax_type in ("x", "y"):
161+
for sub_ax in all_axes:
162+
scale = self.figure_options.get(f"{ax_type}scale")
163+
if ax_type == "x":
164+
mpl_setscale = sub_ax.set_xscale
165+
else:
166+
mpl_setscale = sub_ax.set_yscale
167+
168+
# Apply non linear axis spacing
169+
if scale is not None:
170+
if scale == "quadratic":
171+
mpl_setscale("function", functions=(signed_square, signed_sqrt))
172+
else:
173+
mpl_setscale(scale)
174+
150175
# Get axis formatter from drawing options
151176
formatter_opts = {}
152177
for ax_type in ("x", "y"):
@@ -181,12 +206,6 @@ def format_canvas(self):
181206
"max_ax_vals": max_vals,
182207
}
183208

184-
def signed_sqrt(x):
185-
return np.sign(x) * np.sqrt(abs(x))
186-
187-
def signed_square(x):
188-
return np.sign(x) * x**2
189-
190209
for i, sub_ax in enumerate(all_axes):
191210
# Add data labels if there are multiple labels registered per sub_ax.
192211
_, labels = sub_ax.get_legend_handles_labels()
@@ -197,18 +216,15 @@ def signed_square(x):
197216
limit = formatter_opts[ax_type]["limit"][i]
198217
unit = formatter_opts[ax_type]["unit"][i]
199218
unit_scale = formatter_opts[ax_type]["unit_scale"][i]
200-
scale = self.figure_options.get(f"{ax_type}scale")
201219
min_ax_vals = formatter_opts[ax_type]["min_ax_vals"]
202220
max_ax_vals = formatter_opts[ax_type]["max_ax_vals"]
203221
share_axis = self.figure_options.get(f"share{ax_type}")
204222

205223
if ax_type == "x":
206-
mpl_setscale = sub_ax.set_xscale
207224
mpl_axis_obj = getattr(sub_ax, "xaxis")
208225
mpl_setlimit = sub_ax.set_xlim
209226
mpl_share = sub_ax.sharex
210227
else:
211-
mpl_setscale = sub_ax.set_yscale
212228
mpl_axis_obj = getattr(sub_ax, "yaxis")
213229
mpl_setlimit = sub_ax.set_ylim
214230
mpl_share = sub_ax.sharey
@@ -219,13 +235,6 @@ def signed_square(x):
219235
else:
220236
limit = min_ax_vals[i], max_ax_vals[i]
221237

222-
# Apply non linear axis spacing
223-
if scale is not None:
224-
if scale == "quadratic":
225-
mpl_setscale("function", functions=(signed_square, signed_sqrt))
226-
else:
227-
mpl_setscale(scale)
228-
229238
# Create formatter for axis tick label notation
230239
if unit and unit_scale:
231240
# If value is specified, automatically scale axis magnitude

test/visualization/test_plotter_mpldrawer.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,6 @@ def test_scale(self):
104104
plotter.set_figure_options(
105105
xscale="quadratic",
106106
yscale="log",
107-
ylim=(0.1, 1.0),
108107
)
109108

110109
plotter.figure()

0 commit comments

Comments
 (0)