diff --git a/mplaltair/__init__.py b/mplaltair/__init__.py index 1f3d019..8f0771b 100644 --- a/mplaltair/__init__.py +++ b/mplaltair/__init__.py @@ -1,9 +1,13 @@ +import mplaltair.parse_chart import matplotlib import altair +import matplotlib.pyplot as plt from ._convert import _convert +from ._axis import convert_axis +from ._marks import _handle_line -def convert(chart): +def convert(alt_chart): """Convert an altair encoding to a Matplotlib figure @@ -14,10 +18,21 @@ def convert(chart): Returns ------- - mapping : dict - Mapping from parts of the encoding to the Matplotlib artists. This is - for later customization. + fig : matplotlib.figure + ax : matplotlib.axes """ - return _convert(chart) + chart = mplaltair.parse_chart.ChartMetadata(alt_chart) + fig, ax = plt.subplots() + + if chart.mark in ['point', 'circle', 'square']: # scatter + mapping = _convert(chart) + ax.scatter(**mapping) + elif chart.mark == 'line': # line + _handle_line(chart, ax) + else: + raise NotImplementedError + convert_axis(ax, chart) + fig.tight_layout() + return fig, ax diff --git a/mplaltair/_axis.py b/mplaltair/_axis.py index 14c7d54..87fbf3f 100644 --- a/mplaltair/_axis.py +++ b/mplaltair/_axis.py @@ -1,18 +1,19 @@ import matplotlib.dates as mdates import matplotlib.ticker as ticker import numpy as np -from ._data import _locate_channel_data, _locate_channel_dtype, _locate_channel_scale, _locate_channel_axis, _convert_to_mpl_date +from ._data import _convert_to_mpl_date -def _set_limits(channel, scale): +def _set_limits(channel, mark, ax): """Set the axis limits on the Matplotlib axis Parameters ---------- - channel : dict - The mapping of the channel data and metadata - scale : dict - The mapping of the scale metadata and the scale data + channel : parse_chart.ChannelMetadata + The channel data and metadata + mark : str + The chart's mark + ax : matplotlib.axes """ _axis_kwargs = { @@ -22,55 +23,61 @@ def _set_limits(channel, scale): lims = {} - if channel['dtype'] == 'quantitative': + if channel.type == 'quantitative': # determine limits - if 'domain' in scale: # domain takes precedence over zero in Altair - if scale['domain'] == 'unaggregated': + if 'domain' in channel.scale: # domain takes precedence over zero in Altair + if channel.scale['domain'] == 'unaggregated': raise NotImplementedError else: - lims[_axis_kwargs[channel['axis']].get('min')] = scale['domain'][0] - lims[_axis_kwargs[channel['axis']].get('max')] = scale['domain'][1] - elif 'type' in scale and scale['type'] != 'linear': - lims = _set_scale_type(channel, scale) + lims[_axis_kwargs[channel.name].get('min')] = channel.scale['domain'][0] + lims[_axis_kwargs[channel.name].get('max')] = channel.scale['domain'][1] + elif 'type' in channel.scale and channel.scale['type'] != 'linear': + lims = _set_scale_type(channel, ax) else: - # Check that a positive minimum is zero if zero is True: - if ('zero' not in scale or scale['zero'] == True) and min(channel['data']) > 0: - lims[_axis_kwargs[channel['axis']].get('min')] = 0 # quantitative sets min to be 0 by default + # Include zero on the axis (or not). + # In Altair, scale.zero defaults to False unless the data is unbinned quantitative. + if mark == 'line' and channel.name == 'x': + # Contrary to documentation, Altair defaults to scale.zero=False for the x-axis on line graphs. + # Pass to skip. + pass + else: + # Check that a positive minimum is zero if scale.zero is True: + if ('zero' not in channel.scale or channel.scale['zero'] == True) and min(channel.data) > 0: + lims[_axis_kwargs[channel.name].get('min')] = 0 # quantitative sets min to be 0 by default - # Check that a negative maximum is zero if zero is True: - if ('zero' not in scale or scale['zero'] == True) and max(channel['data']) < 0: - lims[_axis_kwargs[channel['axis']].get('max')] = 0 + # Check that a negative maximum is zero if scale.zero is True: + if ('zero' not in channel.scale or channel.scale['zero'] == True) and max(channel.data) < 0: + lims[_axis_kwargs[channel.name].get('max')] = 0 - elif channel['dtype'] == 'temporal': + elif channel.type == 'temporal': # determine limits - if 'domain' in scale: - domain = _convert_to_mpl_date(scale['domain']) - lims[_axis_kwargs[channel['axis']].get('min')] = domain[0] - lims[_axis_kwargs[channel['axis']].get('max')] = domain[1] - elif 'type' in scale and scale['type'] != 'time': - lims = _set_scale_type(channel, scale) + if 'domain' in channel.scale: + domain = _convert_to_mpl_date(channel.scale['domain']) + lims[_axis_kwargs[channel.name].get('min')] = domain[0] + lims[_axis_kwargs[channel.name].get('max')] = domain[1] + elif 'type' in channel.scale and channel.scale['type'] != 'time': + lims = _set_scale_type(channel, channel.scale) else: raise NotImplementedError # Ordinal and Nominal go here? # set the limits - if channel['axis'] == 'x': - channel['ax'].set_xlim(**lims) + if channel.name == 'x': + ax.set_xlim(**lims) else: - channel['ax'].set_ylim(**lims) + ax.set_ylim(**lims) -def _set_scale_type(channel, scale): +def _set_scale_type(channel, ax): """If the scale is non-linear, change the scale and return appropriate axis limits. The 'linear' and 'time' scale types are not included here because quantitative defaults to 'linear' and temporal defaults to 'time'. The 'utc' and 'sequential' scales are currently not supported. Parameters ---------- - channel : dict - The mapping of the channel data and metadata - scale : dict - The mapping of the scale metadata and the scale data + channel : parse_chart.ChannelMetadata + The channel data and metadata + ax : matplotlib.axes Returns ------- @@ -78,80 +85,80 @@ def _set_scale_type(channel, scale): The axis limit mapped to the appropriate axis parameter for scales that change axis limit behavior """ lims = {} - if scale['type'] == 'log': + if channel.scale['type'] == 'log': base = 10 # default base is 10 in altair - if 'base' in scale: - base = scale['base'] + if 'base' in channel.scale: + base = channel.scale['base'] - if channel['axis'] == 'x': - channel['ax'].set_xscale('log', basex=base) + if channel.name == 'x': + ax.set_xscale('log', basex=base) # lower limit: round down to nearest major tick (using log base change rule) - lims['left'] = base**np.floor(np.log10(channel['data'].min())/np.log10(base)) + lims['left'] = base**np.floor(np.log10(channel.data.min())/np.log10(base)) else: # y-axis - channel['ax'].set_yscale('log', basey=base) + ax.set_yscale('log', basey=base) # lower limit: round down to nearest major tick (using log base change rule) - lims['bottom'] = base**np.floor(np.log10(channel['data'].min())/np.log10(base)) + lims['bottom'] = base**np.floor(np.log10(channel.data.min())/np.log10(base)) - elif scale['type'] == 'pow' or scale['type'] == 'sqrt': + elif channel.scale['type'] == 'pow' or channel.scale['type'] == 'sqrt': """The 'sqrt' scale is just the 'pow' scale with exponent = 0.5. When Matplotlib gets a power scale, the following should work: exponent = 2 # default exponent value for 'pow' scale - if scale['type'] == 'sqrt': + if channel.scale['type'] == 'sqrt': exponent = 0.5 - elif 'exponent' in scale: - exponent = scale['exponent'] + elif 'exponent' in channel.scale: + exponent = channel.scale['exponent'] - if channel['axis'] == 'x': - channel['ax'].set_xscale('power_scale', exponent=exponent) + if channel.name == 'x': + ax.set_xscale('power_scale', exponent=exponent) else: # y-axis - channel['ax'].set_yscale('power_scale', exponent=exponent) + ax.set_yscale('power_scale', exponent=exponent) """ raise NotImplementedError - elif scale['type'] == 'utc': + elif channel.scale['type'] == 'utc': raise NotImplementedError - elif scale['type'] == 'sequential': + elif channel.scale['type'] == 'sequential': raise NotImplementedError("sequential scales used primarily for continuous colors") else: raise NotImplementedError return lims -def _set_tick_locator(channel, axis): +def _set_tick_locator(channel, ax): """Set the tick locator if it needs to vary from the default locator Parameters ---------- - channel : dict - The mapping of the channel data and metadata - axis : dict + channel : parse_chart.ChannelMetadata + The channel data and metadata + ax : matplotlib.axes The mapping of the axis metadata and the scale data """ - current_axis = {'x': channel['ax'].xaxis, 'y': channel['ax'].yaxis} - if 'values' in axis: - if channel['dtype'] == 'temporal': - current_axis[channel['axis']].set_major_locator(ticker.FixedLocator(_convert_to_mpl_date(axis.get('values')))) - elif channel['dtype'] == 'quantitative': - current_axis[channel['axis']].set_major_locator(ticker.FixedLocator(axis.get('values'))) + current_axis = {'x': ax.xaxis, 'y': ax.yaxis} + if 'values' in channel.axis: + if channel.type == 'temporal': + current_axis[channel.name].set_major_locator(ticker.FixedLocator(_convert_to_mpl_date(channel.axis.get('values')))) + elif channel.type == 'quantitative': + current_axis[channel.name].set_major_locator(ticker.FixedLocator(channel.axis.get('values'))) else: raise NotImplementedError - elif 'tickCount' in axis: - current_axis[channel['axis']].set_major_locator( - ticker.MaxNLocator(steps=[2, 5, 10], nbins=axis.get('tickCount')+1, min_n_ticks=axis.get('tickCount')) + elif 'tickCount' in channel.axis: + current_axis[channel.name].set_major_locator( + ticker.MaxNLocator(steps=[2, 5, 10], nbins=channel.axis.get('tickCount')+1, min_n_ticks=channel.axis.get('tickCount')) ) -def _set_tick_formatter(channel, axis): +def _set_tick_formatter(channel, ax): """Set the tick formatter. Parameters ---------- - channel : dict - The mapping of the channel data and metadata - axis : dict + channel : parse_chart.ChannelMetadata + The channel data and metadata + ax : matplotlib.axes The mapping of the axis metadata and the scale data Notes @@ -162,25 +169,22 @@ def _set_tick_formatter(channel, axis): For formatting of temporal data, Matplotlib does not support some format strings that Altair supports (%L, %Q, %s). Matplotlib only supports datetime.strftime formatting for dates. """ - current_axis = {'x': channel['ax'].xaxis, 'y': channel['ax'].yaxis} - format_str = '' - - if 'format' in axis: - format_str = axis['format'] + current_axis = {'x': ax.xaxis, 'y': ax.yaxis} + format_str = channel.axis.get('format', '') - if channel['dtype'] == 'temporal': + if channel.type == 'temporal': if not format_str: format_str = '%b %d, %Y' - current_axis[channel['axis']].set_major_formatter(mdates.DateFormatter(format_str)) # May fail silently + current_axis[channel.name].set_major_formatter(mdates.DateFormatter(format_str)) # May fail silently - elif channel['dtype'] == 'quantitative': + elif channel.type == 'quantitative': if format_str: - current_axis[channel['axis']].set_major_formatter(ticker.StrMethodFormatter('{x:' + format_str + '}')) + current_axis[channel.name].set_major_formatter(ticker.StrMethodFormatter('{x:' + format_str + '}')) # Verify that the format string is valid for Matplotlib and exit nicely if not. try: - current_axis[channel['axis']].get_major_formatter().__call__(1) + current_axis[channel.name].get_major_formatter().__call__(1) except ValueError: raise ValueError("Matplotlib only supports format strings as used by `str.format()`." "Some format strings that work in Altair may not work in Matplotlib." @@ -189,18 +193,18 @@ def _set_tick_formatter(channel, axis): raise NotImplementedError # Nominal and Ordinal go here -def _set_label_angle(channel, axis): +def _set_label_angle(channel, ax): """Set the label angle. TODO: handle axis.labelAngle from Altair Parameters ---------- - channel : dict - The mapping of the channel data and metadata - axis : dict + channel : parse_chart.ChannelMetadata + The channel data and metadata + ax : matplotlib.axes The mapping of the axis metadata and the scale data """ - if channel['dtype'] == 'temporal' and channel['axis'] == 'x': - for label in channel['ax'].get_xticklabels(): + if channel.type == 'temporal' and channel.name == 'x': + for label in ax.get_xticklabels(): # Rotate the labels on the x-axis so they don't run into each other. label.set_rotation(30) label.set_ha('right') @@ -213,22 +217,12 @@ def convert_axis(ax, chart): ---------- ax The Matplotlib axis to be modified - chart - The Altair chart + chart : parse_chart.ChartMetadata + The chart data and metadata """ - for channel in chart.to_dict()['encoding']: - if channel in ['x', 'y']: - chart_info = {'ax': ax, 'axis': channel, - 'data': _locate_channel_data(chart, channel), - 'dtype': _locate_channel_dtype(chart, channel)} - if chart_info['dtype'] == 'temporal': - chart_info['data'] = _convert_to_mpl_date(chart_info['data']) - - scale_info = _locate_channel_scale(chart, channel) - axis_info = _locate_channel_axis(chart, channel) - - _set_limits(chart_info, scale_info) - _set_tick_locator(chart_info, axis_info) - _set_tick_formatter(chart_info, axis_info) - _set_label_angle(chart_info, axis_info) + for channel in [chart.encoding['x'], chart.encoding['y']]: + _set_limits(channel, chart.mark, ax) + _set_tick_locator(channel, ax) + _set_tick_formatter(channel, ax) + _set_label_angle(channel, ax) diff --git a/mplaltair/_convert.py b/mplaltair/_convert.py index f3ad9d0..23e54a6 100644 --- a/mplaltair/_convert.py +++ b/mplaltair/_convert.py @@ -1,5 +1,3 @@ -import matplotlib.dates as mdates -from ._data import _locate_channel_data, _locate_channel_dtype, _convert_to_mpl_date, _normalize_data def _allowed_ranged_marks(enc_channel, mark): """TODO: DOCS @@ -79,6 +77,7 @@ def _process_stroke(dtype, data): """ raise NotImplementedError + _mappings = { 'x': _process_x, 'y': _process_y, @@ -98,8 +97,8 @@ def _convert(chart): Parameters ---------- - chart - The Altair chart. + chart : parse_chart.ChartMetadata + Data and metadata for the Altair chart Returns ------- @@ -109,21 +108,12 @@ def _convert(chart): """ mapping = {} - _normalize_data(chart) - - if not chart.to_dict().get('encoding'): - raise ValueError("Encoding not provided with the chart specification") - for enc_channel, enc_spec in chart.to_dict()['encoding'].items(): - if not _allowed_ranged_marks(enc_channel, chart.to_dict()['mark']): - raise ValueError("Ranged encoding channels like x2, y2 not allowed for Mark: {}".format(chart['mark'])) + for enc_channel in chart.encoding: + if not _allowed_ranged_marks(enc_channel, chart.mark): + raise ValueError("Ranged encoding channels like x2, y2 not allowed for Mark: {}".format(chart.mark)) - for channel in chart.to_dict()['encoding']: - data = _locate_channel_data(chart, channel) - dtype = _locate_channel_dtype(chart, channel) - if dtype == 'temporal': - data = _convert_to_mpl_date(data) + for k, channel in chart.encoding.items(): + mapping[_mappings[k](channel.type, channel.data)[0]] = _mappings[k](channel.type, channel.data)[1] - mapping[_mappings[channel](dtype, data)[0]] = _mappings[channel](dtype, data)[1] - return mapping diff --git a/mplaltair/_data.py b/mplaltair/_data.py index c56e244..55dc6bc 100644 --- a/mplaltair/_data.py +++ b/mplaltair/_data.py @@ -1,12 +1,11 @@ +import pandas as pd from ._exceptions import ValidationError +from ._utils import _fetch import matplotlib.dates as mdates import matplotlib.cbook as cbook from datetime import datetime import numpy as np -import pandas as pd - -from ._utils import _fetch def _normalize_data(chart): """Converts the data to a Pandas dataframe @@ -40,110 +39,6 @@ def _normalize_data(chart): chart.data = df -def _locate_channel_dtype(chart, channel): - """Locates dtype used for each channel - Parameters - ---------- - chart - The Altair chart - channel - The Altair channel being examined - - Returns - ------- - A string representing the data type from the Altair chart ('quantitative', 'ordinal', 'numeric', 'temporal') - """ - - channel_val = chart.to_dict()['encoding'][channel] - if channel_val.get('type'): - return channel_val.get('type') - else: - # TODO: find some way to deal with 'value' so that, opacity, for instance, can be plotted with a value defined - if channel_val.get('value'): - raise NotImplementedError - raise NotImplementedError - - -def _locate_channel_data(chart, channel): - """Locates data used for each channel - - Parameters - ---------- - chart - The Altair chart - channel - The Altair channel being examined - - Returns - ------- - A numpy ndarray containing the data used for the channel - - Raises - ------ - ValidationError - Raised when the specification does not contain any data attribute - - """ - - channel_val = chart.to_dict()['encoding'][channel] - if channel_val.get('value'): - return channel_val.get('value') - elif channel_val.get('aggregate'): - return _aggregate_channel() - elif channel_val.get('timeUnit'): - return _handle_timeUnit() - else: # field is required if the above are not present. - return chart.data[channel_val.get('field')].values - - -def _aggregate_channel(): - raise NotImplementedError - - -def _handle_timeUnit(): - raise NotImplementedError - - -def _locate_channel_scale(chart, channel): - """Locates the channel's scale information. - - Parameters - ---------- - chart - The Altair chart - channel - The Altair channel being examined - - Returns - ------- - A dictionary with the scale information - """ - channel_val = chart.to_dict()['encoding'][channel] - if channel_val.get('scale'): - return channel_val.get('scale') - else: - return {} - - -def _locate_channel_axis(chart, channel): - """Locates the channel's scale information. - - Parameters - ---------- - chart - The Altair chart - channel - The Altair channel being examined - - Returns - ------- - A dictionary with the axis information - """ - channel_val = chart.to_dict()['encoding'][channel] - if channel_val.get('axis'): - return channel_val.get('axis') - else: - return {} def _convert_to_mpl_date(data): """Converts datetime, datetime64, strings, and Altair DateTime objects to Matplotlib dates. @@ -163,7 +58,7 @@ def _convert_to_mpl_date(data): if len(data) == 0: return [] else: - return [_convert_to_mpl_date(i) for i in data] + return np.asarray([_convert_to_mpl_date(i) for i in data]) else: if isinstance(data, str): # string format for dates data = mdates.datestr2num(data) @@ -189,9 +84,8 @@ def _altair_DateTime_to_datetime(dt): A datetime object """ MONTHS = {'Jan': 1, 'January': 1, 'Feb': 2, 'February': 2, 'Mar': 3, 'March': 3, 'Apr': 4, 'April': 4, - 'May': 5, 'May': 5, 'Jun': 6, 'June': 6, 'Jul': 7, 'July': 7, 'Aug': 8, 'August': 8, - 'Sep': 9, 'Sept': 9, 'September': 9, 'Oct': 10, 'October': 10, 'Nov': 11, 'November': 11, - 'Dec': 12, 'December': 12} + 'May': 5, 'Jun': 6, 'June': 6, 'Jul': 7, 'July': 7, 'Aug': 8, 'August': 8, 'Sep': 9, 'Sept': 9, + 'September': 9, 'Oct': 10, 'October': 10, 'Nov': 11, 'November': 11, 'Dec': 12, 'December': 12} alt_to_datetime_kw_mapping = {'date': 'day', 'hours': 'hour', 'milliseconds': 'microsecond', 'minutes': 'minute', 'month': 'month', 'seconds': 'second', 'year': 'year'} diff --git a/mplaltair/_marks.py b/mplaltair/_marks.py new file mode 100644 index 0000000..0e59a52 --- /dev/null +++ b/mplaltair/_marks.py @@ -0,0 +1,72 @@ +import matplotlib +import numpy as np +from ._data import _convert_to_mpl_date + + +def _handle_line(chart, ax): + """Convert encodings, manipulate data if needed, plot on ax. + + Parameters + ---------- + chart : altair.Chart + The Altair chart object + + ax + The Matplotlib axes object + + Notes + ----- + Fill isn't necessary until mpl-altair can handle multiple plot types in one plot. + Size is unsupported by both Matplotlib and Altair. + When both Color and Stroke are provided, color is ignored and stroke is used. + Shape is unsupported in line graphs unless another plot type is plotted at the same time. + """ + groups = [] + kwargs = {} + + if chart.encoding.get('opacity'): + groups.append('opacity') + if chart.encoding.get('stroke'): + groups.append('stroke') + elif chart.encoding.get('color'): + groups.append('color') + + list_fields = lambda c, g: [chart.encoding[i].field for i in g] + try: + for label, subset in chart.data.groupby(list_fields(chart, groups)): + if 'opacity' in groups: + kwargs['alpha'] = _opacity_norm(chart, subset[chart.encoding['opacity'].field].iloc[0]) + + if 'color' not in groups and 'stroke' not in groups: + kwargs['color'] = matplotlib.rcParams['lines.color'] + ax.plot(subset[chart.encoding['x'].field], subset[chart.encoding['y'].field], **kwargs) + except ValueError: + ax.plot(chart.encoding['x'].data, chart.encoding['y'].data) + + +def _opacity_norm(chart, val): + """ + Normalize the values of a column to be between 0.15 and 1, which is a visible range for opacity. + + Parameters + ---------- + chart : parse_chart.ChartMetadata + The Altair chart object + val + The specific value to be normalized. + + Returns + ------- + The normalized value (between 0.15 and 1) + """ + arr = chart.encoding['opacity'].data + if chart.encoding['opacity'].type in ['ordinal', 'nominal', 'temporal']: + unique, indices = np.unique(arr, return_inverse=True) + arr = indices + if chart.encoding['opacity'].type == "temporal": + val = unique.tolist().index(_convert_to_mpl_date(val)) + else: + val = unique.tolist().index(val) + data_min, data_max = (arr.min(), arr.max()) + desired_min, desired_max = (0.15, 1) # Chosen so that the minimum value is visible (aka nonzero) + return ((val - data_min) / (data_max - data_min)) * (desired_max - desired_min) + desired_min \ No newline at end of file diff --git a/mplaltair/parse_chart.py b/mplaltair/parse_chart.py new file mode 100644 index 0000000..2ea5171 --- /dev/null +++ b/mplaltair/parse_chart.py @@ -0,0 +1,117 @@ +from mplaltair._data import _convert_to_mpl_date, _normalize_data + + +class ChannelMetadata(object): + """ + Stores relevant encoding channel information. + + Attributes + ---------- + name : str + The name of the encoding channel + data : np.array + The data linked to the channel (temporal data is converted) + axis : dict + bin : boolean, None + field : str + scale : dict + sort + stack + timeUnit + title + type : str + """ + def __init__(self, channel, alt_chart): + chart_dict = alt_chart.to_dict() + self.name = channel + self.data = self._locate_channel_data(alt_chart) + self.axis = chart_dict['encoding'][self.name].get('axis', {}) + self.bin = chart_dict['encoding'][self.name].get('bin', None) + self.field = chart_dict['encoding'][self.name].get('field', None) + self.scale = chart_dict['encoding'][self.name].get('scale', {}) + self.sort = chart_dict['encoding'][self.name].get('sort', None) + self.stack = chart_dict['encoding'][self.name].get('stack', None) + self.timeUnit = chart_dict['encoding'][self.name].get('aggregate', None) + self.title = chart_dict['encoding'][self.name].get('title', None) + self.type = self._locate_channel_dtype(alt_chart) + + if self.type == 'temporal': + self.data = _convert_to_mpl_date(self.data) + + def _aggregate_channel(self): + raise NotImplementedError + + def _handle_timeUnit(self): + raise NotImplementedError + + def _locate_channel_data(self, alt_chart): + """Locates data used for each channel + + Parameters + ---------- + alt_chart : altair.Chart + The Altair chart + + Returns + ------- + A numpy ndarray containing the data used for the channel + + """ + + channel_val = alt_chart.to_dict()['encoding'][self.name] + if channel_val.get('value'): + return channel_val.get('value') + elif channel_val.get('aggregate'): + return self._aggregate_channel() + elif channel_val.get('timeUnit'): + return self._handle_timeUnit() + else: # field is required if the above are not present. + return alt_chart.data[channel_val.get('field')].values + + def _locate_channel_dtype(self, alt_chart): + """Locates dtype used for each channel + + Parameters + ---------- + alt_chart : altair.Chart + The Altair chart + + Returns + ------- + A string representing the data type from the Altair chart ('quantitative', 'ordinal', 'numeric', 'temporal') + """ + + channel_val = alt_chart.to_dict()['encoding'][self.name] + if channel_val.get('type'): + return channel_val.get('type') + else: + # TODO: find some way to deal with 'value' so that, opacity, for instance, can be plotted with a value defined + if channel_val.get('value'): + raise NotImplementedError + raise NotImplementedError + + + +class ChartMetadata(object): + """ + Stores Altair chart information usefully. Use this class for initially converting the Altair chart. + + Attributes + ---------- + data : pd.DataFrame + mark : str + encoding : dict of ChannelMetadata + """ + + def __init__(self, alt_chart): + + if not alt_chart.to_dict().get('encoding'): + raise ValueError("Encoding is not provided with the chart specification") + + _normalize_data(alt_chart) + self.data = alt_chart.data + self.mark = alt_chart.mark + + self.encoding = {} + for k, v in alt_chart.to_dict()['encoding'].items(): + self.encoding[k] = ChannelMetadata(k, alt_chart) \ No newline at end of file diff --git a/mplaltair/tests/baseline_images/test_convert/test_line.png b/mplaltair/tests/baseline_images/test_convert/test_line.png new file mode 100644 index 0000000..b49cdcd Binary files /dev/null and b/mplaltair/tests/baseline_images/test_convert/test_line.png differ diff --git a/mplaltair/tests/baseline_images/test_convert/test_line_color.png b/mplaltair/tests/baseline_images/test_convert/test_line_color.png new file mode 100644 index 0000000..203b297 Binary files /dev/null and b/mplaltair/tests/baseline_images/test_convert/test_line_color.png differ diff --git a/mplaltair/tests/baseline_images/test_convert/test_line_opacity_c:O.png b/mplaltair/tests/baseline_images/test_convert/test_line_opacity_c:O.png new file mode 100644 index 0000000..703d21c Binary files /dev/null and b/mplaltair/tests/baseline_images/test_convert/test_line_opacity_c:O.png differ diff --git a/mplaltair/tests/baseline_images/test_convert/test_line_opacity_color_c:N-c:O.png b/mplaltair/tests/baseline_images/test_convert/test_line_opacity_color_c:N-c:O.png new file mode 100644 index 0000000..d57ee36 Binary files /dev/null and b/mplaltair/tests/baseline_images/test_convert/test_line_opacity_color_c:N-c:O.png differ diff --git a/mplaltair/tests/baseline_images/test_convert/test_line_opacity_color_d:Q-d:Q.png b/mplaltair/tests/baseline_images/test_convert/test_line_opacity_color_d:Q-d:Q.png new file mode 100644 index 0000000..d57ee36 Binary files /dev/null and b/mplaltair/tests/baseline_images/test_convert/test_line_opacity_color_d:Q-d:Q.png differ diff --git a/mplaltair/tests/baseline_images/test_convert/test_line_opacity_color_dates:T-dates:T.png b/mplaltair/tests/baseline_images/test_convert/test_line_opacity_color_dates:T-dates:T.png new file mode 100644 index 0000000..d57ee36 Binary files /dev/null and b/mplaltair/tests/baseline_images/test_convert/test_line_opacity_color_dates:T-dates:T.png differ diff --git a/mplaltair/tests/baseline_images/test_convert/test_line_opacity_d:Q.png b/mplaltair/tests/baseline_images/test_convert/test_line_opacity_d:Q.png new file mode 100644 index 0000000..703d21c Binary files /dev/null and b/mplaltair/tests/baseline_images/test_convert/test_line_opacity_d:Q.png differ diff --git a/mplaltair/tests/baseline_images/test_convert/test_line_opacity_dates:T.png b/mplaltair/tests/baseline_images/test_convert/test_line_opacity_dates:T.png new file mode 100644 index 0000000..703d21c Binary files /dev/null and b/mplaltair/tests/baseline_images/test_convert/test_line_opacity_dates:T.png differ diff --git a/mplaltair/tests/baseline_images/test_convert/test_line_stroke.png b/mplaltair/tests/baseline_images/test_convert/test_line_stroke.png new file mode 100644 index 0000000..ef08b0b Binary files /dev/null and b/mplaltair/tests/baseline_images/test_convert/test_line_stroke.png differ diff --git a/mplaltair/tests/baseline_images/test_convert/test_line_stroke_a:Q-b:Q-d:Q.png b/mplaltair/tests/baseline_images/test_convert/test_line_stroke_a:Q-b:Q-d:Q.png new file mode 100644 index 0000000..203b297 Binary files /dev/null and b/mplaltair/tests/baseline_images/test_convert/test_line_stroke_a:Q-b:Q-d:Q.png differ diff --git a/mplaltair/tests/baseline_images/test_convert/test_quantitative_scatter_chart0.png b/mplaltair/tests/baseline_images/test_convert/test_quantitative_scatter_chart0.png index 793f334..eb8c701 100644 Binary files a/mplaltair/tests/baseline_images/test_convert/test_quantitative_scatter_chart0.png and b/mplaltair/tests/baseline_images/test_convert/test_quantitative_scatter_chart0.png differ diff --git a/mplaltair/tests/baseline_images/test_convert/test_quantitative_scatter_chart1.png b/mplaltair/tests/baseline_images/test_convert/test_quantitative_scatter_chart1.png index 1f3073c..3024b93 100644 Binary files a/mplaltair/tests/baseline_images/test_convert/test_quantitative_scatter_chart1.png and b/mplaltair/tests/baseline_images/test_convert/test_quantitative_scatter_chart1.png differ diff --git a/mplaltair/tests/baseline_images/test_convert/test_scatter_temporal_channel0.png b/mplaltair/tests/baseline_images/test_convert/test_scatter_temporal_channel0.png index 6ec5288..a8f76b8 100644 Binary files a/mplaltair/tests/baseline_images/test_convert/test_scatter_temporal_channel0.png and b/mplaltair/tests/baseline_images/test_convert/test_scatter_temporal_channel0.png differ diff --git a/mplaltair/tests/baseline_images/test_convert/test_scatter_temporal_channel1.png b/mplaltair/tests/baseline_images/test_convert/test_scatter_temporal_channel1.png index 6ec5288..a8f76b8 100644 Binary files a/mplaltair/tests/baseline_images/test_convert/test_scatter_temporal_channel1.png and b/mplaltair/tests/baseline_images/test_convert/test_scatter_temporal_channel1.png differ diff --git a/mplaltair/tests/test_axis.py b/mplaltair/tests/test_axis.py index dfe23cb..53322d0 100644 --- a/mplaltair/tests/test_axis.py +++ b/mplaltair/tests/test_axis.py @@ -4,6 +4,7 @@ import pandas as pd from mplaltair import convert from .._axis import convert_axis +from ..parse_chart import ChartMetadata import pytest df_quant = pd.DataFrame({ @@ -26,18 +27,13 @@ @pytest.mark.xfail(raises=TypeError) def test_invalid_temporal(): chart = alt.Chart(df_temp).mark_point().encode(alt.X('a:T')) - fig, ax = plt.subplots() - convert_axis(ax, chart) + convert(chart) @pytest.mark.mpl_image_compare(baseline_dir='baseline_images/test_axis') def test_axis_more_than_x_and_y(): chart = alt.Chart(df_quant).mark_point().encode(alt.X('a'), alt.Y('b'), color=alt.Color('c')) - mapping = convert(chart) - fig, ax = plt.subplots() - ax.scatter(**mapping) - convert_axis(ax, chart) - fig.tight_layout() + fig, ax = convert(chart) return fig @@ -46,11 +42,7 @@ def test_axis_more_than_x_and_y(): (df_temp, 'months', 'years'), (df_temp, 'years', 'months'), (df_temp, 'months', 'combination')]) def test_axis(df, x, y): chart = alt.Chart(df).mark_point().encode(alt.X(x), alt.Y(y)) - mapping = convert(chart) - fig, ax = plt.subplots() - ax.scatter(**mapping) - convert_axis(ax, chart) - fig.tight_layout() + fig, ax = convert(chart) return fig @@ -60,7 +52,7 @@ def test_axis_set_tick_formatter_fail(): This test is just for temporary coverage purposes.""" from .._axis import _set_tick_formatter _, ax = plt.subplots() - _set_tick_formatter({'ax': ax, 'dtype': 'ordinal'}, {}) + chart = ChartMetadata(alt.Chart(df_quant).mark_point().encode('a:N', 'c:O')) @pytest.mark.mpl_image_compare(baseline_dir='baseline_images/test_axis') @@ -71,12 +63,9 @@ def test_axis_zero_quantitative(x, y, zero): alt.X(x, scale=alt.Scale(zero=zero)), alt.Y(y, scale=alt.Scale(zero=zero)) ) - mapping = convert(chart) - fig, ax = plt.subplots() - ax.scatter(**mapping) - convert_axis(ax, chart) - fig.tight_layout() + fig, ax = convert(chart) return fig + # plt.show() @pytest.mark.mpl_image_compare(baseline_dir='baseline_images/test_axis') @@ -92,11 +81,7 @@ def test_axis_domain(df, x, y, x_dom, y_dom): alt.X(x, scale=alt.Scale(domain=x_dom)), alt.Y(y, scale=alt.Scale(domain=y_dom)) ) - mapping = convert(chart) - fig, ax = plt.subplots() - ax.scatter(**mapping) - convert_axis(ax, chart) - fig.tight_layout() + fig, ax = convert(chart) return fig @@ -106,11 +91,7 @@ def test_axis_unaggregated_quantitative(): alt.X('a', scale=alt.Scale(domain="unaggregated")), alt.Y('c', scale=alt.Scale(domain="unaggregated")) ) - mapping = convert(chart) - fig, ax = plt.subplots() - ax.scatter(**mapping) - plt.close() - convert_axis(ax, chart) + convert(chart) @pytest.mark.mpl_image_compare(baseline_dir='baseline_images/test_axis') @@ -124,11 +105,7 @@ def test_axis_values(df, y, vals): alt.X('a', axis=alt.Axis(values=[-1, 1, 1.5, 2.125, 3])), alt.Y(y, axis=alt.Axis(values=vals)) ) - mapping = convert(chart) - fig, ax = plt.subplots() - ax.scatter(**mapping) - convert_axis(ax, chart) - fig.tight_layout() + fig, ax = convert(chart) return fig @@ -141,11 +118,7 @@ def test_axis_tickCount(df, x, tickCount): chart = alt.Chart(df).mark_point().encode( alt.X(x, axis=alt.Axis(tickCount=tickCount)), alt.Y('a', axis=alt.Axis(tickCount=tickCount)) ) - mapping = convert(chart) - fig, ax = plt.subplots() - ax.scatter(**mapping) - convert_axis(ax, chart) - fig.tight_layout() + fig, ax = convert(chart) return fig @@ -156,11 +129,7 @@ def test_axis_scale_basic(df, column, scale_type): alt.X(column, scale=alt.Scale(type=scale_type)), alt.Y('a') ) - mapping = convert(chart) - fig, ax = plt.subplots() - ax.scatter(**mapping) - convert_axis(ax, chart) - fig.tight_layout() + fig, ax = convert(chart) return fig @@ -171,11 +140,7 @@ def test_axis_scale_type_x_quantitative(column, type, base, exponent): alt.X(column, scale=alt.Scale(type=type, base=base, exponent=exponent)), alt.Y('a') ) - mapping = convert(chart) - fig, ax = plt.subplots() - ax.scatter(**mapping) - convert_axis(ax, chart) - fig.tight_layout() + fig, ax = convert(chart) return fig @@ -186,11 +151,8 @@ def test_axis_scale_type_y_quantitative(column, type, base, exponent): alt.X('a'), alt.Y(column, scale=alt.Scale(type=type, base=base, exponent=exponent)) ) - mapping = convert(chart) - fig, ax = plt.subplots() - ax.scatter(**mapping) - convert_axis(ax, chart) - fig.tight_layout() + fig, ax = convert(chart) + plt.show() return fig @@ -202,10 +164,7 @@ def test_axis_scale_NotImplemented_quantitative(df, x, type): alt.X(x, scale=alt.Scale(type=type)), alt.Y('a') ) - mapping = convert(chart) - fig, ax = plt.subplots() - ax.scatter(**mapping) - convert_axis(ax, chart) + convert(chart) @pytest.mark.mpl_image_compare(baseline_dir='baseline_images/test_axis') @@ -217,11 +176,7 @@ def test_axis_formatter(df, x, y, format_x, format_y): alt.X(x, axis=alt.Axis(format=format_x)), alt.Y(y, axis=alt.Axis(format=format_y)) ) - mapping = convert(chart) - fig, ax = plt.subplots() - ax.scatter(**mapping) - convert_axis(ax, chart) - fig.tight_layout() + fig, ax = convert(chart) return fig @@ -233,11 +188,7 @@ def test_axis_formatter_temporal(): alt.X('months:T', axis=alt.Axis(format='%b %Y')), alt.Y('hrs:T', axis=alt.Axis(format='%H:%M:%S')) ) - mapping = convert(chart) - fig, ax = plt.subplots() - ax.scatter(**mapping) - convert_axis(ax, chart) - fig.tight_layout() + fig, ax = convert(chart) return fig @@ -247,7 +198,4 @@ def test_axis_formatter_fail(): alt.X('c', axis=alt.Axis(format='-$.2g')), alt.Y('b', axis=alt.Axis(format='+.3r')) ) - mapping = convert(chart) - fig, ax = plt.subplots() - ax.scatter(**mapping) - convert_axis(ax, chart) + convert(chart) diff --git a/mplaltair/tests/test_convert.py b/mplaltair/tests/test_convert.py index 3fe070a..5c84b09 100644 --- a/mplaltair/tests/test_convert.py +++ b/mplaltair/tests/test_convert.py @@ -2,10 +2,11 @@ import altair as alt import pandas as pd -import matplotlib.pyplot as plt import matplotlib.dates as mdates - +import matplotlib.pyplot as plt from mplaltair import convert +from mplaltair._convert import _convert +from mplaltair.parse_chart import ChartMetadata df = pd.DataFrame({ @@ -22,175 +23,168 @@ }) -def test_encoding_not_provided(): +def test_encoding_not_provided(): # TODO: move to the parse_chart tests chart_spec = alt.Chart(df).mark_point() with pytest.raises(ValueError): - convert(chart_spec) + chart = ChartMetadata(chart_spec) def test_invalid_encodings(): chart_spec = alt.Chart(df).encode(x2='quant').mark_point() + chart = ChartMetadata(chart_spec) with pytest.raises(ValueError): - convert(chart_spec) + _convert(chart) @pytest.mark.xfail(raises=TypeError) -def test_invalid_temporal(): +def test_invalid_temporal(): # TODO: move to parse_chart tests??? chart = alt.Chart(df).mark_point().encode(alt.X('quant:T')) - convert(chart) + ChartMetadata(chart) @pytest.mark.parametrize('channel', ['quant', 'ord', 'nom']) def test_convert_x_success(channel): chart_spec = alt.Chart(df).encode(x=channel).mark_point() - mapping = convert(chart_spec) + chart = ChartMetadata(chart_spec) + mapping = _convert(chart) assert list(mapping['x']) == list(df[channel].values) @pytest.mark.parametrize("column", ["years", "months", "days", "hrs", "combination"]) def test_convert_x_success_temporal(column): chart = alt.Chart(df).mark_point().encode(alt.X(column)) - mapping = convert(chart) + chart = ChartMetadata(chart) + mapping = _convert(chart) assert list(mapping['x']) == list(mdates.date2num(df[column].values)) - # assert list(mapping['x']) == list(df[column].values) def test_convert_x_fail(): - chart_spec = alt.Chart(df).encode(x='b:N').mark_point() with pytest.raises(KeyError): - convert(chart_spec) + chart_spec = ChartMetadata(alt.Chart(df).encode(x='b:N').mark_point()) @pytest.mark.parametrize('channel', ['quant', 'ord', 'nom']) def test_convert_y_success(channel): - chart_spec = alt.Chart(df).encode(y=channel).mark_point() - mapping = convert(chart_spec) + chart_spec = ChartMetadata(alt.Chart(df).encode(y=channel).mark_point()) + mapping = _convert(chart_spec) assert list(mapping['y']) == list(df[channel].values) @pytest.mark.parametrize("column", ["years", "months", "days", "hrs", "combination"]) def test_convert_y_success_temporal(column): - chart = alt.Chart(df).mark_point().encode(alt.Y(column)) - mapping = convert(chart) + chart = ChartMetadata(alt.Chart(df).mark_point().encode(alt.Y(column))) + mapping = _convert(chart) assert list(mapping['y']) == list(mdates.date2num(df[column].values)) - # assert list(mapping['y']) == list(df[column].values) def test_convert_y_fail(): - chart_spec = alt.Chart(df).encode(y='b:N').mark_point() with pytest.raises(KeyError): - convert(chart_spec) + chart_spec = ChartMetadata(alt.Chart(df).encode(y='b:N').mark_point()) @pytest.mark.xfail(raises=ValueError, reason="It doesn't make sense to have x2 and y2 on scatter plots") def test_quantitative_x2_y2(): - chart = alt.Chart(df_quant).mark_point().encode(alt.X('a'), alt.Y('b'), alt.X2('c'), alt.Y2('alpha')) - convert(chart) + chart = ChartMetadata(alt.Chart(df_quant).mark_point().encode(alt.X('a'), alt.Y('b'), alt.X2('c'), alt.Y2('alpha'))) + _convert(chart) @pytest.mark.xfail(raises=ValueError) @pytest.mark.parametrize("column", ["years", "months", "days", "hrs", "combination"]) def test_convert_x2_y2_fail_temporal(column): - chart = alt.Chart(df).mark_point().encode(alt.X2(column), alt.Y2(column)) - convert(chart) + chart = ChartMetadata(alt.Chart(df).mark_point().encode(alt.X2(column), alt.Y2(column))) + _convert(chart) @pytest.mark.parametrize('channel,dtype', [('quant','quantitative'), ('ord','ordinal')]) def test_convert_color_success(channel, dtype): - chart_spec = alt.Chart(df).encode(color=alt.Color(field=channel, type=dtype)).mark_point() - mapping = convert(chart_spec) + chart_spec = ChartMetadata(alt.Chart(df).encode(color=alt.Color(field=channel, type=dtype)).mark_point()) + mapping = _convert(chart_spec) assert list(mapping['c']) == list(df[channel].values) def test_convert_color_success_nominal(): - chart_spec = alt.Chart(df).encode(color='nom').mark_point() + chart_spec = ChartMetadata(alt.Chart(df).encode(color='nom').mark_point()) with pytest.raises(NotImplementedError): - convert(chart_spec) + _convert(chart_spec) @pytest.mark.parametrize("column", ["years", "months", "days", "hrs", "combination"]) def test_convert_color_success_temporal(column): - chart = alt.Chart(df).mark_point().encode(alt.Color(column)) - mapping = convert(chart) + chart = ChartMetadata(alt.Chart(df).mark_point().encode(alt.Color(column))) + mapping = _convert(chart) assert list(mapping['c']) == list(mdates.date2num(df[column].values)) - # assert list(mapping['c']) == list(df[column].values) def test_convert_color_fail(): - chart_spec = alt.Chart(df).encode(color='b:N').mark_point() with pytest.raises(KeyError): - convert(chart_spec) + chart_spec = ChartMetadata(alt.Chart(df).encode(color='b:N').mark_point()) @pytest.mark.parametrize('channel,type', [('quant', 'Q'), ('ord', 'O')]) def test_convert_fill(channel, type): - chart_spec = alt.Chart(df).encode(fill='{}:{}'.format(channel, type)).mark_point() - mapping = convert(chart_spec) + chart_spec = ChartMetadata(alt.Chart(df).encode(fill='{}:{}'.format(channel, type)).mark_point()) + mapping = _convert(chart_spec) assert list(mapping['c']) == list(df[channel].values) def test_convert_fill_success_nominal(): - chart_spec = alt.Chart(df).encode(fill='nom').mark_point() + chart_spec = ChartMetadata(alt.Chart(df).encode(fill='nom').mark_point()) with pytest.raises(NotImplementedError): - convert(chart_spec) + _convert(chart_spec) @pytest.mark.parametrize("column", ["years", "months", "days", "hrs", "combination"]) def test_convert_fill_success_temporal(column): - chart = alt.Chart(df).mark_point().encode(alt.Fill(column)) - mapping = convert(chart) + chart = ChartMetadata(alt.Chart(df).mark_point().encode(alt.Fill(column))) + mapping = _convert(chart) assert list(mapping['c']) == list(mdates.date2num(df[column].values)) - # assert list(mapping['c']) == list(df[column].values) def test_convert_fill_fail(): - chart_spec = alt.Chart(df).encode(fill='b:N').mark_point() with pytest.raises(KeyError): - convert(chart_spec) + chart_spec = ChartMetadata(alt.Chart(df).encode(fill='b:N').mark_point()) @pytest.mark.xfail(raises=NotImplementedError, reason="The marker argument in scatter() cannot take arrays") def test_quantitative_shape(): - chart = alt.Chart(df_quant).mark_point().encode(alt.Shape('shape')) - mapping = convert(chart) + chart = ChartMetadata(alt.Chart(df_quant).mark_point().encode(alt.Shape('shape'))) + mapping = _convert(chart) @pytest.mark.xfail(raises=NotImplementedError, reason="The marker argument in scatter() cannot take arrays") @pytest.mark.parametrize("column", ["years", "months", "days", "hrs", "combination"]) def test_convert_shape_fail_temporal(column): chart = alt.Chart(df).mark_point().encode(alt.Shape(column)) - mapping = convert(chart) + convert(chart) @pytest.mark.xfail(raises=NotImplementedError, reason="Merge: the dtype for opacity isn't assumed to be quantitative") def test_quantitative_opacity_value(): - chart = alt.Chart(df_quant).mark_point().encode(opacity=alt.value(.5)) - mapping = convert(chart) + chart = ChartMetadata(alt.Chart(df_quant).mark_point().encode(opacity=alt.value(.5))) @pytest.mark.xfail(raises=NotImplementedError, reason="The alpha argument in scatter() cannot take arrays") def test_quantitative_opacity_array(): - chart = alt.Chart(df_quant).mark_point().encode(alt.Opacity('alpha')) - convert(chart) + chart = ChartMetadata(alt.Chart(df_quant).mark_point().encode(alt.Opacity('alpha'))) + _convert(chart) @pytest.mark.xfail(raises=NotImplementedError, reason="The alpha argument in scatter() cannot take arrays") @pytest.mark.parametrize("column", ["years", "months", "days", "hrs", "combination"]) def test_convert_opacity_fail_temporal(column): - chart = alt.Chart(df).mark_point().encode(alt.Opacity(column)) - convert(chart) + chart = ChartMetadata(alt.Chart(df).mark_point().encode(alt.Opacity(column))) + _convert(chart) @pytest.mark.parametrize('channel,type', [('quant', 'Q'), ('ord', 'O')]) def test_convert_size_success(channel, type): - chart_spec = alt.Chart(df).encode(size='{}:{}'.format(channel, type)).mark_point() - mapping = convert(chart_spec) + chart_spec = ChartMetadata(alt.Chart(df).encode(size='{}:{}'.format(channel, type)).mark_point()) + mapping = _convert(chart_spec) assert list(mapping['s']) == list(df[channel].values) def test_convert_size_success_nominal(): - chart_spec = alt.Chart(df).encode(size='nom').mark_point() with pytest.raises(NotImplementedError): - convert(chart_spec) + chart_spec = ChartMetadata(alt.Chart(df).encode(size='nom').mark_point()) + _convert(chart_spec) def test_convert_size_fail(): - chart_spec = alt.Chart(df).encode(size='b:N').mark_point() with pytest.raises(KeyError): - convert(chart_spec) + chart_spec = ChartMetadata(alt.Chart(df).encode(size='b:N').mark_point()) @pytest.mark.xfail(raises=NotImplementedError, reason="Dates would need to be normalized for the size.") @pytest.mark.parametrize("column", ["years", "months", "days", "hrs", "combination"]) def test_convert_size_fail_temporal(column): - chart = alt.Chart(df).mark_point().encode(alt.Size(column)) - convert(chart) + chart = ChartMetadata(alt.Chart(df).mark_point().encode(alt.Size(column))) + _convert(chart) @pytest.mark.xfail(raises=NotImplementedError, reason="Stroke is not well supported in Altair") def test_quantitative_stroke(): - chart = alt.Chart(df_quant).mark_point().encode(alt.Stroke('fill')) - convert(chart) + chart = ChartMetadata(alt.Chart(df_quant).mark_point().encode(alt.Stroke('fill'))) + _convert(chart) @pytest.mark.xfail(raises=NotImplementedError, reason="Stroke is not well defined in Altair") @pytest.mark.parametrize("column", ["years", "months", "days", "hrs", "combination"]) def test_convert_stroke_fail_temporal(column): - chart = alt.Chart(df).mark_point().encode(alt.Stroke(column)) - convert(chart) + chart = ChartMetadata(alt.Chart(df).mark_point().encode(alt.Stroke(column))) + _convert(chart) # Aggregations @@ -198,13 +192,12 @@ def test_convert_stroke_fail_temporal(column): @pytest.mark.xfail(raises=NotImplementedError, reason="Aggregate functions are not supported yet") def test_quantitative_x_count_y(): df_count = pd.DataFrame({"a": [1, 1, 2, 3, 5], "b": [1.4, 1.4, 2.9, 3.18, 5.3]}) - chart = alt.Chart(df_count).mark_point().encode(alt.X('a'), alt.Y('count()')) - mapping = convert(chart) + chart = ChartMetadata(alt.Chart(df_count).mark_point().encode(alt.X('a'), alt.Y('count()'))) + @pytest.mark.xfail(raises=NotImplementedError, reason="specifying timeUnit is not supported yet") def test_timeUnit(): - chart = alt.Chart(df).mark_point().encode(alt.X('date(combination)')) - convert(chart) + chart = ChartMetadata(alt.Chart(df).mark_point().encode(alt.X('date(combination)'))) # Plots @@ -218,17 +211,95 @@ def test_timeUnit(): @pytest.mark.mpl_image_compare(baseline_dir='baseline_images/test_convert') @pytest.mark.parametrize("chart", [chart_quant, chart_fill_quant]) def test_quantitative_scatter(chart): - mapping = convert(chart) - fig, ax = plt.subplots() - ax.scatter(**mapping) + fig, ax = convert(chart) return fig @pytest.mark.mpl_image_compare(baseline_dir='baseline_images/test_convert') @pytest.mark.parametrize("channel", [alt.Color("years"), alt.Fill("years")]) def test_scatter_temporal(channel): - chart = alt.Chart(df).mark_point().encode(alt.X("years"), channel) - mapping = convert(chart) - mapping['y'] = df['quantitative'].values - fig, ax = plt.subplots() - ax.scatter(**mapping) + chart = alt.Chart(df).mark_point().encode( + alt.X("years"), + alt.Y("quantitative"), + channel + ) + fig, ax = convert(chart) return fig + + +# Line plots +df_line = pd.DataFrame({ + 'a': [1, 2, 3, 1, 2, 3, 1, 2, 3], + 'b': [3, 2, 1, 7, 8, 9, 4, 5, 6], + 'c': ['a', 'a', 'a', 'b', 'b', 'b', 'c', 'c', 'c'], + 'd': [1, 1, 1, 2, 2, 2, 3, 3, 3], + 'dates': ['1968-08-01', '1968-08-01', '1968-08-01', '2010-08-08', '2010-08-08', '2010-08-08', '2015-03-14', '2015-03-14', '2015-03-14'] + }) + + +class TestLines(object): + @pytest.mark.mpl_image_compare(baseline_dir='baseline_images/test_convert') + def test_line(self): + chart = alt.Chart(df_line).mark_line().encode( + alt.X('a'), + alt.Y('b'), + ) + fig, _ = convert(chart) + return fig + + @pytest.mark.mpl_image_compare(baseline_dir='baseline_images/test_convert') + @pytest.mark.parametrize('x,y,s', [ + ('a:Q', 'b:Q', 'd:Q'), + pytest.param('a:N', 'b:N', 'c:N', marks=pytest.mark.xfail(raises=NotImplementedError)), + pytest.param('a:O', 'b:O', 'c:O', marks=pytest.mark.xfail(raises=NotImplementedError)) + ]) + def test_line_stroke(self, x, y, s): + chart = alt.Chart(df_line).mark_line().encode( + alt.X(x), + alt.Y(y), + alt.Stroke(s) + ) + fig, _ = convert(chart) + return fig + + @pytest.mark.mpl_image_compare(baseline_dir='baseline_images/test_convert') + def test_line_color(self): + chart = alt.Chart(df_line).mark_line().encode( + alt.X('a'), + alt.Y('b'), + alt.Color('d') + ) + fig, _ = convert(chart) + return fig + + @pytest.mark.mpl_image_compare(baseline_dir='baseline_images/test_convert') + @pytest.mark.parametrize('o', ['d:Q', 'c:O', 'dates:T']) + def test_line_opacity(self, o): + chart = alt.Chart(df_line).mark_line().encode( + alt.X('a'), + alt.Y('b'), + alt.Opacity(o) + ) + fig, ax = convert(chart) + return fig + + @pytest.mark.mpl_image_compare(baseline_dir='baseline_images/test_convert') + @pytest.mark.parametrize('c,o', [('d:Q', 'd:Q'), ('c:N', 'c:O'), ('dates:T', 'dates:T')]) + def test_line_opacity_color(self, c, o): + chart = alt.Chart(df_line).mark_line().encode( + alt.X('a'), + alt.Y('b'), + alt.Color(c), + alt.Opacity(o) + ) + fig, ax = convert(chart) + return fig + + +class TestBars(object): + @pytest.mark.xfail(raises=NotImplementedError) + def test_bar_fail(self): + chart = alt.Chart(df_line).mark_bar().encode( + alt.X('a'), + alt.Y('b'), + ) + convert(chart) diff --git a/mplaltair/tests/test_data.py b/mplaltair/tests/test_data.py index 19f6059..0f4871d 100644 --- a/mplaltair/tests/test_data.py +++ b/mplaltair/tests/test_data.py @@ -5,8 +5,6 @@ import pytest from vega_datasets import data -from mplaltair._data import _normalize_data -from mplaltair._exceptions import ValidationError df = pd.DataFrame({ "a": [1, 2, 3, 4, 5], "b": [1.1, 2.2, 3.3, 4.4, 5.5], "c": [1, 2.2, 3, 4.4, 5], @@ -17,86 +15,17 @@ "quantitative": [1.1, 2.1, 3.1, 4.1, 5.1] }) + def test_data_list(): chart = alt.Chart(pd.DataFrame({'a': [1], 'b': [2], 'c': [3]})).mark_point() - _normalize_data(chart) + _data._normalize_data(chart) assert type(chart.data) == pd.DataFrame def test_data_url(): chart = alt.Chart(data.cars.url).mark_point() - _normalize_data(chart) + _data._normalize_data(chart) assert type(chart.data) == pd.DataFrame -# _locate_channel_data() tests - -@pytest.mark.parametrize("column, dtype", [ - ('a', 'quantitative'), ('b', 'quantitative'), ('c', 'quantitative'), ('combination', 'temporal') -]) -def test_data_field_quantitative(column, dtype): - chart = alt.Chart(df).mark_point().encode(alt.X(field=column, type=dtype)) - for channel in chart.to_dict()['encoding']: - data = _data._locate_channel_data(chart, channel) - assert list(data) == list(df[column].values) - - -@pytest.mark.parametrize("column", ['a', 'b', 'c', 'combination']) -def test_data_shorthand_quantitative(column): - chart = alt.Chart(df).mark_point().encode(alt.X(column)) - for channel in chart.to_dict()['encoding']: - data = _data._locate_channel_data(chart, channel) - assert list(data) == list(df[column].values) - - -def test_data_value_quantitative(): - chart = alt.Chart(df).mark_point().encode(opacity=alt.value(0.5)) - for channel in chart.to_dict()['encoding']: - data = _data._locate_channel_data(chart, channel) - assert data == 0.5 - - -@pytest.mark.parametrize("column", ['a', 'b', 'c']) -def test_data_aggregate_quantitative_fail(column): - """"'Passes' if it raises a NotImplementedError""" - chart = alt.Chart(df).mark_point().encode(alt.X(field=column, type='quantitative', aggregate='average')) - for channel in chart.to_dict()['encoding']: - with pytest.raises(NotImplementedError): - data = _data._locate_channel_data(chart, channel) - - -def test_data_timeUnit_shorthand_temporal_fail(): - chart = alt.Chart(df).mark_point().encode(alt.X('month(combination):T')) - for channel in chart.to_dict()['encoding']: - with pytest.raises(NotImplementedError): - data = _data._locate_channel_data(chart, channel) - - -def test_data_timeUnit_field_temporal_fail(): - """"'Passes' if it raises a NotImplementedError""" - chart = alt.Chart(df).mark_point().encode(alt.X(field='combination', type='temporal', timeUnit='month')) - for channel in chart.to_dict()['encoding']: - with pytest.raises(NotImplementedError): - data = _data._locate_channel_data(chart, channel) - - -# _locate_channel_dtype() tests - -@pytest.mark.parametrize('column, expected', [ - ('a:Q', 'quantitative'), ('nom:N', 'nominal'), ('ord:O', 'ordinal'), ('combination:T', 'temporal') -]) -def test_data_dtype(column, expected): - chart = alt.Chart(df).mark_point().encode(alt.X(column)) - for channel in chart.to_dict()['encoding']: - dtype = _data._locate_channel_dtype(chart, channel) - assert dtype == expected - - -def test_data_dtype_fail(): - """"'Passes' if it raises a NotImplementedError""" - chart = alt.Chart(df).mark_point().encode(opacity=alt.value(.5)) - for channel in chart.to_dict()['encoding']: - with pytest.raises(NotImplementedError): - dtype = _data._locate_channel_dtype(chart, channel) - # test date conversion: df_nonstandard = pd.DataFrame({ diff --git a/mplaltair/tests/test_parse_chart.py b/mplaltair/tests/test_parse_chart.py new file mode 100644 index 0000000..06beb65 --- /dev/null +++ b/mplaltair/tests/test_parse_chart.py @@ -0,0 +1,82 @@ +import pytest +import altair as alt +import pandas as pd +import mplaltair.parse_chart as parse_chart +from .._data import _convert_to_mpl_date + +df = pd.DataFrame({ + "a": [1, 2, 3, 4, 5], "b": [1.1, 2.2, 3.3, 4.4, 5.5], "c": [1, 2.2, 3, 4.4, 5], + "nom": ['a', 'b', 'c', 'd', 'e'], "ord": [1, 2, 3, 4, 5], + "years": pd.date_range('01/01/2015', periods=5, freq='Y'), "months": pd.date_range('1/1/2015', periods=5, freq='M'), + "days": pd.date_range('1/1/2015', periods=5, freq='D'), "hrs": pd.date_range('1/1/2015', periods=5, freq='H'), + "combination": pd.to_datetime(['1/1/2015', '1/1/2015 10:00:00', '1/2/2015 00:00', '1/4/2016 10:00', '5/1/2016']), + "quantitative": [1.1, 2.1, 3.1, 4.1, 5.1] +}) + + +# _locate_channel_data() tests + +@pytest.mark.parametrize("column, dtype", [ + ('a', 'quantitative'), ('b', 'quantitative'), ('c', 'quantitative'), ('combination', 'temporal') +]) +def test_data_field_quantitative(column, dtype): + chart = alt.Chart(df).mark_point().encode(alt.X(field=column, type=dtype)) + chart = parse_chart.ChartMetadata(chart) + for channel in chart.encoding: + data = chart.encoding[channel].data + if chart.encoding[channel].type == 'temporal': + assert list(data) == list(_convert_to_mpl_date(df[column].values)) + else: + assert list(data) == list(df[column].values) + + +@pytest.mark.parametrize("column", ['a', 'b', 'c', 'combination']) +def test_data_shorthand_quantitative(column): + chart = alt.Chart(df).mark_point().encode(alt.X(column)) + chart = parse_chart.ChartMetadata(chart) + for channel in chart.encoding: + data = chart.encoding[channel].data + if chart.encoding[channel].type == 'temporal': + assert list(data) == list(_convert_to_mpl_date(df[column].values)) + else: + assert list(data) == list(df[column].values) + + +@pytest.mark.parametrize("column", ['a', 'b', 'c']) +def test_data_aggregate_quantitative_fail(column): + """"'Passes' if it raises a NotImplementedError""" + chart = alt.Chart(df).mark_point().encode(alt.X(field=column, type='quantitative', aggregate='average')) + with pytest.raises(NotImplementedError): + chart = parse_chart.ChartMetadata(chart) + + +def test_data_timeUnit_shorthand_temporal_fail(): + chart = alt.Chart(df).mark_point().encode(alt.X('month(combination):T')) + with pytest.raises(NotImplementedError): + chart = parse_chart.ChartMetadata(chart) + + +def test_data_timeUnit_field_temporal_fail(): + """"'Passes' if it raises a NotImplementedError""" + chart = alt.Chart(df).mark_point().encode(alt.X(field='combination', type='temporal', timeUnit='month')) + with pytest.raises(NotImplementedError): + chart = parse_chart.ChartMetadata(chart) + + +# _locate_channel_dtype() tests + +@pytest.mark.parametrize('column, expected', [ + ('a:Q', 'quantitative'), ('nom:N', 'nominal'), ('ord:O', 'ordinal'), ('combination:T', 'temporal') +]) +def test_data_dtype(column, expected): + chart = alt.Chart(df).mark_point().encode(alt.X(column)) + chart = parse_chart.ChartMetadata(chart) + for channel in chart.encoding: + assert chart.encoding[channel].type == expected + + +def test_data_dtype_fail(): + """"'Passes' if it raises a NotImplementedError""" + chart = alt.Chart(df).mark_point().encode(opacity=alt.value(.5)) + with pytest.raises(NotImplementedError): + chart = parse_chart.ChartMetadata(chart) \ No newline at end of file