diff --git a/mplaltair/__init__.py b/mplaltair/__init__.py index 1f3d019..4d9fcfd 100644 --- a/mplaltair/__init__.py +++ b/mplaltair/__init__.py @@ -1,6 +1,10 @@ import matplotlib import altair +import matplotlib.pyplot as plt from ._convert import _convert +from ._data import _normalize_data +from ._axis import convert_axis +from ._marks import _handle_line def convert(chart): @@ -14,10 +18,24 @@ def convert(chart): Returns ------- - mapping : dict - Mapping from parts of the encoding to the Matplotlib artists. This is - for later customization. + fig : matplotlib.figure + ax : matplotlib.axes """ - return _convert(chart) + + fig, ax = plt.subplots() + _normalize_data(chart) + + if chart.mark in ['point', 'circle', 'square']: # scatter + mapping = _convert(chart) + ax.scatter(**mapping) + elif chart.mark == 'line': # line + _handle_line(chart, ax) + else: + raise NotImplementedError + + convert_axis(ax, chart) + fig.tight_layout() + + return fig, ax diff --git a/mplaltair/_axis.py b/mplaltair/_axis.py index 14c7d54..b501dcd 100644 --- a/mplaltair/_axis.py +++ b/mplaltair/_axis.py @@ -33,13 +33,19 @@ def _set_limits(channel, scale): elif 'type' in scale and scale['type'] != 'linear': lims = _set_scale_type(channel, scale) else: - # Check that a positive minimum is zero if zero is True: - if ('zero' not in scale or scale['zero'] == True) and min(channel['data']) > 0: - lims[_axis_kwargs[channel['axis']].get('min')] = 0 # quantitative sets min to be 0 by default + # Include zero on the axis (or not). + # In Altair, scale.zero defaults to False unless the data is unbinned quantitative. + if channel['mark'] == 'line' and channel['axis'] == 'x': + # Contrary to documentation, Altair defaults to scale.zero=False for the x-axis on line graphs. + pass + else: + # Check that a positive minimum is zero if scale.zero is True: + if ('zero' not in scale or scale['zero'] == True) and min(channel['data']) > 0: + lims[_axis_kwargs[channel['axis']].get('min')] = 0 # quantitative sets min to be 0 by default - # Check that a negative maximum is zero if zero is True: - if ('zero' not in scale or scale['zero'] == True) and max(channel['data']) < 0: - lims[_axis_kwargs[channel['axis']].get('max')] = 0 + # Check that a negative maximum is zero if scale.zero is True: + if ('zero' not in scale or scale['zero'] == True) and max(channel['data']) < 0: + lims[_axis_kwargs[channel['axis']].get('max')] = 0 elif channel['dtype'] == 'temporal': # determine limits @@ -221,9 +227,8 @@ def convert_axis(ax, chart): if channel in ['x', 'y']: chart_info = {'ax': ax, 'axis': channel, 'data': _locate_channel_data(chart, channel), - 'dtype': _locate_channel_dtype(chart, channel)} - if chart_info['dtype'] == 'temporal': - chart_info['data'] = _convert_to_mpl_date(chart_info['data']) + 'dtype': _locate_channel_dtype(chart, channel), + 'mark': chart.mark} scale_info = _locate_channel_scale(chart, channel) axis_info = _locate_channel_axis(chart, channel) diff --git a/mplaltair/_convert.py b/mplaltair/_convert.py index b74d99d..a63da30 100644 --- a/mplaltair/_convert.py +++ b/mplaltair/_convert.py @@ -1,6 +1,4 @@ -import matplotlib.dates as mdates -import numpy as np -from ._data import _locate_channel_data, _locate_channel_dtype, _convert_to_mpl_date +from ._data import _locate_channel_data, _locate_channel_dtype def _allowed_ranged_marks(enc_channel, mark): """TODO: DOCS @@ -80,6 +78,7 @@ def _process_stroke(dtype, data): """ raise NotImplementedError + _mappings = { 'x': _process_x, 'y': _process_y, @@ -120,9 +119,7 @@ def _convert(chart): for channel in chart.to_dict()['encoding']: data = _locate_channel_data(chart, channel) dtype = _locate_channel_dtype(chart, channel) - if dtype == 'temporal': - data = _convert_to_mpl_date(data) mapping[_mappings[channel](dtype, data)[0]] = _mappings[channel](dtype, data)[1] - + return mapping diff --git a/mplaltair/_data.py b/mplaltair/_data.py index 2a29e49..cdd4019 100644 --- a/mplaltair/_data.py +++ b/mplaltair/_data.py @@ -1,4 +1,6 @@ +import pandas as pd from ._exceptions import ValidationError +from ._utils import _fetch import matplotlib.dates as mdates import matplotlib.cbook as cbook from datetime import datetime @@ -51,13 +53,19 @@ def _locate_channel_data(chart, channel): channel_val = chart.to_dict()['encoding'][channel] if channel_val.get('value'): - return channel_val.get('value') + data = channel_val.get('value') elif channel_val.get('aggregate'): - return _aggregate_channel() + data = _aggregate_channel() elif channel_val.get('timeUnit'): - return _handle_timeUnit() + data = _handle_timeUnit() else: # field is required if the above are not present. - return chart.data[channel_val.get('field')].values + data = chart.data[channel_val.get('field')].values + + # Take care of temporal conversion immediately + if _locate_channel_dtype(chart, channel) == 'temporal': + return _convert_to_mpl_date(data) + else: + return data def _aggregate_channel(): @@ -109,6 +117,35 @@ def _locate_channel_axis(chart, channel): else: return {} + +def _locate_channel_field(chart, channel): + return chart.to_dict()['encoding'][channel]['field'] + + +# FROM ENCODINGS======================================================================================================= +def _normalize_data(chart): + """Converts the data to a Pandas dataframe. Originally Nabarun's code (PR #5). + + Parameters + ---------- + chart : altair.Chart + The Altair chart object + """ + spec = chart.to_dict() + if not spec['data']: + raise ValidationError('Please specify a data source.') + + if spec['data'].get('url'): + df = pd.DataFrame(_fetch(spec['data']['url'])) + elif spec['data'].get('values'): + return + else: + raise NotImplementedError('Given data specification is unsupported at the moment.') + + chart.data = df +# END STUFF FROM ENCODINGS============================================================================================= + + def _convert_to_mpl_date(data): """Converts datetime, datetime64, strings, and Altair DateTime objects to Matplotlib dates. @@ -127,7 +164,7 @@ def _convert_to_mpl_date(data): if len(data) == 0: return [] else: - return [_convert_to_mpl_date(i) for i in data] + return np.asarray([_convert_to_mpl_date(i) for i in data]) else: if isinstance(data, str): # string format for dates data = mdates.datestr2num(data) @@ -153,9 +190,8 @@ def _altair_DateTime_to_datetime(dt): A datetime object """ MONTHS = {'Jan': 1, 'January': 1, 'Feb': 2, 'February': 2, 'Mar': 3, 'March': 3, 'Apr': 4, 'April': 4, - 'May': 5, 'May': 5, 'Jun': 6, 'June': 6, 'Jul': 7, 'July': 7, 'Aug': 8, 'August': 8, - 'Sep': 9, 'Sept': 9, 'September': 9, 'Oct': 10, 'October': 10, 'Nov': 11, 'November': 11, - 'Dec': 12, 'December': 12} + 'May': 5, 'Jun': 6, 'June': 6, 'Jul': 7, 'July': 7, 'Aug': 8, 'August': 8, 'Sep': 9, 'Sept': 9, + 'September': 9, 'Oct': 10, 'October': 10, 'Nov': 11, 'November': 11, 'Dec': 12, 'December': 12} alt_to_datetime_kw_mapping = {'date': 'day', 'hours': 'hour', 'milliseconds': 'microsecond', 'minutes': 'minute', 'month': 'month', 'seconds': 'second', 'year': 'year'} diff --git a/mplaltair/_marks.py b/mplaltair/_marks.py new file mode 100644 index 0000000..05631c1 --- /dev/null +++ b/mplaltair/_marks.py @@ -0,0 +1,79 @@ +import matplotlib +import numpy as np +from ._data import _locate_channel_field, _locate_channel_data, _locate_channel_dtype, _convert_to_mpl_date + + +def _handle_line(chart, ax): + """Convert encodings, manipulate data if needed, and plot the line chart on an axes. + + Parameters + ---------- + chart : altair.Chart + The Altair chart object + + ax : matplotlib.axes + The Matplotlib axes object + + Notes + ----- + Fill isn't necessary until mpl-altair can handle multiple plot types in one plot. + Size is unsupported by both Matplotlib and Altair. + When both Color and Stroke are provided, color is ignored and stroke is used. + Shape is unsupported in line graphs unless another plot type is plotted at the same time. + """ + + groupbys = [] + kwargs = {} + + if 'opacity' in chart.to_dict()['encoding']: + groupbys.append('opacity') + + if 'stroke' in chart.to_dict()['encoding']: + groupbys.append('stroke') + elif 'color' in chart.to_dict()['encoding']: + groupbys.append('color') + + list_fields = lambda c, g: [_locate_channel_field(c, i) for i in g] + if len(groupbys) > 0: + for label, subset in chart.data.groupby(list_fields(chart, groupbys)): + if 'opacity' in groupbys: + kwargs['alpha'] = _opacity_norm(chart, _locate_channel_dtype(chart, 'opacity'), + subset[_locate_channel_field(chart, 'opacity')].iloc[0]) + + if 'color' not in groupbys and 'stroke' not in groupbys: + kwargs['color'] = matplotlib.rcParams['lines.color'] + ax.plot(subset[_locate_channel_field(chart, 'x')], subset[_locate_channel_field(chart, 'y')], **kwargs) + else: + ax.plot(_locate_channel_data(chart, 'x'), _locate_channel_data(chart, 'y')) + + +def _opacity_norm(chart, dtype, val): + """ + Normalize the values of a column to be between 0.15 and 1, which is a visible range for opacity. + + Parameters + ---------- + chart : altair.Chart + The Altair chart object + dtype : str + The data type of the column ('quantitative', 'nominal', 'ordinal', or 'temporal') + val + The specific value to be normalized. + + Returns + ------- + The normalized value (between 0.15 and 1) + """ + arr = _locate_channel_data(chart, 'opacity') + if dtype in ['ordinal', 'nominal', 'temporal']: + # map categoricals to numbers + unique, indices = np.unique(arr, return_inverse=True) + arr = indices + if dtype == 'temporal': + val = unique.tolist().index(_convert_to_mpl_date(val)) + else: + val = unique.tolist().index(val) + data_min = arr.min() + data_max = arr.max() + desired_min, desired_max = (0.15, 1) # Chosen so that the minimum value is visible + return ((val - data_min) / (data_max - data_min)) * (desired_max - desired_min) + desired_min \ No newline at end of file diff --git a/mplaltair/_utils.py b/mplaltair/_utils.py new file mode 100644 index 0000000..831e226 --- /dev/null +++ b/mplaltair/_utils.py @@ -0,0 +1,44 @@ +# FROM ENCODINGS======================================================================================================= +from urllib.request import urlopen +from urllib.error import HTTPError + +import pandas as pd + +_PD_READERS = { + 'json': pd.read_json, + 'csv': pd.read_csv +} + +def _get_format(url): + """Gives back the format of the file from url + + WARNING: It might break. Trying to find a better way. + """ + return url.split('.')[-1] + +def _fetch(url): + """Downloads the file from the given url as a Pandas DataFrame + + Parameters + ---------- + url : string + URL of the file to be downloaded + + Returns + ------- + pd.DataFrame + Data in the format of a DataFrame + + Raises + ------ + NotImplementedError + Raises when an unsupported file format is given as an URL + """ + try: + ext = _get_format(url) + reader = _PD_READERS[ext] + df = reader(urlopen(url).read()) + except KeyError: + raise NotImplementedError('File format not implemented') + return df +# END OF STUFF FROM ENCODINGS========================================================================================== diff --git a/mplaltair/tests/baseline_images/test_convert/test_line.png b/mplaltair/tests/baseline_images/test_convert/test_line.png new file mode 100644 index 0000000..b49cdcd Binary files /dev/null and b/mplaltair/tests/baseline_images/test_convert/test_line.png differ diff --git a/mplaltair/tests/baseline_images/test_convert/test_line_color.png b/mplaltair/tests/baseline_images/test_convert/test_line_color.png new file mode 100644 index 0000000..203b297 Binary files /dev/null and b/mplaltair/tests/baseline_images/test_convert/test_line_color.png differ diff --git a/mplaltair/tests/baseline_images/test_convert/test_line_opacity_c:O.png b/mplaltair/tests/baseline_images/test_convert/test_line_opacity_c:O.png new file mode 100644 index 0000000..703d21c Binary files /dev/null and b/mplaltair/tests/baseline_images/test_convert/test_line_opacity_c:O.png differ diff --git a/mplaltair/tests/baseline_images/test_convert/test_line_opacity_color_c:N-c:O.png b/mplaltair/tests/baseline_images/test_convert/test_line_opacity_color_c:N-c:O.png new file mode 100644 index 0000000..d57ee36 Binary files /dev/null and b/mplaltair/tests/baseline_images/test_convert/test_line_opacity_color_c:N-c:O.png differ diff --git a/mplaltair/tests/baseline_images/test_convert/test_line_opacity_color_d:Q-d:Q.png b/mplaltair/tests/baseline_images/test_convert/test_line_opacity_color_d:Q-d:Q.png new file mode 100644 index 0000000..d57ee36 Binary files /dev/null and b/mplaltair/tests/baseline_images/test_convert/test_line_opacity_color_d:Q-d:Q.png differ diff --git a/mplaltair/tests/baseline_images/test_convert/test_line_opacity_color_dates:T-dates:T.png b/mplaltair/tests/baseline_images/test_convert/test_line_opacity_color_dates:T-dates:T.png new file mode 100644 index 0000000..d57ee36 Binary files /dev/null and b/mplaltair/tests/baseline_images/test_convert/test_line_opacity_color_dates:T-dates:T.png differ diff --git a/mplaltair/tests/baseline_images/test_convert/test_line_opacity_d:Q.png b/mplaltair/tests/baseline_images/test_convert/test_line_opacity_d:Q.png new file mode 100644 index 0000000..703d21c Binary files /dev/null and b/mplaltair/tests/baseline_images/test_convert/test_line_opacity_d:Q.png differ diff --git a/mplaltair/tests/baseline_images/test_convert/test_line_opacity_dates:T.png b/mplaltair/tests/baseline_images/test_convert/test_line_opacity_dates:T.png new file mode 100644 index 0000000..703d21c Binary files /dev/null and b/mplaltair/tests/baseline_images/test_convert/test_line_opacity_dates:T.png differ diff --git a/mplaltair/tests/baseline_images/test_convert/test_line_stroke.png b/mplaltair/tests/baseline_images/test_convert/test_line_stroke.png new file mode 100644 index 0000000..ef08b0b Binary files /dev/null and b/mplaltair/tests/baseline_images/test_convert/test_line_stroke.png differ diff --git a/mplaltair/tests/baseline_images/test_convert/test_line_stroke_a:Q-b:Q-d:Q.png b/mplaltair/tests/baseline_images/test_convert/test_line_stroke_a:Q-b:Q-d:Q.png new file mode 100644 index 0000000..203b297 Binary files /dev/null and b/mplaltair/tests/baseline_images/test_convert/test_line_stroke_a:Q-b:Q-d:Q.png differ diff --git a/mplaltair/tests/baseline_images/test_convert/test_quantitative_scatter_chart0.png b/mplaltair/tests/baseline_images/test_convert/test_quantitative_scatter_chart0.png index 793f334..eb8c701 100644 Binary files a/mplaltair/tests/baseline_images/test_convert/test_quantitative_scatter_chart0.png and b/mplaltair/tests/baseline_images/test_convert/test_quantitative_scatter_chart0.png differ diff --git a/mplaltair/tests/baseline_images/test_convert/test_quantitative_scatter_chart1.png b/mplaltair/tests/baseline_images/test_convert/test_quantitative_scatter_chart1.png index 1f3073c..3024b93 100644 Binary files a/mplaltair/tests/baseline_images/test_convert/test_quantitative_scatter_chart1.png and b/mplaltair/tests/baseline_images/test_convert/test_quantitative_scatter_chart1.png differ diff --git a/mplaltair/tests/baseline_images/test_convert/test_scatter_temporal_channel0.png b/mplaltair/tests/baseline_images/test_convert/test_scatter_temporal_channel0.png index 6ec5288..a8f76b8 100644 Binary files a/mplaltair/tests/baseline_images/test_convert/test_scatter_temporal_channel0.png and b/mplaltair/tests/baseline_images/test_convert/test_scatter_temporal_channel0.png differ diff --git a/mplaltair/tests/baseline_images/test_convert/test_scatter_temporal_channel1.png b/mplaltair/tests/baseline_images/test_convert/test_scatter_temporal_channel1.png index 6ec5288..a8f76b8 100644 Binary files a/mplaltair/tests/baseline_images/test_convert/test_scatter_temporal_channel1.png and b/mplaltair/tests/baseline_images/test_convert/test_scatter_temporal_channel1.png differ diff --git a/mplaltair/tests/test_axis.py b/mplaltair/tests/test_axis.py index dfe23cb..8f65b70 100644 --- a/mplaltair/tests/test_axis.py +++ b/mplaltair/tests/test_axis.py @@ -3,7 +3,6 @@ import matplotlib.pyplot as plt import pandas as pd from mplaltair import convert -from .._axis import convert_axis import pytest df_quant = pd.DataFrame({ @@ -26,18 +25,13 @@ @pytest.mark.xfail(raises=TypeError) def test_invalid_temporal(): chart = alt.Chart(df_temp).mark_point().encode(alt.X('a:T')) - fig, ax = plt.subplots() - convert_axis(ax, chart) + convert(chart) @pytest.mark.mpl_image_compare(baseline_dir='baseline_images/test_axis') def test_axis_more_than_x_and_y(): chart = alt.Chart(df_quant).mark_point().encode(alt.X('a'), alt.Y('b'), color=alt.Color('c')) - mapping = convert(chart) - fig, ax = plt.subplots() - ax.scatter(**mapping) - convert_axis(ax, chart) - fig.tight_layout() + fig, ax = convert(chart) return fig @@ -46,11 +40,7 @@ def test_axis_more_than_x_and_y(): (df_temp, 'months', 'years'), (df_temp, 'years', 'months'), (df_temp, 'months', 'combination')]) def test_axis(df, x, y): chart = alt.Chart(df).mark_point().encode(alt.X(x), alt.Y(y)) - mapping = convert(chart) - fig, ax = plt.subplots() - ax.scatter(**mapping) - convert_axis(ax, chart) - fig.tight_layout() + fig, ax = convert(chart) return fig @@ -71,11 +61,7 @@ def test_axis_zero_quantitative(x, y, zero): alt.X(x, scale=alt.Scale(zero=zero)), alt.Y(y, scale=alt.Scale(zero=zero)) ) - mapping = convert(chart) - fig, ax = plt.subplots() - ax.scatter(**mapping) - convert_axis(ax, chart) - fig.tight_layout() + fig, ax = convert(chart) return fig @@ -92,11 +78,7 @@ def test_axis_domain(df, x, y, x_dom, y_dom): alt.X(x, scale=alt.Scale(domain=x_dom)), alt.Y(y, scale=alt.Scale(domain=y_dom)) ) - mapping = convert(chart) - fig, ax = plt.subplots() - ax.scatter(**mapping) - convert_axis(ax, chart) - fig.tight_layout() + fig, ax = convert(chart) return fig @@ -106,11 +88,7 @@ def test_axis_unaggregated_quantitative(): alt.X('a', scale=alt.Scale(domain="unaggregated")), alt.Y('c', scale=alt.Scale(domain="unaggregated")) ) - mapping = convert(chart) - fig, ax = plt.subplots() - ax.scatter(**mapping) - plt.close() - convert_axis(ax, chart) + convert(chart) @pytest.mark.mpl_image_compare(baseline_dir='baseline_images/test_axis') @@ -124,11 +102,7 @@ def test_axis_values(df, y, vals): alt.X('a', axis=alt.Axis(values=[-1, 1, 1.5, 2.125, 3])), alt.Y(y, axis=alt.Axis(values=vals)) ) - mapping = convert(chart) - fig, ax = plt.subplots() - ax.scatter(**mapping) - convert_axis(ax, chart) - fig.tight_layout() + fig, ax = convert(chart) return fig @@ -141,11 +115,7 @@ def test_axis_tickCount(df, x, tickCount): chart = alt.Chart(df).mark_point().encode( alt.X(x, axis=alt.Axis(tickCount=tickCount)), alt.Y('a', axis=alt.Axis(tickCount=tickCount)) ) - mapping = convert(chart) - fig, ax = plt.subplots() - ax.scatter(**mapping) - convert_axis(ax, chart) - fig.tight_layout() + fig, ax = convert(chart) return fig @@ -156,11 +126,7 @@ def test_axis_scale_basic(df, column, scale_type): alt.X(column, scale=alt.Scale(type=scale_type)), alt.Y('a') ) - mapping = convert(chart) - fig, ax = plt.subplots() - ax.scatter(**mapping) - convert_axis(ax, chart) - fig.tight_layout() + fig, ax = convert(chart) return fig @@ -171,11 +137,7 @@ def test_axis_scale_type_x_quantitative(column, type, base, exponent): alt.X(column, scale=alt.Scale(type=type, base=base, exponent=exponent)), alt.Y('a') ) - mapping = convert(chart) - fig, ax = plt.subplots() - ax.scatter(**mapping) - convert_axis(ax, chart) - fig.tight_layout() + fig, ax = convert(chart) return fig @@ -186,11 +148,8 @@ def test_axis_scale_type_y_quantitative(column, type, base, exponent): alt.X('a'), alt.Y(column, scale=alt.Scale(type=type, base=base, exponent=exponent)) ) - mapping = convert(chart) - fig, ax = plt.subplots() - ax.scatter(**mapping) - convert_axis(ax, chart) - fig.tight_layout() + fig, ax = convert(chart) + plt.show() return fig @@ -202,10 +161,7 @@ def test_axis_scale_NotImplemented_quantitative(df, x, type): alt.X(x, scale=alt.Scale(type=type)), alt.Y('a') ) - mapping = convert(chart) - fig, ax = plt.subplots() - ax.scatter(**mapping) - convert_axis(ax, chart) + convert(chart) @pytest.mark.mpl_image_compare(baseline_dir='baseline_images/test_axis') @@ -217,11 +173,7 @@ def test_axis_formatter(df, x, y, format_x, format_y): alt.X(x, axis=alt.Axis(format=format_x)), alt.Y(y, axis=alt.Axis(format=format_y)) ) - mapping = convert(chart) - fig, ax = plt.subplots() - ax.scatter(**mapping) - convert_axis(ax, chart) - fig.tight_layout() + fig, ax = convert(chart) return fig @@ -233,11 +185,7 @@ def test_axis_formatter_temporal(): alt.X('months:T', axis=alt.Axis(format='%b %Y')), alt.Y('hrs:T', axis=alt.Axis(format='%H:%M:%S')) ) - mapping = convert(chart) - fig, ax = plt.subplots() - ax.scatter(**mapping) - convert_axis(ax, chart) - fig.tight_layout() + fig, ax = convert(chart) return fig @@ -247,7 +195,4 @@ def test_axis_formatter_fail(): alt.X('c', axis=alt.Axis(format='-$.2g')), alt.Y('b', axis=alt.Axis(format='+.3r')) ) - mapping = convert(chart) - fig, ax = plt.subplots() - ax.scatter(**mapping) - convert_axis(ax, chart) + convert(chart) diff --git a/mplaltair/tests/test_convert.py b/mplaltair/tests/test_convert.py index 3fe070a..9360303 100644 --- a/mplaltair/tests/test_convert.py +++ b/mplaltair/tests/test_convert.py @@ -2,10 +2,10 @@ import altair as alt import pandas as pd -import matplotlib.pyplot as plt import matplotlib.dates as mdates - +import matplotlib.pyplot as plt from mplaltair import convert +from mplaltair._convert import _convert df = pd.DataFrame({ @@ -25,172 +25,168 @@ def test_encoding_not_provided(): chart_spec = alt.Chart(df).mark_point() with pytest.raises(ValueError): - convert(chart_spec) + _convert(chart_spec) def test_invalid_encodings(): chart_spec = alt.Chart(df).encode(x2='quant').mark_point() with pytest.raises(ValueError): - convert(chart_spec) + _convert(chart_spec) @pytest.mark.xfail(raises=TypeError) def test_invalid_temporal(): chart = alt.Chart(df).mark_point().encode(alt.X('quant:T')) - convert(chart) + _convert(chart) @pytest.mark.parametrize('channel', ['quant', 'ord', 'nom']) def test_convert_x_success(channel): chart_spec = alt.Chart(df).encode(x=channel).mark_point() - mapping = convert(chart_spec) + mapping = _convert(chart_spec) assert list(mapping['x']) == list(df[channel].values) @pytest.mark.parametrize("column", ["years", "months", "days", "hrs", "combination"]) def test_convert_x_success_temporal(column): chart = alt.Chart(df).mark_point().encode(alt.X(column)) - mapping = convert(chart) + mapping = _convert(chart) assert list(mapping['x']) == list(mdates.date2num(df[column].values)) - # assert list(mapping['x']) == list(df[column].values) def test_convert_x_fail(): chart_spec = alt.Chart(df).encode(x='b:N').mark_point() with pytest.raises(KeyError): - convert(chart_spec) + _convert(chart_spec) @pytest.mark.parametrize('channel', ['quant', 'ord', 'nom']) def test_convert_y_success(channel): chart_spec = alt.Chart(df).encode(y=channel).mark_point() - mapping = convert(chart_spec) + mapping = _convert(chart_spec) assert list(mapping['y']) == list(df[channel].values) @pytest.mark.parametrize("column", ["years", "months", "days", "hrs", "combination"]) def test_convert_y_success_temporal(column): chart = alt.Chart(df).mark_point().encode(alt.Y(column)) - mapping = convert(chart) + mapping = _convert(chart) assert list(mapping['y']) == list(mdates.date2num(df[column].values)) - # assert list(mapping['y']) == list(df[column].values) def test_convert_y_fail(): chart_spec = alt.Chart(df).encode(y='b:N').mark_point() with pytest.raises(KeyError): - convert(chart_spec) + _convert(chart_spec) @pytest.mark.xfail(raises=ValueError, reason="It doesn't make sense to have x2 and y2 on scatter plots") def test_quantitative_x2_y2(): chart = alt.Chart(df_quant).mark_point().encode(alt.X('a'), alt.Y('b'), alt.X2('c'), alt.Y2('alpha')) - convert(chart) + _convert(chart) @pytest.mark.xfail(raises=ValueError) @pytest.mark.parametrize("column", ["years", "months", "days", "hrs", "combination"]) def test_convert_x2_y2_fail_temporal(column): chart = alt.Chart(df).mark_point().encode(alt.X2(column), alt.Y2(column)) - convert(chart) + _convert(chart) @pytest.mark.parametrize('channel,dtype', [('quant','quantitative'), ('ord','ordinal')]) def test_convert_color_success(channel, dtype): chart_spec = alt.Chart(df).encode(color=alt.Color(field=channel, type=dtype)).mark_point() - mapping = convert(chart_spec) + mapping = _convert(chart_spec) assert list(mapping['c']) == list(df[channel].values) def test_convert_color_success_nominal(): chart_spec = alt.Chart(df).encode(color='nom').mark_point() with pytest.raises(NotImplementedError): - convert(chart_spec) + _convert(chart_spec) @pytest.mark.parametrize("column", ["years", "months", "days", "hrs", "combination"]) def test_convert_color_success_temporal(column): chart = alt.Chart(df).mark_point().encode(alt.Color(column)) - mapping = convert(chart) + mapping = _convert(chart) assert list(mapping['c']) == list(mdates.date2num(df[column].values)) - # assert list(mapping['c']) == list(df[column].values) def test_convert_color_fail(): chart_spec = alt.Chart(df).encode(color='b:N').mark_point() with pytest.raises(KeyError): - convert(chart_spec) + _convert(chart_spec) @pytest.mark.parametrize('channel,type', [('quant', 'Q'), ('ord', 'O')]) def test_convert_fill(channel, type): chart_spec = alt.Chart(df).encode(fill='{}:{}'.format(channel, type)).mark_point() - mapping = convert(chart_spec) + mapping = _convert(chart_spec) assert list(mapping['c']) == list(df[channel].values) def test_convert_fill_success_nominal(): chart_spec = alt.Chart(df).encode(fill='nom').mark_point() with pytest.raises(NotImplementedError): - convert(chart_spec) + _convert(chart_spec) @pytest.mark.parametrize("column", ["years", "months", "days", "hrs", "combination"]) def test_convert_fill_success_temporal(column): chart = alt.Chart(df).mark_point().encode(alt.Fill(column)) - mapping = convert(chart) + mapping = _convert(chart) assert list(mapping['c']) == list(mdates.date2num(df[column].values)) - # assert list(mapping['c']) == list(df[column].values) def test_convert_fill_fail(): chart_spec = alt.Chart(df).encode(fill='b:N').mark_point() with pytest.raises(KeyError): - convert(chart_spec) + _convert(chart_spec) @pytest.mark.xfail(raises=NotImplementedError, reason="The marker argument in scatter() cannot take arrays") def test_quantitative_shape(): chart = alt.Chart(df_quant).mark_point().encode(alt.Shape('shape')) - mapping = convert(chart) + mapping = _convert(chart) @pytest.mark.xfail(raises=NotImplementedError, reason="The marker argument in scatter() cannot take arrays") @pytest.mark.parametrize("column", ["years", "months", "days", "hrs", "combination"]) def test_convert_shape_fail_temporal(column): chart = alt.Chart(df).mark_point().encode(alt.Shape(column)) - mapping = convert(chart) + mapping = _convert(chart) @pytest.mark.xfail(raises=NotImplementedError, reason="Merge: the dtype for opacity isn't assumed to be quantitative") def test_quantitative_opacity_value(): chart = alt.Chart(df_quant).mark_point().encode(opacity=alt.value(.5)) - mapping = convert(chart) + mapping = _convert(chart) @pytest.mark.xfail(raises=NotImplementedError, reason="The alpha argument in scatter() cannot take arrays") def test_quantitative_opacity_array(): chart = alt.Chart(df_quant).mark_point().encode(alt.Opacity('alpha')) - convert(chart) + _convert(chart) @pytest.mark.xfail(raises=NotImplementedError, reason="The alpha argument in scatter() cannot take arrays") @pytest.mark.parametrize("column", ["years", "months", "days", "hrs", "combination"]) def test_convert_opacity_fail_temporal(column): chart = alt.Chart(df).mark_point().encode(alt.Opacity(column)) - convert(chart) + _convert(chart) @pytest.mark.parametrize('channel,type', [('quant', 'Q'), ('ord', 'O')]) def test_convert_size_success(channel, type): chart_spec = alt.Chart(df).encode(size='{}:{}'.format(channel, type)).mark_point() - mapping = convert(chart_spec) + mapping = _convert(chart_spec) assert list(mapping['s']) == list(df[channel].values) def test_convert_size_success_nominal(): chart_spec = alt.Chart(df).encode(size='nom').mark_point() with pytest.raises(NotImplementedError): - convert(chart_spec) + _convert(chart_spec) def test_convert_size_fail(): chart_spec = alt.Chart(df).encode(size='b:N').mark_point() with pytest.raises(KeyError): - convert(chart_spec) + _convert(chart_spec) @pytest.mark.xfail(raises=NotImplementedError, reason="Dates would need to be normalized for the size.") @pytest.mark.parametrize("column", ["years", "months", "days", "hrs", "combination"]) def test_convert_size_fail_temporal(column): chart = alt.Chart(df).mark_point().encode(alt.Size(column)) - convert(chart) + _convert(chart) @pytest.mark.xfail(raises=NotImplementedError, reason="Stroke is not well supported in Altair") def test_quantitative_stroke(): chart = alt.Chart(df_quant).mark_point().encode(alt.Stroke('fill')) - convert(chart) + _convert(chart) @pytest.mark.xfail(raises=NotImplementedError, reason="Stroke is not well defined in Altair") @pytest.mark.parametrize("column", ["years", "months", "days", "hrs", "combination"]) def test_convert_stroke_fail_temporal(column): chart = alt.Chart(df).mark_point().encode(alt.Stroke(column)) - convert(chart) + _convert(chart) # Aggregations @@ -199,12 +195,12 @@ def test_convert_stroke_fail_temporal(column): def test_quantitative_x_count_y(): df_count = pd.DataFrame({"a": [1, 1, 2, 3, 5], "b": [1.4, 1.4, 2.9, 3.18, 5.3]}) chart = alt.Chart(df_count).mark_point().encode(alt.X('a'), alt.Y('count()')) - mapping = convert(chart) + mapping = _convert(chart) @pytest.mark.xfail(raises=NotImplementedError, reason="specifying timeUnit is not supported yet") def test_timeUnit(): chart = alt.Chart(df).mark_point().encode(alt.X('date(combination)')) - convert(chart) + _convert(chart) # Plots @@ -218,17 +214,95 @@ def test_timeUnit(): @pytest.mark.mpl_image_compare(baseline_dir='baseline_images/test_convert') @pytest.mark.parametrize("chart", [chart_quant, chart_fill_quant]) def test_quantitative_scatter(chart): - mapping = convert(chart) - fig, ax = plt.subplots() - ax.scatter(**mapping) + fig, ax = convert(chart) return fig @pytest.mark.mpl_image_compare(baseline_dir='baseline_images/test_convert') @pytest.mark.parametrize("channel", [alt.Color("years"), alt.Fill("years")]) def test_scatter_temporal(channel): - chart = alt.Chart(df).mark_point().encode(alt.X("years"), channel) - mapping = convert(chart) - mapping['y'] = df['quantitative'].values - fig, ax = plt.subplots() - ax.scatter(**mapping) + chart = alt.Chart(df).mark_point().encode( + alt.X("years"), + alt.Y("quantitative"), + channel + ) + fig, ax = convert(chart) return fig + + +# Line plots +df_line = pd.DataFrame({ + 'a': [1, 2, 3, 1, 2, 3, 1, 2, 3], + 'b': [3, 2, 1, 7, 8, 9, 4, 5, 6], + 'c': ['a', 'a', 'a', 'b', 'b', 'b', 'c', 'c', 'c'], + 'd': [1, 1, 1, 2, 2, 2, 3, 3, 3], + 'dates': ['1968-08-01', '1968-08-01', '1968-08-01', '2010-08-08', '2010-08-08', '2010-08-08', '2015-03-14', '2015-03-14', '2015-03-14'] + }) + + +class TestLines(object): + @pytest.mark.mpl_image_compare(baseline_dir='baseline_images/test_convert') + def test_line(self): + chart = alt.Chart(df_line).mark_line().encode( + alt.X('a'), + alt.Y('b'), + ) + fig, _ = convert(chart) + return fig + + @pytest.mark.mpl_image_compare(baseline_dir='baseline_images/test_convert') + @pytest.mark.parametrize('x,y,s', [ + ('a:Q', 'b:Q', 'd:Q'), + pytest.param('a:N', 'b:N', 'c:N', marks=pytest.mark.xfail(raises=NotImplementedError)), + pytest.param('a:O', 'b:O', 'c:O', marks=pytest.mark.xfail(raises=NotImplementedError)) + ]) + def test_line_stroke(self, x, y, s): + chart = alt.Chart(df_line).mark_line().encode( + alt.X(x), + alt.Y(y), + alt.Stroke(s) + ) + fig, _ = convert(chart) + return fig + + @pytest.mark.mpl_image_compare(baseline_dir='baseline_images/test_convert') + def test_line_color(self): + chart = alt.Chart(df_line).mark_line().encode( + alt.X('a'), + alt.Y('b'), + alt.Color('d') + ) + fig, _ = convert(chart) + return fig + + @pytest.mark.mpl_image_compare(baseline_dir='baseline_images/test_convert') + @pytest.mark.parametrize('o', ['d:Q', 'c:O', 'dates:T']) + def test_line_opacity(self, o): + chart = alt.Chart(df_line).mark_line().encode( + alt.X('a'), + alt.Y('b'), + alt.Opacity(o) + ) + fig, ax = convert(chart) + return fig + + @pytest.mark.mpl_image_compare(baseline_dir='baseline_images/test_convert') + @pytest.mark.parametrize('c,o', [('d:Q', 'd:Q'), ('c:N', 'c:O'), ('dates:T', 'dates:T')]) + def test_line_opacity_color(self, c, o): + chart = alt.Chart(df_line).mark_line().encode( + alt.X('a'), + alt.Y('b'), + alt.Color(c), + alt.Opacity(o) + ) + fig, ax = convert(chart) + return fig + + +class TestBars(object): + @pytest.mark.xfail(raises=NotImplementedError) + def test_bar_fail(self): + chart = alt.Chart(df_line).mark_bar().encode( + alt.X('a'), + alt.Y('b'), + ) + convert(chart) diff --git a/mplaltair/tests/test_data.py b/mplaltair/tests/test_data.py index bdd4154..d6b0335 100644 --- a/mplaltair/tests/test_data.py +++ b/mplaltair/tests/test_data.py @@ -23,10 +23,13 @@ def test_data_field_quantitative(column, dtype): chart = alt.Chart(df).mark_point().encode(alt.X(field=column, type=dtype)) for channel in chart.to_dict()['encoding']: data = _data._locate_channel_data(chart, channel) - assert list(data) == list(df[column].values) + if dtype == 'temporal': + assert list(data) == list(_data._convert_to_mpl_date(df[column].values)) + else: + assert list(data) == list(df[column].values) -@pytest.mark.parametrize("column", ['a', 'b', 'c', 'combination']) +@pytest.mark.parametrize("column", ['a', 'b', 'c']) def test_data_shorthand_quantitative(column): chart = alt.Chart(df).mark_point().encode(alt.X(column)) for channel in chart.to_dict()['encoding']: @@ -34,11 +37,11 @@ def test_data_shorthand_quantitative(column): assert list(data) == list(df[column].values) -def test_data_value_quantitative(): - chart = alt.Chart(df).mark_point().encode(opacity=alt.value(0.5)) +def test_data_shorthand_temporal(): + chart = alt.Chart(df).mark_point().encode(alt.X('combination')) for channel in chart.to_dict()['encoding']: data = _data._locate_channel_data(chart, channel) - assert data == 0.5 + assert list(data) == list(_data._convert_to_mpl_date(df['combination'].values)) @pytest.mark.parametrize("column", ['a', 'b', 'c'])