1
1
import matplotlib .dates as mdates
2
2
import matplotlib .ticker as ticker
3
3
import numpy as np
4
- from ._data import _locate_channel_data , _locate_channel_dtype , _locate_channel_scale , _locate_channel_axis , _convert_to_mpl_date
4
+ from ._data import _convert_to_mpl_date
5
5
6
6
7
- def _set_limits (channel , scale ):
7
+ def _set_limits (channel , mark , ax ):
8
8
"""Set the axis limits on the Matplotlib axis
9
9
10
10
Parameters
11
11
----------
12
- channel : dict
13
- The mapping of the channel data and metadata
14
- scale : dict
15
- The mapping of the scale metadata and the scale data
12
+ channel : parse_chart.ChannelMetadata
13
+ The channel data and metadata
14
+ mark : str
15
+ The chart's mark
16
+ ax : matplotlib.axes
16
17
"""
17
18
18
19
_axis_kwargs = {
@@ -22,136 +23,142 @@ def _set_limits(channel, scale):
22
23
23
24
lims = {}
24
25
25
- if channel [ 'dtype' ] == 'quantitative' :
26
+ if channel . type == 'quantitative' :
26
27
# determine limits
27
- if 'domain' in scale : # domain takes precedence over zero in Altair
28
- if scale ['domain' ] == 'unaggregated' :
28
+ if 'domain' in channel . scale : # domain takes precedence over zero in Altair
29
+ if channel . scale ['domain' ] == 'unaggregated' :
29
30
raise NotImplementedError
30
31
else :
31
- lims [_axis_kwargs [channel [ 'axis' ]] .get ('min' )] = scale ['domain' ][0 ]
32
- lims [_axis_kwargs [channel [ 'axis' ]] .get ('max' )] = scale ['domain' ][1 ]
33
- elif 'type' in scale and scale ['type' ] != 'linear' :
34
- lims = _set_scale_type (channel , scale )
32
+ lims [_axis_kwargs [channel . name ] .get ('min' )] = channel . scale ['domain' ][0 ]
33
+ lims [_axis_kwargs [channel . name ] .get ('max' )] = channel . scale ['domain' ][1 ]
34
+ elif 'type' in channel . scale and channel . scale ['type' ] != 'linear' :
35
+ lims = _set_scale_type (channel , ax )
35
36
else :
36
- # Check that a positive minimum is zero if zero is True:
37
- if ('zero' not in scale or scale ['zero' ] == True ) and min (channel ['data' ]) > 0 :
38
- lims [_axis_kwargs [channel ['axis' ]].get ('min' )] = 0 # quantitative sets min to be 0 by default
37
+ # Include zero on the axis (or not).
38
+ # In Altair, scale.zero defaults to False unless the data is unbinned quantitative.
39
+ if mark == 'line' and channel .name == 'x' :
40
+ # Contrary to documentation, Altair defaults to scale.zero=False for the x-axis on line graphs.
41
+ # Pass to skip.
42
+ pass
43
+ else :
44
+ # Check that a positive minimum is zero if scale.zero is True:
45
+ if ('zero' not in channel .scale or channel .scale ['zero' ] == True ) and min (channel .data ) > 0 :
46
+ lims [_axis_kwargs [channel .name ].get ('min' )] = 0 # quantitative sets min to be 0 by default
39
47
40
- # Check that a negative maximum is zero if zero is True:
41
- if ('zero' not in scale or scale ['zero' ] == True ) and max (channel [ ' data' ] ) < 0 :
42
- lims [_axis_kwargs [channel [ 'axis' ] ].get ('max' )] = 0
48
+ # Check that a negative maximum is zero if scale. zero is True:
49
+ if ('zero' not in channel . scale or channel . scale ['zero' ] == True ) and max (channel . data ) < 0 :
50
+ lims [_axis_kwargs [channel . name ].get ('max' )] = 0
43
51
44
- elif channel [ 'dtype' ] == 'temporal' :
52
+ elif channel . type == 'temporal' :
45
53
# determine limits
46
- if 'domain' in scale :
47
- domain = _convert_to_mpl_date (scale ['domain' ])
48
- lims [_axis_kwargs [channel [ 'axis' ] ].get ('min' )] = domain [0 ]
49
- lims [_axis_kwargs [channel [ 'axis' ] ].get ('max' )] = domain [1 ]
50
- elif 'type' in scale and scale ['type' ] != 'time' :
51
- lims = _set_scale_type (channel , scale )
54
+ if 'domain' in channel . scale :
55
+ domain = _convert_to_mpl_date (channel . scale ['domain' ])
56
+ lims [_axis_kwargs [channel . name ].get ('min' )] = domain [0 ]
57
+ lims [_axis_kwargs [channel . name ].get ('max' )] = domain [1 ]
58
+ elif 'type' in channel . scale and channel . scale ['type' ] != 'time' :
59
+ lims = _set_scale_type (channel , channel . scale )
52
60
53
61
else :
54
62
raise NotImplementedError # Ordinal and Nominal go here?
55
63
56
64
# set the limits
57
- if channel [ 'axis' ] == 'x' :
58
- channel [ 'ax' ] .set_xlim (** lims )
65
+ if channel . name == 'x' :
66
+ ax .set_xlim (** lims )
59
67
else :
60
- channel [ 'ax' ] .set_ylim (** lims )
68
+ ax .set_ylim (** lims )
61
69
62
70
63
- def _set_scale_type (channel , scale ):
71
+ def _set_scale_type (channel , ax ):
64
72
"""If the scale is non-linear, change the scale and return appropriate axis limits.
65
73
The 'linear' and 'time' scale types are not included here because quantitative defaults to 'linear'
66
74
and temporal defaults to 'time'. The 'utc' and 'sequential' scales are currently not supported.
67
75
68
76
Parameters
69
77
----------
70
- channel : dict
71
- The mapping of the channel data and metadata
72
- scale : dict
73
- The mapping of the scale metadata and the scale data
78
+ channel : parse_chart.ChannelMetadata
79
+ The channel data and metadata
80
+ ax : matplotlib.axes
74
81
75
82
Returns
76
83
-------
77
84
lims : dict
78
85
The axis limit mapped to the appropriate axis parameter for scales that change axis limit behavior
79
86
"""
80
87
lims = {}
81
- if scale ['type' ] == 'log' :
88
+ if channel . scale ['type' ] == 'log' :
82
89
83
90
base = 10 # default base is 10 in altair
84
- if 'base' in scale :
85
- base = scale ['base' ]
91
+ if 'base' in channel . scale :
92
+ base = channel . scale ['base' ]
86
93
87
- if channel [ 'axis' ] == 'x' :
88
- channel [ 'ax' ] .set_xscale ('log' , basex = base )
94
+ if channel . name == 'x' :
95
+ ax .set_xscale ('log' , basex = base )
89
96
# lower limit: round down to nearest major tick (using log base change rule)
90
- lims ['left' ] = base ** np .floor (np .log10 (channel [ ' data' ] .min ())/ np .log10 (base ))
97
+ lims ['left' ] = base ** np .floor (np .log10 (channel . data .min ())/ np .log10 (base ))
91
98
else : # y-axis
92
- channel [ 'ax' ] .set_yscale ('log' , basey = base )
99
+ ax .set_yscale ('log' , basey = base )
93
100
# lower limit: round down to nearest major tick (using log base change rule)
94
- lims ['bottom' ] = base ** np .floor (np .log10 (channel [ ' data' ] .min ())/ np .log10 (base ))
101
+ lims ['bottom' ] = base ** np .floor (np .log10 (channel . data .min ())/ np .log10 (base ))
95
102
96
- elif scale ['type' ] == 'pow' or scale ['type' ] == 'sqrt' :
103
+ elif channel . scale ['type' ] == 'pow' or channel . scale ['type' ] == 'sqrt' :
97
104
"""The 'sqrt' scale is just the 'pow' scale with exponent = 0.5.
98
105
When Matplotlib gets a power scale, the following should work:
99
106
100
107
exponent = 2 # default exponent value for 'pow' scale
101
- if scale['type'] == 'sqrt':
108
+ if channel. scale['type'] == 'sqrt':
102
109
exponent = 0.5
103
- elif 'exponent' in scale:
104
- exponent = scale['exponent']
110
+ elif 'exponent' in channel. scale:
111
+ exponent = channel. scale['exponent']
105
112
106
- if channel['axis'] == 'x':
107
- channel['ax'] .set_xscale('power_scale', exponent=exponent)
113
+ if channel.name == 'x':
114
+ ax .set_xscale('power_scale', exponent=exponent)
108
115
else: # y-axis
109
- channel['ax'] .set_yscale('power_scale', exponent=exponent)
116
+ ax .set_yscale('power_scale', exponent=exponent)
110
117
"""
111
118
raise NotImplementedError
112
119
113
- elif scale ['type' ] == 'utc' :
120
+ elif channel . scale ['type' ] == 'utc' :
114
121
raise NotImplementedError
115
- elif scale ['type' ] == 'sequential' :
122
+ elif channel . scale ['type' ] == 'sequential' :
116
123
raise NotImplementedError ("sequential scales used primarily for continuous colors" )
117
124
else :
118
125
raise NotImplementedError
119
126
return lims
120
127
121
128
122
- def _set_tick_locator (channel , axis ):
129
+ def _set_tick_locator (channel , ax ):
123
130
"""Set the tick locator if it needs to vary from the default locator
124
131
125
132
Parameters
126
133
----------
127
- channel : dict
128
- The mapping of the channel data and metadata
129
- axis : dict
134
+ channel : parse_chart.ChannelMetadata
135
+ The channel data and metadata
136
+ ax : matplotlib.axes
130
137
The mapping of the axis metadata and the scale data
131
138
"""
132
- current_axis = {'x' : channel [ 'ax' ] .xaxis , 'y' : channel [ 'ax' ] .yaxis }
133
- if 'values' in axis :
134
- if channel [ 'dtype' ] == 'temporal' :
135
- current_axis [channel [ 'axis' ]] .set_major_locator (ticker .FixedLocator (_convert_to_mpl_date (axis .get ('values' ))))
136
- elif channel [ 'dtype' ] == 'quantitative' :
137
- current_axis [channel [ 'axis' ]] .set_major_locator (ticker .FixedLocator (axis .get ('values' )))
139
+ current_axis = {'x' : ax .xaxis , 'y' : ax .yaxis }
140
+ if 'values' in channel . axis :
141
+ if channel . type == 'temporal' :
142
+ current_axis [channel . name ] .set_major_locator (ticker .FixedLocator (_convert_to_mpl_date (channel . axis .get ('values' ))))
143
+ elif channel . type == 'quantitative' :
144
+ current_axis [channel . name ] .set_major_locator (ticker .FixedLocator (channel . axis .get ('values' )))
138
145
else :
139
146
raise NotImplementedError
140
- elif 'tickCount' in axis :
141
- current_axis [channel [ 'axis' ] ].set_major_locator (
142
- ticker .MaxNLocator (steps = [2 , 5 , 10 ], nbins = axis .get ('tickCount' )+ 1 , min_n_ticks = axis .get ('tickCount' ))
147
+ elif 'tickCount' in channel . axis :
148
+ current_axis [channel . name ].set_major_locator (
149
+ ticker .MaxNLocator (steps = [2 , 5 , 10 ], nbins = channel . axis .get ('tickCount' )+ 1 , min_n_ticks = channel . axis .get ('tickCount' ))
143
150
)
144
151
145
152
146
- def _set_tick_formatter (channel , axis ):
153
+ def _set_tick_formatter (channel , ax ):
147
154
"""Set the tick formatter.
148
155
149
156
150
157
Parameters
151
158
----------
152
- channel : dict
153
- The mapping of the channel data and metadata
154
- axis : dict
159
+ channel : parse_chart.ChannelMetadata
160
+ The channel data and metadata
161
+ ax : matplotlib.axes
155
162
The mapping of the axis metadata and the scale data
156
163
157
164
Notes
@@ -162,25 +169,22 @@ def _set_tick_formatter(channel, axis):
162
169
For formatting of temporal data, Matplotlib does not support some format strings that Altair supports (%L, %Q, %s).
163
170
Matplotlib only supports datetime.strftime formatting for dates.
164
171
"""
165
- current_axis = {'x' : channel ['ax' ].xaxis , 'y' : channel ['ax' ].yaxis }
166
- format_str = ''
167
-
168
- if 'format' in axis :
169
- format_str = axis ['format' ]
172
+ current_axis = {'x' : ax .xaxis , 'y' : ax .yaxis }
173
+ format_str = channel .axis .get ('format' , '' )
170
174
171
- if channel [ 'dtype' ] == 'temporal' :
175
+ if channel . type == 'temporal' :
172
176
if not format_str :
173
177
format_str = '%b %d, %Y'
174
178
175
- current_axis [channel [ 'axis' ] ].set_major_formatter (mdates .DateFormatter (format_str )) # May fail silently
179
+ current_axis [channel . name ].set_major_formatter (mdates .DateFormatter (format_str )) # May fail silently
176
180
177
- elif channel [ 'dtype' ] == 'quantitative' :
181
+ elif channel . type == 'quantitative' :
178
182
if format_str :
179
- current_axis [channel [ 'axis' ] ].set_major_formatter (ticker .StrMethodFormatter ('{x:' + format_str + '}' ))
183
+ current_axis [channel . name ].set_major_formatter (ticker .StrMethodFormatter ('{x:' + format_str + '}' ))
180
184
181
185
# Verify that the format string is valid for Matplotlib and exit nicely if not.
182
186
try :
183
- current_axis [channel [ 'axis' ] ].get_major_formatter ().__call__ (1 )
187
+ current_axis [channel . name ].get_major_formatter ().__call__ (1 )
184
188
except ValueError :
185
189
raise ValueError ("Matplotlib only supports format strings as used by `str.format()`."
186
190
"Some format strings that work in Altair may not work in Matplotlib."
@@ -189,18 +193,18 @@ def _set_tick_formatter(channel, axis):
189
193
raise NotImplementedError # Nominal and Ordinal go here
190
194
191
195
192
- def _set_label_angle (channel , axis ):
196
+ def _set_label_angle (channel , ax ):
193
197
"""Set the label angle. TODO: handle axis.labelAngle from Altair
194
198
195
199
Parameters
196
200
----------
197
- channel : dict
198
- The mapping of the channel data and metadata
199
- axis : dict
201
+ channel : parse_chart.ChannelMetadata
202
+ The channel data and metadata
203
+ ax : matplotlib.axes
200
204
The mapping of the axis metadata and the scale data
201
205
"""
202
- if channel [ 'dtype' ] == 'temporal' and channel [ 'axis' ] == 'x' :
203
- for label in channel [ 'ax' ] .get_xticklabels ():
206
+ if channel . type == 'temporal' and channel . name == 'x' :
207
+ for label in ax .get_xticklabels ():
204
208
# Rotate the labels on the x-axis so they don't run into each other.
205
209
label .set_rotation (30 )
206
210
label .set_ha ('right' )
@@ -213,22 +217,12 @@ def convert_axis(ax, chart):
213
217
----------
214
218
ax
215
219
The Matplotlib axis to be modified
216
- chart
217
- The Altair chart
220
+ chart : parse_chart.ChartMetadata
221
+ The chart data and metadata
218
222
"""
219
223
220
- for channel in chart .to_dict ()['encoding' ]:
221
- if channel in ['x' , 'y' ]:
222
- chart_info = {'ax' : ax , 'axis' : channel ,
223
- 'data' : _locate_channel_data (chart , channel ),
224
- 'dtype' : _locate_channel_dtype (chart , channel )}
225
- if chart_info ['dtype' ] == 'temporal' :
226
- chart_info ['data' ] = _convert_to_mpl_date (chart_info ['data' ])
227
-
228
- scale_info = _locate_channel_scale (chart , channel )
229
- axis_info = _locate_channel_axis (chart , channel )
230
-
231
- _set_limits (chart_info , scale_info )
232
- _set_tick_locator (chart_info , axis_info )
233
- _set_tick_formatter (chart_info , axis_info )
234
- _set_label_angle (chart_info , axis_info )
224
+ for channel in [chart .encoding ['x' ], chart .encoding ['y' ]]:
225
+ _set_limits (channel , chart .mark , ax )
226
+ _set_tick_locator (channel , ax )
227
+ _set_tick_formatter (channel , ax )
228
+ _set_label_angle (channel , ax )
0 commit comments