Skip to content

Commit 337ab54

Browse files
authored
Merge pull request #21 from kdorr/refactor-chart-object
Add Line and Refactor chart object
2 parents 7bcab2d + da74595 commit 337ab54

24 files changed

+566
-454
lines changed

mplaltair/__init__.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
1+
import mplaltair.parse_chart
12
import matplotlib
23
import altair
4+
import matplotlib.pyplot as plt
35
from ._convert import _convert
6+
from ._axis import convert_axis
7+
from ._marks import _handle_line
48

59

6-
def convert(chart):
10+
def convert(alt_chart):
711
"""Convert an altair encoding to a Matplotlib figure
812
913
@@ -14,10 +18,21 @@ def convert(chart):
1418
1519
Returns
1620
-------
17-
mapping : dict
18-
Mapping from parts of the encoding to the Matplotlib artists. This is
19-
for later customization.
21+
fig : matplotlib.figure
2022
23+
ax : matplotlib.axes
2124
2225
"""
23-
return _convert(chart)
26+
chart = mplaltair.parse_chart.ChartMetadata(alt_chart)
27+
fig, ax = plt.subplots()
28+
29+
if chart.mark in ['point', 'circle', 'square']: # scatter
30+
mapping = _convert(chart)
31+
ax.scatter(**mapping)
32+
elif chart.mark == 'line': # line
33+
_handle_line(chart, ax)
34+
else:
35+
raise NotImplementedError
36+
convert_axis(ax, chart)
37+
fig.tight_layout()
38+
return fig, ax

mplaltair/_axis.py

Lines changed: 95 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,19 @@
11
import matplotlib.dates as mdates
22
import matplotlib.ticker as ticker
33
import numpy as np
4-
from ._data import _locate_channel_data, _locate_channel_dtype, _locate_channel_scale, _locate_channel_axis, _convert_to_mpl_date
4+
from ._data import _convert_to_mpl_date
55

66

7-
def _set_limits(channel, scale):
7+
def _set_limits(channel, mark, ax):
88
"""Set the axis limits on the Matplotlib axis
99
1010
Parameters
1111
----------
12-
channel : dict
13-
The mapping of the channel data and metadata
14-
scale : dict
15-
The mapping of the scale metadata and the scale data
12+
channel : parse_chart.ChannelMetadata
13+
The channel data and metadata
14+
mark : str
15+
The chart's mark
16+
ax : matplotlib.axes
1617
"""
1718

1819
_axis_kwargs = {
@@ -22,136 +23,142 @@ def _set_limits(channel, scale):
2223

2324
lims = {}
2425

25-
if channel['dtype'] == 'quantitative':
26+
if channel.type == 'quantitative':
2627
# determine limits
27-
if 'domain' in scale: # domain takes precedence over zero in Altair
28-
if scale['domain'] == 'unaggregated':
28+
if 'domain' in channel.scale: # domain takes precedence over zero in Altair
29+
if channel.scale['domain'] == 'unaggregated':
2930
raise NotImplementedError
3031
else:
31-
lims[_axis_kwargs[channel['axis']].get('min')] = scale['domain'][0]
32-
lims[_axis_kwargs[channel['axis']].get('max')] = scale['domain'][1]
33-
elif 'type' in scale and scale['type'] != 'linear':
34-
lims = _set_scale_type(channel, scale)
32+
lims[_axis_kwargs[channel.name].get('min')] = channel.scale['domain'][0]
33+
lims[_axis_kwargs[channel.name].get('max')] = channel.scale['domain'][1]
34+
elif 'type' in channel.scale and channel.scale['type'] != 'linear':
35+
lims = _set_scale_type(channel, ax)
3536
else:
36-
# Check that a positive minimum is zero if zero is True:
37-
if ('zero' not in scale or scale['zero'] == True) and min(channel['data']) > 0:
38-
lims[_axis_kwargs[channel['axis']].get('min')] = 0 # quantitative sets min to be 0 by default
37+
# Include zero on the axis (or not).
38+
# In Altair, scale.zero defaults to False unless the data is unbinned quantitative.
39+
if mark == 'line' and channel.name == 'x':
40+
# Contrary to documentation, Altair defaults to scale.zero=False for the x-axis on line graphs.
41+
# Pass to skip.
42+
pass
43+
else:
44+
# Check that a positive minimum is zero if scale.zero is True:
45+
if ('zero' not in channel.scale or channel.scale['zero'] == True) and min(channel.data) > 0:
46+
lims[_axis_kwargs[channel.name].get('min')] = 0 # quantitative sets min to be 0 by default
3947

40-
# Check that a negative maximum is zero if zero is True:
41-
if ('zero' not in scale or scale['zero'] == True) and max(channel['data']) < 0:
42-
lims[_axis_kwargs[channel['axis']].get('max')] = 0
48+
# Check that a negative maximum is zero if scale.zero is True:
49+
if ('zero' not in channel.scale or channel.scale['zero'] == True) and max(channel.data) < 0:
50+
lims[_axis_kwargs[channel.name].get('max')] = 0
4351

44-
elif channel['dtype'] == 'temporal':
52+
elif channel.type == 'temporal':
4553
# determine limits
46-
if 'domain' in scale:
47-
domain = _convert_to_mpl_date(scale['domain'])
48-
lims[_axis_kwargs[channel['axis']].get('min')] = domain[0]
49-
lims[_axis_kwargs[channel['axis']].get('max')] = domain[1]
50-
elif 'type' in scale and scale['type'] != 'time':
51-
lims = _set_scale_type(channel, scale)
54+
if 'domain' in channel.scale:
55+
domain = _convert_to_mpl_date(channel.scale['domain'])
56+
lims[_axis_kwargs[channel.name].get('min')] = domain[0]
57+
lims[_axis_kwargs[channel.name].get('max')] = domain[1]
58+
elif 'type' in channel.scale and channel.scale['type'] != 'time':
59+
lims = _set_scale_type(channel, channel.scale)
5260

5361
else:
5462
raise NotImplementedError # Ordinal and Nominal go here?
5563

5664
# set the limits
57-
if channel['axis'] == 'x':
58-
channel['ax'].set_xlim(**lims)
65+
if channel.name == 'x':
66+
ax.set_xlim(**lims)
5967
else:
60-
channel['ax'].set_ylim(**lims)
68+
ax.set_ylim(**lims)
6169

6270

63-
def _set_scale_type(channel, scale):
71+
def _set_scale_type(channel, ax):
6472
"""If the scale is non-linear, change the scale and return appropriate axis limits.
6573
The 'linear' and 'time' scale types are not included here because quantitative defaults to 'linear'
6674
and temporal defaults to 'time'. The 'utc' and 'sequential' scales are currently not supported.
6775
6876
Parameters
6977
----------
70-
channel : dict
71-
The mapping of the channel data and metadata
72-
scale : dict
73-
The mapping of the scale metadata and the scale data
78+
channel : parse_chart.ChannelMetadata
79+
The channel data and metadata
80+
ax : matplotlib.axes
7481
7582
Returns
7683
-------
7784
lims : dict
7885
The axis limit mapped to the appropriate axis parameter for scales that change axis limit behavior
7986
"""
8087
lims = {}
81-
if scale['type'] == 'log':
88+
if channel.scale['type'] == 'log':
8289

8390
base = 10 # default base is 10 in altair
84-
if 'base' in scale:
85-
base = scale['base']
91+
if 'base' in channel.scale:
92+
base = channel.scale['base']
8693

87-
if channel['axis'] == 'x':
88-
channel['ax'].set_xscale('log', basex=base)
94+
if channel.name == 'x':
95+
ax.set_xscale('log', basex=base)
8996
# lower limit: round down to nearest major tick (using log base change rule)
90-
lims['left'] = base**np.floor(np.log10(channel['data'].min())/np.log10(base))
97+
lims['left'] = base**np.floor(np.log10(channel.data.min())/np.log10(base))
9198
else: # y-axis
92-
channel['ax'].set_yscale('log', basey=base)
99+
ax.set_yscale('log', basey=base)
93100
# lower limit: round down to nearest major tick (using log base change rule)
94-
lims['bottom'] = base**np.floor(np.log10(channel['data'].min())/np.log10(base))
101+
lims['bottom'] = base**np.floor(np.log10(channel.data.min())/np.log10(base))
95102

96-
elif scale['type'] == 'pow' or scale['type'] == 'sqrt':
103+
elif channel.scale['type'] == 'pow' or channel.scale['type'] == 'sqrt':
97104
"""The 'sqrt' scale is just the 'pow' scale with exponent = 0.5.
98105
When Matplotlib gets a power scale, the following should work:
99106
100107
exponent = 2 # default exponent value for 'pow' scale
101-
if scale['type'] == 'sqrt':
108+
if channel.scale['type'] == 'sqrt':
102109
exponent = 0.5
103-
elif 'exponent' in scale:
104-
exponent = scale['exponent']
110+
elif 'exponent' in channel.scale:
111+
exponent = channel.scale['exponent']
105112
106-
if channel['axis'] == 'x':
107-
channel['ax'].set_xscale('power_scale', exponent=exponent)
113+
if channel.name == 'x':
114+
ax.set_xscale('power_scale', exponent=exponent)
108115
else: # y-axis
109-
channel['ax'].set_yscale('power_scale', exponent=exponent)
116+
ax.set_yscale('power_scale', exponent=exponent)
110117
"""
111118
raise NotImplementedError
112119

113-
elif scale['type'] == 'utc':
120+
elif channel.scale['type'] == 'utc':
114121
raise NotImplementedError
115-
elif scale['type'] == 'sequential':
122+
elif channel.scale['type'] == 'sequential':
116123
raise NotImplementedError("sequential scales used primarily for continuous colors")
117124
else:
118125
raise NotImplementedError
119126
return lims
120127

121128

122-
def _set_tick_locator(channel, axis):
129+
def _set_tick_locator(channel, ax):
123130
"""Set the tick locator if it needs to vary from the default locator
124131
125132
Parameters
126133
----------
127-
channel : dict
128-
The mapping of the channel data and metadata
129-
axis : dict
134+
channel : parse_chart.ChannelMetadata
135+
The channel data and metadata
136+
ax : matplotlib.axes
130137
The mapping of the axis metadata and the scale data
131138
"""
132-
current_axis = {'x': channel['ax'].xaxis, 'y': channel['ax'].yaxis}
133-
if 'values' in axis:
134-
if channel['dtype'] == 'temporal':
135-
current_axis[channel['axis']].set_major_locator(ticker.FixedLocator(_convert_to_mpl_date(axis.get('values'))))
136-
elif channel['dtype'] == 'quantitative':
137-
current_axis[channel['axis']].set_major_locator(ticker.FixedLocator(axis.get('values')))
139+
current_axis = {'x': ax.xaxis, 'y': ax.yaxis}
140+
if 'values' in channel.axis:
141+
if channel.type == 'temporal':
142+
current_axis[channel.name].set_major_locator(ticker.FixedLocator(_convert_to_mpl_date(channel.axis.get('values'))))
143+
elif channel.type == 'quantitative':
144+
current_axis[channel.name].set_major_locator(ticker.FixedLocator(channel.axis.get('values')))
138145
else:
139146
raise NotImplementedError
140-
elif 'tickCount' in axis:
141-
current_axis[channel['axis']].set_major_locator(
142-
ticker.MaxNLocator(steps=[2, 5, 10], nbins=axis.get('tickCount')+1, min_n_ticks=axis.get('tickCount'))
147+
elif 'tickCount' in channel.axis:
148+
current_axis[channel.name].set_major_locator(
149+
ticker.MaxNLocator(steps=[2, 5, 10], nbins=channel.axis.get('tickCount')+1, min_n_ticks=channel.axis.get('tickCount'))
143150
)
144151

145152

146-
def _set_tick_formatter(channel, axis):
153+
def _set_tick_formatter(channel, ax):
147154
"""Set the tick formatter.
148155
149156
150157
Parameters
151158
----------
152-
channel : dict
153-
The mapping of the channel data and metadata
154-
axis : dict
159+
channel : parse_chart.ChannelMetadata
160+
The channel data and metadata
161+
ax : matplotlib.axes
155162
The mapping of the axis metadata and the scale data
156163
157164
Notes
@@ -162,25 +169,22 @@ def _set_tick_formatter(channel, axis):
162169
For formatting of temporal data, Matplotlib does not support some format strings that Altair supports (%L, %Q, %s).
163170
Matplotlib only supports datetime.strftime formatting for dates.
164171
"""
165-
current_axis = {'x': channel['ax'].xaxis, 'y': channel['ax'].yaxis}
166-
format_str = ''
167-
168-
if 'format' in axis:
169-
format_str = axis['format']
172+
current_axis = {'x': ax.xaxis, 'y': ax.yaxis}
173+
format_str = channel.axis.get('format', '')
170174

171-
if channel['dtype'] == 'temporal':
175+
if channel.type == 'temporal':
172176
if not format_str:
173177
format_str = '%b %d, %Y'
174178

175-
current_axis[channel['axis']].set_major_formatter(mdates.DateFormatter(format_str)) # May fail silently
179+
current_axis[channel.name].set_major_formatter(mdates.DateFormatter(format_str)) # May fail silently
176180

177-
elif channel['dtype'] == 'quantitative':
181+
elif channel.type == 'quantitative':
178182
if format_str:
179-
current_axis[channel['axis']].set_major_formatter(ticker.StrMethodFormatter('{x:' + format_str + '}'))
183+
current_axis[channel.name].set_major_formatter(ticker.StrMethodFormatter('{x:' + format_str + '}'))
180184

181185
# Verify that the format string is valid for Matplotlib and exit nicely if not.
182186
try:
183-
current_axis[channel['axis']].get_major_formatter().__call__(1)
187+
current_axis[channel.name].get_major_formatter().__call__(1)
184188
except ValueError:
185189
raise ValueError("Matplotlib only supports format strings as used by `str.format()`."
186190
"Some format strings that work in Altair may not work in Matplotlib."
@@ -189,18 +193,18 @@ def _set_tick_formatter(channel, axis):
189193
raise NotImplementedError # Nominal and Ordinal go here
190194

191195

192-
def _set_label_angle(channel, axis):
196+
def _set_label_angle(channel, ax):
193197
"""Set the label angle. TODO: handle axis.labelAngle from Altair
194198
195199
Parameters
196200
----------
197-
channel : dict
198-
The mapping of the channel data and metadata
199-
axis : dict
201+
channel : parse_chart.ChannelMetadata
202+
The channel data and metadata
203+
ax : matplotlib.axes
200204
The mapping of the axis metadata and the scale data
201205
"""
202-
if channel['dtype'] == 'temporal' and channel['axis'] == 'x':
203-
for label in channel['ax'].get_xticklabels():
206+
if channel.type == 'temporal' and channel.name == 'x':
207+
for label in ax.get_xticklabels():
204208
# Rotate the labels on the x-axis so they don't run into each other.
205209
label.set_rotation(30)
206210
label.set_ha('right')
@@ -213,22 +217,12 @@ def convert_axis(ax, chart):
213217
----------
214218
ax
215219
The Matplotlib axis to be modified
216-
chart
217-
The Altair chart
220+
chart : parse_chart.ChartMetadata
221+
The chart data and metadata
218222
"""
219223

220-
for channel in chart.to_dict()['encoding']:
221-
if channel in ['x', 'y']:
222-
chart_info = {'ax': ax, 'axis': channel,
223-
'data': _locate_channel_data(chart, channel),
224-
'dtype': _locate_channel_dtype(chart, channel)}
225-
if chart_info['dtype'] == 'temporal':
226-
chart_info['data'] = _convert_to_mpl_date(chart_info['data'])
227-
228-
scale_info = _locate_channel_scale(chart, channel)
229-
axis_info = _locate_channel_axis(chart, channel)
230-
231-
_set_limits(chart_info, scale_info)
232-
_set_tick_locator(chart_info, axis_info)
233-
_set_tick_formatter(chart_info, axis_info)
234-
_set_label_angle(chart_info, axis_info)
224+
for channel in [chart.encoding['x'], chart.encoding['y']]:
225+
_set_limits(channel, chart.mark, ax)
226+
_set_tick_locator(channel, ax)
227+
_set_tick_formatter(channel, ax)
228+
_set_label_angle(channel, ax)

0 commit comments

Comments
 (0)