Skip to content

Add Line and Refactor chart object #21

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 24 commits into from
Aug 20, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 20 additions & 5 deletions mplaltair/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
import mplaltair.parse_chart
import matplotlib
import altair
import matplotlib.pyplot as plt
from ._convert import _convert
from ._axis import convert_axis
from ._marks import _handle_line


def convert(chart):
def convert(alt_chart):
"""Convert an altair encoding to a Matplotlib figure


Expand All @@ -14,10 +18,21 @@ def convert(chart):

Returns
-------
mapping : dict
Mapping from parts of the encoding to the Matplotlib artists. This is
for later customization.
fig : matplotlib.figure

ax : matplotlib.axes

"""
return _convert(chart)
chart = mplaltair.parse_chart.ChartMetadata(alt_chart)
fig, ax = plt.subplots()

if chart.mark in ['point', 'circle', 'square']: # scatter
mapping = _convert(chart)
ax.scatter(**mapping)
elif chart.mark == 'line': # line
_handle_line(chart, ax)
else:
raise NotImplementedError
convert_axis(ax, chart)
fig.tight_layout()
return fig, ax
196 changes: 95 additions & 101 deletions mplaltair/_axis.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
import matplotlib.dates as mdates
import matplotlib.ticker as ticker
import numpy as np
from ._data import _locate_channel_data, _locate_channel_dtype, _locate_channel_scale, _locate_channel_axis, _convert_to_mpl_date
from ._data import _convert_to_mpl_date


def _set_limits(channel, scale):
def _set_limits(channel, mark, ax):
"""Set the axis limits on the Matplotlib axis

Parameters
----------
channel : dict
The mapping of the channel data and metadata
scale : dict
The mapping of the scale metadata and the scale data
channel : parse_chart.ChannelMetadata
The channel data and metadata
mark : str
The chart's mark
ax : matplotlib.axes
"""

_axis_kwargs = {
Expand All @@ -22,136 +23,142 @@ def _set_limits(channel, scale):

lims = {}

if channel['dtype'] == 'quantitative':
if channel.type == 'quantitative':
# determine limits
if 'domain' in scale: # domain takes precedence over zero in Altair
if scale['domain'] == 'unaggregated':
if 'domain' in channel.scale: # domain takes precedence over zero in Altair
if channel.scale['domain'] == 'unaggregated':
raise NotImplementedError
else:
lims[_axis_kwargs[channel['axis']].get('min')] = scale['domain'][0]
lims[_axis_kwargs[channel['axis']].get('max')] = scale['domain'][1]
elif 'type' in scale and scale['type'] != 'linear':
lims = _set_scale_type(channel, scale)
lims[_axis_kwargs[channel.name].get('min')] = channel.scale['domain'][0]
lims[_axis_kwargs[channel.name].get('max')] = channel.scale['domain'][1]
elif 'type' in channel.scale and channel.scale['type'] != 'linear':
lims = _set_scale_type(channel, ax)
else:
# Check that a positive minimum is zero if zero is True:
if ('zero' not in scale or scale['zero'] == True) and min(channel['data']) > 0:
lims[_axis_kwargs[channel['axis']].get('min')] = 0 # quantitative sets min to be 0 by default
# Include zero on the axis (or not).
# In Altair, scale.zero defaults to False unless the data is unbinned quantitative.
if mark == 'line' and channel.name == 'x':
# Contrary to documentation, Altair defaults to scale.zero=False for the x-axis on line graphs.
# Pass to skip.
pass
else:
# Check that a positive minimum is zero if scale.zero is True:
if ('zero' not in channel.scale or channel.scale['zero'] == True) and min(channel.data) > 0:
lims[_axis_kwargs[channel.name].get('min')] = 0 # quantitative sets min to be 0 by default

# Check that a negative maximum is zero if zero is True:
if ('zero' not in scale or scale['zero'] == True) and max(channel['data']) < 0:
lims[_axis_kwargs[channel['axis']].get('max')] = 0
# Check that a negative maximum is zero if scale.zero is True:
if ('zero' not in channel.scale or channel.scale['zero'] == True) and max(channel.data) < 0:
lims[_axis_kwargs[channel.name].get('max')] = 0

elif channel['dtype'] == 'temporal':
elif channel.type == 'temporal':
# determine limits
if 'domain' in scale:
domain = _convert_to_mpl_date(scale['domain'])
lims[_axis_kwargs[channel['axis']].get('min')] = domain[0]
lims[_axis_kwargs[channel['axis']].get('max')] = domain[1]
elif 'type' in scale and scale['type'] != 'time':
lims = _set_scale_type(channel, scale)
if 'domain' in channel.scale:
domain = _convert_to_mpl_date(channel.scale['domain'])
lims[_axis_kwargs[channel.name].get('min')] = domain[0]
lims[_axis_kwargs[channel.name].get('max')] = domain[1]
elif 'type' in channel.scale and channel.scale['type'] != 'time':
lims = _set_scale_type(channel, channel.scale)

else:
raise NotImplementedError # Ordinal and Nominal go here?

# set the limits
if channel['axis'] == 'x':
channel['ax'].set_xlim(**lims)
if channel.name == 'x':
ax.set_xlim(**lims)
else:
channel['ax'].set_ylim(**lims)
ax.set_ylim(**lims)


def _set_scale_type(channel, scale):
def _set_scale_type(channel, ax):
"""If the scale is non-linear, change the scale and return appropriate axis limits.
The 'linear' and 'time' scale types are not included here because quantitative defaults to 'linear'
and temporal defaults to 'time'. The 'utc' and 'sequential' scales are currently not supported.

Parameters
----------
channel : dict
The mapping of the channel data and metadata
scale : dict
The mapping of the scale metadata and the scale data
channel : parse_chart.ChannelMetadata
The channel data and metadata
ax : matplotlib.axes

Returns
-------
lims : dict
The axis limit mapped to the appropriate axis parameter for scales that change axis limit behavior
"""
lims = {}
if scale['type'] == 'log':
if channel.scale['type'] == 'log':

base = 10 # default base is 10 in altair
if 'base' in scale:
base = scale['base']
if 'base' in channel.scale:
base = channel.scale['base']

if channel['axis'] == 'x':
channel['ax'].set_xscale('log', basex=base)
if channel.name == 'x':
ax.set_xscale('log', basex=base)
# lower limit: round down to nearest major tick (using log base change rule)
lims['left'] = base**np.floor(np.log10(channel['data'].min())/np.log10(base))
lims['left'] = base**np.floor(np.log10(channel.data.min())/np.log10(base))
else: # y-axis
channel['ax'].set_yscale('log', basey=base)
ax.set_yscale('log', basey=base)
# lower limit: round down to nearest major tick (using log base change rule)
lims['bottom'] = base**np.floor(np.log10(channel['data'].min())/np.log10(base))
lims['bottom'] = base**np.floor(np.log10(channel.data.min())/np.log10(base))

elif scale['type'] == 'pow' or scale['type'] == 'sqrt':
elif channel.scale['type'] == 'pow' or channel.scale['type'] == 'sqrt':
"""The 'sqrt' scale is just the 'pow' scale with exponent = 0.5.
When Matplotlib gets a power scale, the following should work:

exponent = 2 # default exponent value for 'pow' scale
if scale['type'] == 'sqrt':
if channel.scale['type'] == 'sqrt':
exponent = 0.5
elif 'exponent' in scale:
exponent = scale['exponent']
elif 'exponent' in channel.scale:
exponent = channel.scale['exponent']

if channel['axis'] == 'x':
channel['ax'].set_xscale('power_scale', exponent=exponent)
if channel.name == 'x':
ax.set_xscale('power_scale', exponent=exponent)
else: # y-axis
channel['ax'].set_yscale('power_scale', exponent=exponent)
ax.set_yscale('power_scale', exponent=exponent)
"""
raise NotImplementedError

elif scale['type'] == 'utc':
elif channel.scale['type'] == 'utc':
raise NotImplementedError
elif scale['type'] == 'sequential':
elif channel.scale['type'] == 'sequential':
raise NotImplementedError("sequential scales used primarily for continuous colors")
else:
raise NotImplementedError
return lims


def _set_tick_locator(channel, axis):
def _set_tick_locator(channel, ax):
"""Set the tick locator if it needs to vary from the default locator

Parameters
----------
channel : dict
The mapping of the channel data and metadata
axis : dict
channel : parse_chart.ChannelMetadata
The channel data and metadata
ax : matplotlib.axes
The mapping of the axis metadata and the scale data
"""
current_axis = {'x': channel['ax'].xaxis, 'y': channel['ax'].yaxis}
if 'values' in axis:
if channel['dtype'] == 'temporal':
current_axis[channel['axis']].set_major_locator(ticker.FixedLocator(_convert_to_mpl_date(axis.get('values'))))
elif channel['dtype'] == 'quantitative':
current_axis[channel['axis']].set_major_locator(ticker.FixedLocator(axis.get('values')))
current_axis = {'x': ax.xaxis, 'y': ax.yaxis}
if 'values' in channel.axis:
if channel.type == 'temporal':
current_axis[channel.name].set_major_locator(ticker.FixedLocator(_convert_to_mpl_date(channel.axis.get('values'))))
elif channel.type == 'quantitative':
current_axis[channel.name].set_major_locator(ticker.FixedLocator(channel.axis.get('values')))
else:
raise NotImplementedError
elif 'tickCount' in axis:
current_axis[channel['axis']].set_major_locator(
ticker.MaxNLocator(steps=[2, 5, 10], nbins=axis.get('tickCount')+1, min_n_ticks=axis.get('tickCount'))
elif 'tickCount' in channel.axis:
current_axis[channel.name].set_major_locator(
ticker.MaxNLocator(steps=[2, 5, 10], nbins=channel.axis.get('tickCount')+1, min_n_ticks=channel.axis.get('tickCount'))
)


def _set_tick_formatter(channel, axis):
def _set_tick_formatter(channel, ax):
"""Set the tick formatter.


Parameters
----------
channel : dict
The mapping of the channel data and metadata
axis : dict
channel : parse_chart.ChannelMetadata
The channel data and metadata
ax : matplotlib.axes
The mapping of the axis metadata and the scale data

Notes
Expand All @@ -162,25 +169,22 @@ def _set_tick_formatter(channel, axis):
For formatting of temporal data, Matplotlib does not support some format strings that Altair supports (%L, %Q, %s).
Matplotlib only supports datetime.strftime formatting for dates.
"""
current_axis = {'x': channel['ax'].xaxis, 'y': channel['ax'].yaxis}
format_str = ''

if 'format' in axis:
format_str = axis['format']
current_axis = {'x': ax.xaxis, 'y': ax.yaxis}
format_str = channel.axis.get('format', '')

if channel['dtype'] == 'temporal':
if channel.type == 'temporal':
if not format_str:
format_str = '%b %d, %Y'

current_axis[channel['axis']].set_major_formatter(mdates.DateFormatter(format_str)) # May fail silently
current_axis[channel.name].set_major_formatter(mdates.DateFormatter(format_str)) # May fail silently

elif channel['dtype'] == 'quantitative':
elif channel.type == 'quantitative':
if format_str:
current_axis[channel['axis']].set_major_formatter(ticker.StrMethodFormatter('{x:' + format_str + '}'))
current_axis[channel.name].set_major_formatter(ticker.StrMethodFormatter('{x:' + format_str + '}'))

# Verify that the format string is valid for Matplotlib and exit nicely if not.
try:
current_axis[channel['axis']].get_major_formatter().__call__(1)
current_axis[channel.name].get_major_formatter().__call__(1)
except ValueError:
raise ValueError("Matplotlib only supports format strings as used by `str.format()`."
"Some format strings that work in Altair may not work in Matplotlib."
Expand All @@ -189,18 +193,18 @@ def _set_tick_formatter(channel, axis):
raise NotImplementedError # Nominal and Ordinal go here


def _set_label_angle(channel, axis):
def _set_label_angle(channel, ax):
"""Set the label angle. TODO: handle axis.labelAngle from Altair

Parameters
----------
channel : dict
The mapping of the channel data and metadata
axis : dict
channel : parse_chart.ChannelMetadata
The channel data and metadata
ax : matplotlib.axes
The mapping of the axis metadata and the scale data
"""
if channel['dtype'] == 'temporal' and channel['axis'] == 'x':
for label in channel['ax'].get_xticklabels():
if channel.type == 'temporal' and channel.name == 'x':
for label in ax.get_xticklabels():
# Rotate the labels on the x-axis so they don't run into each other.
label.set_rotation(30)
label.set_ha('right')
Expand All @@ -213,22 +217,12 @@ def convert_axis(ax, chart):
----------
ax
The Matplotlib axis to be modified
chart
The Altair chart
chart : parse_chart.ChartMetadata
The chart data and metadata
"""

for channel in chart.to_dict()['encoding']:
if channel in ['x', 'y']:
chart_info = {'ax': ax, 'axis': channel,
'data': _locate_channel_data(chart, channel),
'dtype': _locate_channel_dtype(chart, channel)}
if chart_info['dtype'] == 'temporal':
chart_info['data'] = _convert_to_mpl_date(chart_info['data'])

scale_info = _locate_channel_scale(chart, channel)
axis_info = _locate_channel_axis(chart, channel)

_set_limits(chart_info, scale_info)
_set_tick_locator(chart_info, axis_info)
_set_tick_formatter(chart_info, axis_info)
_set_label_angle(chart_info, axis_info)
for channel in [chart.encoding['x'], chart.encoding['y']]:
_set_limits(channel, chart.mark, ax)
_set_tick_locator(channel, ax)
_set_tick_formatter(channel, ax)
_set_label_angle(channel, ax)
Loading