Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 22 additions & 4 deletions mplaltair/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import matplotlib
import altair
import matplotlib.pyplot as plt
from ._convert import _convert
from ._data import _normalize_data
from ._axis import convert_axis
from ._marks import _handle_line


def convert(chart):
Expand All @@ -14,10 +18,24 @@ def convert(chart):

Returns
-------
mapping : dict
Mapping from parts of the encoding to the Matplotlib artists. This is
for later customization.
fig : matplotlib.figure

ax : matplotlib.axes

"""
return _convert(chart)

fig, ax = plt.subplots()
_normalize_data(chart)

if chart.mark in ['point', 'circle', 'square']: # scatter
mapping = _convert(chart)
ax.scatter(**mapping)
elif chart.mark == 'line': # line
_handle_line(chart, ax)
else:
raise NotImplementedError

convert_axis(ax, chart)
fig.tight_layout()

return fig, ax
23 changes: 14 additions & 9 deletions mplaltair/_axis.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,19 @@ def _set_limits(channel, scale):
elif 'type' in scale and scale['type'] != 'linear':
lims = _set_scale_type(channel, scale)
else:
# Check that a positive minimum is zero if zero is True:
if ('zero' not in scale or scale['zero'] == True) and min(channel['data']) > 0:
lims[_axis_kwargs[channel['axis']].get('min')] = 0 # quantitative sets min to be 0 by default
# Include zero on the axis (or not).
# In Altair, scale.zero defaults to False unless the data is unbinned quantitative.
if channel['mark'] == 'line' and channel['axis'] == 'x':
# Contrary to documentation, Altair defaults to scale.zero=False for the x-axis on line graphs.
pass
else:
# Check that a positive minimum is zero if scale.zero is True:
if ('zero' not in scale or scale['zero'] == True) and min(channel['data']) > 0:
lims[_axis_kwargs[channel['axis']].get('min')] = 0 # quantitative sets min to be 0 by default

# Check that a negative maximum is zero if zero is True:
if ('zero' not in scale or scale['zero'] == True) and max(channel['data']) < 0:
lims[_axis_kwargs[channel['axis']].get('max')] = 0
# Check that a negative maximum is zero if scale.zero is True:
if ('zero' not in scale or scale['zero'] == True) and max(channel['data']) < 0:
lims[_axis_kwargs[channel['axis']].get('max')] = 0

elif channel['dtype'] == 'temporal':
# determine limits
Expand Down Expand Up @@ -221,9 +227,8 @@ def convert_axis(ax, chart):
if channel in ['x', 'y']:
chart_info = {'ax': ax, 'axis': channel,
'data': _locate_channel_data(chart, channel),
'dtype': _locate_channel_dtype(chart, channel)}
if chart_info['dtype'] == 'temporal':
chart_info['data'] = _convert_to_mpl_date(chart_info['data'])
'dtype': _locate_channel_dtype(chart, channel),
'mark': chart.mark}

scale_info = _locate_channel_scale(chart, channel)
axis_info = _locate_channel_axis(chart, channel)
Expand Down
9 changes: 3 additions & 6 deletions mplaltair/_convert.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
import matplotlib.dates as mdates
import numpy as np
from ._data import _locate_channel_data, _locate_channel_dtype, _convert_to_mpl_date
from ._data import _locate_channel_data, _locate_channel_dtype

def _allowed_ranged_marks(enc_channel, mark):
"""TODO: DOCS
Expand Down Expand Up @@ -80,6 +78,7 @@ def _process_stroke(dtype, data):
"""
raise NotImplementedError


_mappings = {
'x': _process_x,
'y': _process_y,
Expand Down Expand Up @@ -120,9 +119,7 @@ def _convert(chart):
for channel in chart.to_dict()['encoding']:
data = _locate_channel_data(chart, channel)
dtype = _locate_channel_dtype(chart, channel)
if dtype == 'temporal':
data = _convert_to_mpl_date(data)

mapping[_mappings[channel](dtype, data)[0]] = _mappings[channel](dtype, data)[1]

return mapping
52 changes: 44 additions & 8 deletions mplaltair/_data.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import pandas as pd
from ._exceptions import ValidationError
from ._utils import _fetch
import matplotlib.dates as mdates
import matplotlib.cbook as cbook
from datetime import datetime
Expand Down Expand Up @@ -51,13 +53,19 @@ def _locate_channel_data(chart, channel):

channel_val = chart.to_dict()['encoding'][channel]
if channel_val.get('value'):
return channel_val.get('value')
data = channel_val.get('value')
elif channel_val.get('aggregate'):
return _aggregate_channel()
data = _aggregate_channel()
elif channel_val.get('timeUnit'):
return _handle_timeUnit()
data = _handle_timeUnit()
else: # field is required if the above are not present.
return chart.data[channel_val.get('field')].values
data = chart.data[channel_val.get('field')].values

# Take care of temporal conversion immediately
if _locate_channel_dtype(chart, channel) == 'temporal':
return _convert_to_mpl_date(data)
else:
return data


def _aggregate_channel():
Expand Down Expand Up @@ -109,6 +117,35 @@ def _locate_channel_axis(chart, channel):
else:
return {}


def _locate_channel_field(chart, channel):
return chart.to_dict()['encoding'][channel]['field']


# FROM ENCODINGS=======================================================================================================
def _normalize_data(chart):
"""Converts the data to a Pandas dataframe. Originally Nabarun's code (PR #5).

Parameters
----------
chart : altair.Chart
The Altair chart object
"""
spec = chart.to_dict()
if not spec['data']:
raise ValidationError('Please specify a data source.')

if spec['data'].get('url'):
df = pd.DataFrame(_fetch(spec['data']['url']))
elif spec['data'].get('values'):
return
else:
raise NotImplementedError('Given data specification is unsupported at the moment.')

chart.data = df
# END STUFF FROM ENCODINGS=============================================================================================


def _convert_to_mpl_date(data):
"""Converts datetime, datetime64, strings, and Altair DateTime objects to Matplotlib dates.

Expand All @@ -127,7 +164,7 @@ def _convert_to_mpl_date(data):
if len(data) == 0:
return []
else:
return [_convert_to_mpl_date(i) for i in data]
return np.asarray([_convert_to_mpl_date(i) for i in data])
else:
if isinstance(data, str): # string format for dates
data = mdates.datestr2num(data)
Expand All @@ -153,9 +190,8 @@ def _altair_DateTime_to_datetime(dt):
A datetime object
"""
MONTHS = {'Jan': 1, 'January': 1, 'Feb': 2, 'February': 2, 'Mar': 3, 'March': 3, 'Apr': 4, 'April': 4,
'May': 5, 'May': 5, 'Jun': 6, 'June': 6, 'Jul': 7, 'July': 7, 'Aug': 8, 'August': 8,
'Sep': 9, 'Sept': 9, 'September': 9, 'Oct': 10, 'October': 10, 'Nov': 11, 'November': 11,
'Dec': 12, 'December': 12}
'May': 5, 'Jun': 6, 'June': 6, 'Jul': 7, 'July': 7, 'Aug': 8, 'August': 8, 'Sep': 9, 'Sept': 9,
'September': 9, 'Oct': 10, 'October': 10, 'Nov': 11, 'November': 11, 'Dec': 12, 'December': 12}

alt_to_datetime_kw_mapping = {'date': 'day', 'hours': 'hour', 'milliseconds': 'microsecond', 'minutes': 'minute',
'month': 'month', 'seconds': 'second', 'year': 'year'}
Expand Down
79 changes: 79 additions & 0 deletions mplaltair/_marks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import matplotlib
import numpy as np
from ._data import _locate_channel_field, _locate_channel_data, _locate_channel_dtype, _convert_to_mpl_date


def _handle_line(chart, ax):
"""Convert encodings, manipulate data if needed, and plot the line chart on an axes.

Parameters
----------
chart : altair.Chart
The Altair chart object

ax : matplotlib.axes
The Matplotlib axes object

Notes
-----
Fill isn't necessary until mpl-altair can handle multiple plot types in one plot.
Size is unsupported by both Matplotlib and Altair.
When both Color and Stroke are provided, color is ignored and stroke is used.
Shape is unsupported in line graphs unless another plot type is plotted at the same time.
"""

groupbys = []
kwargs = {}

if 'opacity' in chart.to_dict()['encoding']:
groupbys.append('opacity')

if 'stroke' in chart.to_dict()['encoding']:
groupbys.append('stroke')
elif 'color' in chart.to_dict()['encoding']:
groupbys.append('color')

list_fields = lambda c, g: [_locate_channel_field(c, i) for i in g]
if len(groupbys) > 0:
for label, subset in chart.data.groupby(list_fields(chart, groupbys)):
if 'opacity' in groupbys:
kwargs['alpha'] = _opacity_norm(chart, _locate_channel_dtype(chart, 'opacity'),
subset[_locate_channel_field(chart, 'opacity')].iloc[0])

if 'color' not in groupbys and 'stroke' not in groupbys:
kwargs['color'] = matplotlib.rcParams['lines.color']
ax.plot(subset[_locate_channel_field(chart, 'x')], subset[_locate_channel_field(chart, 'y')], **kwargs)
else:
ax.plot(_locate_channel_data(chart, 'x'), _locate_channel_data(chart, 'y'))


def _opacity_norm(chart, dtype, val):
"""
Normalize the values of a column to be between 0.15 and 1, which is a visible range for opacity.

Parameters
----------
chart : altair.Chart
The Altair chart object
dtype : str
The data type of the column ('quantitative', 'nominal', 'ordinal', or 'temporal')
val
The specific value to be normalized.

Returns
-------
The normalized value (between 0.15 and 1)
"""
arr = _locate_channel_data(chart, 'opacity')
if dtype in ['ordinal', 'nominal', 'temporal']:
# map categoricals to numbers
unique, indices = np.unique(arr, return_inverse=True)
arr = indices
if dtype == 'temporal':
val = unique.tolist().index(_convert_to_mpl_date(val))
else:
val = unique.tolist().index(val)
data_min = arr.min()
data_max = arr.max()
desired_min, desired_max = (0.15, 1) # Chosen so that the minimum value is visible
return ((val - data_min) / (data_max - data_min)) * (desired_max - desired_min) + desired_min
44 changes: 44 additions & 0 deletions mplaltair/_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# FROM ENCODINGS=======================================================================================================
from urllib.request import urlopen
from urllib.error import HTTPError

import pandas as pd

_PD_READERS = {
'json': pd.read_json,
'csv': pd.read_csv
}

def _get_format(url):
"""Gives back the format of the file from url

WARNING: It might break. Trying to find a better way.
"""
return url.split('.')[-1]

def _fetch(url):
"""Downloads the file from the given url as a Pandas DataFrame

Parameters
----------
url : string
URL of the file to be downloaded

Returns
-------
pd.DataFrame
Data in the format of a DataFrame

Raises
------
NotImplementedError
Raises when an unsupported file format is given as an URL
"""
try:
ext = _get_format(url)
reader = _PD_READERS[ext]
df = reader(urlopen(url).read())
except KeyError:
raise NotImplementedError('File format not implemented')
return df
# END OF STUFF FROM ENCODINGS==========================================================================================
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading