diff --git a/.travis.yml b/.travis.yml index 3165293..797b012 100644 --- a/.travis.yml +++ b/.travis.yml @@ -5,13 +5,13 @@ python: - "3.4" env: matrix: - - DJANGO="django==1.7.9" REST="djangorestframework==2.4.6" - - DJANGO="django==1.7.9" REST="djangorestframework==3.1.1" - - DJANGO="django==1.8.3" REST="djangorestframework==2.4.6" - - DJANGO="django==1.8.3" REST="djangorestframework==3.1.3" - - DJANGO="django==1.8.3" REST="djangorestframework==3.1.3" LINT=1 + - DJANGO="django==1.7.10" REST="djangorestframework==2.4.8" + - DJANGO="django==1.7.10" REST="djangorestframework==3.2.3" + - DJANGO="django==1.8.4" REST="djangorestframework==2.4.8" + - DJANGO="django==1.8.4" REST="djangorestframework==3.2.3" + - DJANGO="django==1.8.4" REST="djangorestframework==3.2.3" LINT=1 global: - - PANDAS="pandas==0.16.0" + - PANDAS="pandas==0.16.2" install: - pip install $DJANGO - pip install $REST diff --git a/rest_pandas/serializers.py b/rest_pandas/serializers.py index eef5571..07f3c14 100644 --- a/rest_pandas/serializers.py +++ b/rest_pandas/serializers.py @@ -1,5 +1,7 @@ from rest_framework import serializers from pandas import DataFrame +from django.core.exceptions import ImproperlyConfigured +import datetime if hasattr(serializers, 'ListSerializer'): @@ -20,10 +22,7 @@ class PandasSerializer(BaseSerializer): index_none_value = None def get_index(self, dataframe): - model_serializer = getattr(self, 'child', self) - if getattr(model_serializer.Meta, 'model', None): - return ['id'] - return None + return self.get_index_fields() def get_dataframe(self, data): dataframe = DataFrame(data) @@ -53,6 +52,313 @@ def data(self): else: return DataFrame([]) + @property + def model_serializer(self): + if USE_LIST_SERIALIZERS: + serializer = type(self.child) + else: + serializer = type(self) + if serializer.__name__ == 'SerializerWithListSerializer': + for base in serializer.__bases__: + if not issubclass(base, PandasSerializer): + return base + return serializer + + @property + def model_serializer_meta(self): + return getattr(self.model_serializer, 'Meta', object()) + + def get_index_fields(self): + """ + List of fields to use for index + """ + default_fields = [] + if getattr(self.model_serializer_meta, 'model', None): + default_fields = ['id'] + return self.get_meta_option('index', default_fields) + + def get_meta_option(self, name, default=None): + meta_name = 'pandas_' + name + value = getattr(self.model_serializer_meta, meta_name, None) + + if value is None: + if default is not None: + return default + else: + raise ImproperlyConfigured( + "%s should be specified on %s.Meta" % + (meta_name, self.model_serializer.__name__) + ) + return value + + +class PandasUnstackedSerializer(PandasSerializer): + """ + Pivots dataframe so commonly-repeating values are across the top in a + multi-row header. Intended for use with e.g. time series data, where the + header includes metadata applicable to each time series. + (Use with wq/chart.js' timeSeries() function) + """ + index_none_value = '-' + + def get_index(self, dataframe): + """ + Include header fields in initial index for later unstacking + """ + return self.get_index_fields() + self.get_header_fields() + + def transform_dataframe(self, dataframe): + """ + Unstack the dataframe so header fields are across the top. + """ + dataframe.columns.name = "" + + for i in range(len(self.get_header_fields())): + dataframe = dataframe.unstack() + + # Remove blank rows / columns + dataframe = dataframe.dropna( + axis=0, how='all' + ).dropna( + axis=1, how='all' + ) + return dataframe + + def get_header_fields(self): + """ + Series metadata fields for header (first few rows) + """ + return self.get_meta_option('unstacked_header') + + +class PandasScatterSerializer(PandasSerializer): + """ + Pivots dataframe into a format suitable for plotting two series + against each other as x vs y on a scatter plot. + (Use with wq/chart.js' scatter() function) + """ + index_none_value = '-' + + def get_index(self, dataframe): + """ + Include scatter & header fields in initial index for later unstacking + """ + return ( + self.get_index_fields() + + self.get_header_fields() + + self.get_coord_fields() + ) + + def transform_dataframe(self, dataframe): + """ + Unstack the dataframe so header consists of a composite 'value' header + plus any other header fields. + """ + coord_fields = self.get_coord_fields() + header_fields = self.get_header_fields() + for i in range(len(header_fields) + len(coord_fields)): + dataframe = dataframe.unstack() + + # Compute new column headers + columns = [] + for i in range(len(header_fields) + 1): + columns.append([]) + + for col in dataframe.columns: + value_name = col[0] + coord_names = list(col[1:len(coord_fields) + 1]) + header_names = list(col[len(coord_fields) + 1:]) + coord_name = '' + for name in coord_names: + if name != self.index_none_value: + coord_name += name + '-' + coord_name += value_name + columns[0].append(coord_name) + for i, header_name in enumerate(header_names): + columns[1 + i].append(header_name) + + dataframe.columns = columns + dataframe.columns.names = [''] + header_fields + + # Remove blank columns + dataframe = dataframe.dropna(axis=1, how='all') + + # Remove any rows that don't have data for all columns (e.g. x & y) + dataframe = dataframe.dropna(axis=0, how='any') + return dataframe + + def get_coord_fields(self): + """ + Fields that will be collapsed into a single header with the name of + each coordinate. + """ + return self.get_meta_option('scatter_coord') + + def get_header_fields(self): + """ + Other header fields, if any + """ + return self.get_meta_option('scatter_header', []) + + +class PandasBoxplotSerializer(PandasSerializer): + """ + Compute boxplot statistics on dataframe columns, creating a new unstacked + dataframe where each row describes a boxplot. + (Use with wq/chart.js' boxplot() function) + """ + index_none_value = '-' + + def get_index(self, dataframe): + group_field = self.get_group_field() + date_field = self.get_date_field() + header_fields = self.get_header_fields() + + if date_field: + group_fields = [date_field, group_field] + else: + group_fields = [group_field] + return group_fields + header_fields + + def transform_dataframe(self, dataframe): + """ + Use matplotlib to compute boxplot statistics on e.g. timeseries data. + """ + grouping = self.get_grouping(dataframe) + group_field = self.get_group_field() + header_fields = self.get_header_fields() + + if "series" in grouping: + # Unstack so each series is a column + for i in range(len(header_fields) + 1): + dataframe = dataframe.unstack() + + groups = { + col: dataframe[col] + for col in dataframe.columns + } + + if "year" in grouping: + interval = "year" + elif "month" in grouping: + interval = "month" + else: + interval = None + + # Compute stats for each column, potentially grouped by year + all_stats = [] + for header, series in groups.items(): + if interval: + series_stats = self.boxplots_for_interval(series, interval) + else: + interval = None + series_stats = [self.compute_boxplot(series)] + + series_infos = [] + for series_stat in series_stats: + series_info = {} + if isinstance(header, tuple): + value_name = header[0] + col_values = header[1:] + else: + value_name = header + col_values = [] + col_names = zip(dataframe.columns.names[1:], col_values) + for col_name, value in col_names: + series_info[col_name] = value + for stat_name, val in series_stat.items(): + if stat_name == interval: + series_info[stat_name] = val + else: + series_info[value_name + '-' + stat_name] = val + series_infos.append(series_info) + all_stats += series_infos + + dataframe = DataFrame(all_stats) + if 'series' in grouping: + index = header_fields + [group_field] + unstack = len(header_fields) + if interval: + index = [interval] + index + unstack += 1 + else: + index = [interval] + unstack = 0 + + dataframe.set_index(index, inplace=True) + dataframe.columns.name = '' + for i in range(unstack): + dataframe = dataframe.unstack() + + # Remove blank columns + dataframe = dataframe.dropna(axis=1, how='all') + return dataframe + + def get_grouping(self, dataframe): + request = self.context.get('request', None) + datasets = len(dataframe.columns) + if request: + group = request.GET.get('group', None) + if group: + return group + # Heuristic for default grouping: + if datasets > 20 and self.get_date_field(): + # Group all data by year + return "year" + elif datasets > 10 or not self.get_date_field(): + # Compare series but don't break down by year + return "series" + else: + # 10 or fewer datasets, break down by both series and year + return "series-year" + + def boxplots_for_interval(self, series, interval): + def get_interval_name(date): + if isinstance(date, tuple): + date = date[0] + if hasattr(date, 'count') and date.count('-') == 2: + date = datetime.datetime.strptime(date, "%Y-%m-%d") + return getattr(date, interval) + + interval_stats = [] + groups = series.groupby(get_interval_name).groups + for interval_name, group in groups.items(): + stats = self.compute_boxplot(series[group]) + stats[interval] = interval_name + interval_stats.append(stats) + return interval_stats + + def compute_boxplot(self, series): + """ + Compute boxplot for given pandas Series. + """ + from matplotlib.cbook import boxplot_stats + series = series[series.notnull()] + if len(series.values) == 0: + return {} + stats = boxplot_stats(list(series.values))[0] + stats['count'] = len(series.values) + stats['fliers'] = "|".join(map(str, stats['fliers'])) + return stats + + def get_group_field(self): + """ + Categorical field to group datasets by. + """ + return self.get_meta_option('boxplot_group') + + def get_date_field(self): + """ + Date field to group datasets by year or month. + """ + return self.get_meta_option('boxplot_date', False) + + def get_header_fields(self): + """ + Additional series metadata for boxplot column headers + """ + return self.get_meta_option('boxplot_header', []) + class SimpleSerializer(serializers.Serializer): """ diff --git a/rest_pandas/test.py b/rest_pandas/test.py index a786875..8d7679c 100644 --- a/rest_pandas/test.py +++ b/rest_pandas/test.py @@ -9,16 +9,18 @@ def parse_csv(string): """ reader = csv.reader(StringIO(string)) val_cols = None + val_start = None id_cols = None for row in reader: - if row[0] == '': - val_cols = row[row.count(''):] + if row[0] == '' and not val_cols: + val_start = row.count('') + val_cols = row[val_start:] col_meta = [{} for v in val_cols] elif row[-1] != '' and val_cols and not id_cols: key = row[0] - for i, meta in enumerate(row[1:]): + for i, meta in enumerate(row[val_start:]): col_meta[i].update(**{key: meta}) - elif row[-1] == '': + elif row[-1] == '' and not id_cols: id_cols = row[:row.index('')] meta_index = {} meta_i = 0 @@ -51,8 +53,9 @@ def parse_csv(string): val = float(val) except ValueError: pass - data[val_cols[i]] = val - records[mi] = data + if val != '': + data[val_cols[i]] = val + records[mi] = data for mi, data in records.items(): datasets[mi]['data'].append(data) return datasets diff --git a/tests/test_complex.py b/tests/test_complex.py new file mode 100644 index 0000000..0275a4b --- /dev/null +++ b/tests/test_complex.py @@ -0,0 +1,259 @@ +from rest_framework.test import APITestCase +from tests.testapp.models import ComplexTimeSeries +from rest_pandas.test import parse_csv +from wq.io import load_string +import unittest +try: + from matplotlib.cbook import boxplot_stats +except ImportError: + boxplot_stats = None + + +class ComplexTestCase(APITestCase): + def setUp(self): + data = ( + ('site1', 'height', None, '2015-01-01', 'routine', 0.5, None), + ('site1', 'height', None, '2015-01-02', 'routine', 0.4, None), + ('site1', 'height', None, '2015-01-03', 'routine', 0.6, None), + ('site1', 'height', None, '2015-01-04', 'special', 0.2, None), + ('site1', 'height', None, '2015-01-05', 'routine', 0.1, None), + + ('site1', 'flow', 'cfs', '2015-01-01', 'special', 0.7, None), + ('site1', 'flow', 'cfs', '2015-01-02', 'routine', 0.8, None), + ('site1', 'flow', 'cfs', '2015-01-03', 'routine', 0.0, 'Q'), + ('site1', 'flow', 'cfs', '2015-01-04', 'routine', 0.9, None), + ('site1', 'flow', 'cfs', '2015-01-05', 'routine', 0.3, None), + + ('site2', 'flow', 'cfs', '2015-01-01', 'routine', 0.0, None), + ('site2', 'flow', 'cfs', '2015-01-02', 'routine', 0.7, None), + ('site2', 'flow', 'cfs', '2015-01-03', 'routine', 0.2, None), + ('site2', 'flow', 'cfs', '2015-01-04', 'routine', 0.3, None), + ('site2', 'flow', 'cfs', '2015-01-05', 'routine', 0.8, None), + ) + for site, parameter, units, date, type, value, flag in data: + ComplexTimeSeries.objects.create( + site=site, + parameter=parameter, + units=units, + date=date, + type=type, + value=value, + flag=flag, + ) + + def test_complex_series(self): + response = self.client.get("/complextimeseries.csv") + self.assertEqual( + """,,flag,value,value,value + units,,cfs,-,cfs,cfs + parameter,,flow,height,flow,flow + site,,site1,site1,site1,site2 + date,type,,,, + 2015-01-01,routine,,0.5,,0.0 + 2015-01-01,special,,,0.7, + 2015-01-02,routine,,0.4,0.8,0.7 + 2015-01-03,routine,Q,0.6,0.0,0.2 + 2015-01-04,routine,,,0.9,0.3 + 2015-01-04,special,,0.2,, + 2015-01-05,routine,,0.1,0.3,0.8 + """.replace(' ', ''), + response.content.decode('utf-8'), + ) + datasets = self.parse_unstacked_csv(response) + self.assertEqual(len(datasets), 3) + for dataset in datasets: + self.assertEqual(len(dataset['data']), 5) + + s1flow = None + s1height = None + s2flow = None + for dataset in datasets: + if dataset['site'] == "site1": + if dataset['parameter'] == "flow": + s1flow = dataset + else: + s1height = dataset + else: + s2flow = dataset + + d0 = s1height['data'][0] + self.assertEqual(d0['date'], '2015-01-01') + self.assertEqual(d0['value'], 0.5) + + d1 = s1flow['data'][2] + self.assertEqual(d1['date'], '2015-01-03') + self.assertEqual(d1['value'], 0.0) + self.assertEqual(d1['flag'], 'Q') + + d2 = s2flow['data'][4] + self.assertEqual(d2['date'], '2015-01-05') + self.assertEqual(d2['value'], 0.8) + + def test_complex_scatter(self): + response = self.client.get("/complexscatter.csv") + self.assertEqual( + """,,flow-cfs-value,flow-cfs-value,height-value + site,,site1,site2,site1 + date,type,,, + 2015-01-02,routine,0.8,0.7,0.4 + 2015-01-03,routine,0.0,0.2,0.6 + 2015-01-05,routine,0.3,0.8,0.1 + """.replace(' ', ''), + response.content.decode('utf-8') + ) + datasets = self.parse_unstacked_csv(response) + self.assertEqual([ + {'site': 'site1', 'data': [ + {'date': '2015-01-02', 'type': 'routine', + 'flow-cfs-value': 0.8, 'height-value': 0.4}, + {'date': '2015-01-03', 'type': 'routine', + 'flow-cfs-value': 0.0, 'height-value': 0.6}, + {'date': '2015-01-05', 'type': 'routine', + 'flow-cfs-value': 0.3, 'height-value': 0.1}, + ]}, + {'site': 'site2', 'data': [ + {'date': '2015-01-02', 'type': 'routine', + 'flow-cfs-value': 0.7}, + {'date': '2015-01-03', 'type': 'routine', + 'flow-cfs-value': 0.2}, + {'date': '2015-01-05', 'type': 'routine', + 'flow-cfs-value': 0.8}, + ]}, + ], datasets) + + @unittest.skipUnless(boxplot_stats, "test requires matplotlib 1.4+") + def test_complex_boxplot(self): + # Default group=series-year + response = self.client.get("/complexboxplot.csv") + datasets = self.parse_unstacked_csv(response) + + self.assertEqual(len(datasets), 3) + s1flow = None + s1height = None + s2flow = None + for dataset in datasets: + if dataset['site'] == "site1": + if dataset['parameter'] == "flow": + s1flow = dataset + else: + s1height = dataset + else: + s2flow = dataset + + self.assertEqual(len(s1height['data']), 1) + self.assertEqual(s1height['units'], '-') + stats = s1height['data'][0] + self.assertEqual(stats['year'], '2015') + self.assertEqual(stats['value-whislo'], 0.1) + self.assertEqual(stats['value-mean'], 0.36) + self.assertEqual(stats['value-whishi'], 0.6) + + self.assertEqual(s1flow['units'], 'cfs') + stats = s1flow['data'][0] + self.assertEqual(stats['year'], '2015') + self.assertEqual(stats['value-whislo'], 0.0) + self.assertEqual(round(stats['value-mean'], 8), 0.54) + self.assertEqual(stats['value-whishi'], 0.9) + + self.assertEqual(s2flow['units'], 'cfs') + stats = s2flow['data'][0] + self.assertEqual(stats['year'], '2015') + self.assertEqual(stats['value-whislo'], 0.0) + self.assertEqual(stats['value-mean'], 0.4) + self.assertEqual(stats['value-whishi'], 0.8) + + @unittest.skipUnless(boxplot_stats, "test requires matplotlib 1.4+") + def test_complex_boxplot_series(self): + response = self.client.get("/complexboxplot.csv?group=series") + datasets = self.parse_unstacked_csv(response) + s1flow = None + s1height = None + s2flow = None + for dataset in datasets: + if dataset['site'] == "site1": + if dataset['parameter'] == "flow": + s1flow = dataset + else: + s1height = dataset + else: + s2flow = dataset + + self.assertEqual(len(s1height['data']), 1) + stats = s1height['data'][0] + self.assertNotIn('year', stats) + self.assertEqual(stats['value-whislo'], 0.1) + self.assertEqual(stats['value-mean'], 0.36) + self.assertEqual(stats['value-whishi'], 0.6) + + stats = s1flow['data'][0] + self.assertNotIn('year', stats) + self.assertEqual(stats['value-whislo'], 0.0) + self.assertEqual(round(stats['value-mean'], 8), 0.54) + self.assertEqual(stats['value-whishi'], 0.9) + + self.assertEqual(len(s1flow['data']), 1) + stats = s2flow['data'][0] + self.assertNotIn('year', stats) + self.assertEqual(stats['value-whislo'], 0.0) + self.assertEqual(stats['value-mean'], 0.4) + self.assertEqual(stats['value-whishi'], 0.8) + + @unittest.skipUnless(boxplot_stats, "test requires matplotlib 1.4+") + def test_complex_boxplot_month_group(self): + response = self.client.get("/complexboxplot.csv?group=series-month") + datasets = self.parse_unstacked_csv(response) + s1flow = None + s1height = None + s2flow = None + for dataset in datasets: + if dataset['site'] == "site1": + if dataset['parameter'] == "flow": + s1flow = dataset + else: + s1height = dataset + else: + s2flow = dataset + + self.assertEqual(len(s1height['data']), 1) + stats = s1height['data'][0] + self.assertEqual(stats['month'], '1') + self.assertEqual(stats['value-whislo'], 0.1) + self.assertEqual(stats['value-mean'], 0.36) + self.assertEqual(stats['value-whishi'], 0.6) + + stats = s1flow['data'][0] + self.assertEqual(stats['month'], '1') + self.assertEqual(stats['value-whislo'], 0.0) + self.assertEqual(round(stats['value-mean'], 8), 0.54) + self.assertEqual(stats['value-whishi'], 0.9) + + self.assertEqual(len(s1flow['data']), 1) + stats = s2flow['data'][0] + self.assertEqual(stats['month'], '1') + self.assertEqual(stats['value-whislo'], 0.0) + self.assertEqual(stats['value-mean'], 0.4) + self.assertEqual(stats['value-whishi'], 0.8) + + @unittest.skipUnless(boxplot_stats, "test requires matplotlib 1.4+") + def test_complex_boxplot_year(self): + response = self.client.get("/complexboxplot.csv?group=year") + datasets = self.parse_plain_csv(response) + self.assertEqual(len(datasets), 1) + stats = datasets[0] + self.assertEqual(stats['year'], 2015) + self.assertEqual(stats['value-whislo'], 0.0) + self.assertEqual(round(stats['value-mean'], 5), 0.43333) + self.assertEqual(stats['value-whishi'], 0.9) + + def parse_unstacked_csv(self, response): + return parse_csv(response.content.decode('utf-8')) + + def parse_plain_csv(self, response): + data = load_string(response.content.decode('utf-8')).data + for row in data: + for key in row: + try: + row[key] = float(row[key]) + except ValueError: + pass + return data diff --git a/tests/test_multi.py b/tests/test_multi.py new file mode 100644 index 0000000..5e7a8cd --- /dev/null +++ b/tests/test_multi.py @@ -0,0 +1,185 @@ +from rest_framework.test import APITestCase +from tests.testapp.models import MultiTimeSeries +from tests.testapp.serializers import NotUnstackableSerializer +from rest_pandas.test import parse_csv +from wq.io import load_string +from django.core.exceptions import ImproperlyConfigured +import unittest +try: + from matplotlib.cbook import boxplot_stats +except ImportError: + boxplot_stats = None + + +class MultiTestCase(APITestCase): + def setUp(self): + data = ( + ('test1', '2015-01-01', 0.5), + ('test1', '2015-01-02', 0.4), + ('test1', '2015-01-03', 0.6), + ('test1', '2015-01-04', 0.2), + ('test1', '2015-01-05', 0.1), + + ('test2', '2015-01-01', 0.7), + ('test2', '2015-01-02', 0.8), + ('test2', '2015-01-03', 0.0), + ('test2', '2015-01-04', 0.9), + ('test2', '2015-01-05', 0.3), + ) + for series, date, value in data: + MultiTimeSeries.objects.create( + series=series, + date=date, + value=value + ) + + def test_multi_series(self): + response = self.client.get("/multitimeseries.csv") + self.assertEqual( + """,value,value + series,test1,test2 + date,, + 2015-01-01,0.5,0.7 + 2015-01-02,0.4,0.8 + 2015-01-03,0.6,0.0 + 2015-01-04,0.2,0.9 + 2015-01-05,0.1,0.3 + """.replace(' ', ''), + response.content.decode('utf-8'), + ) + datasets = self.parse_unstacked_csv(response) + self.assertEqual(len(datasets), 2) + for dataset in datasets: + self.assertEqual(len(dataset['data']), 5) + + if datasets[0]['series'] == "test1": + s1data, s2data = datasets[0], datasets[1] + else: + s2data, s1data = datasets[1], datasets[0] + + d0 = s1data['data'][0] + self.assertEqual(d0['date'], '2015-01-01') + self.assertEqual(d0['value'], 0.5) + + d0 = s2data['data'][4] + self.assertEqual(d0['date'], '2015-01-05') + self.assertEqual(d0['value'], 0.3) + + def test_multi_scatter(self): + response = self.client.get("/multiscatter.csv") + self.assertEqual( + """date,test1-value,test2-value + 2015-01-01,0.5,0.7 + 2015-01-02,0.4,0.8 + 2015-01-03,0.6,0.0 + 2015-01-04,0.2,0.9 + 2015-01-05,0.1,0.3 + """.replace(' ', ''), + response.content.decode('utf-8') + ) + + @unittest.skipUnless(boxplot_stats, "test requires matplotlib 1.4+") + def test_multi_boxplot(self): + # Default: group=series-year + response = self.client.get("/multiboxplot.csv") + + datasets = self.parse_unstacked_csv(response) + self.assertEqual(len(datasets), 2) + if datasets[0]['series'] == 'test1': + s1data, s2data = datasets + else: + s2data, s1data = datasets + + self.assertEqual(len(s1data['data']), 1) + stats = s1data['data'][0] + self.assertEqual(stats['year'], '2015') + self.assertEqual(stats['value-whislo'], 0.1) + self.assertEqual(stats['value-mean'], 0.36) + self.assertEqual(stats['value-whishi'], 0.6) + + stats = s2data['data'][0] + self.assertEqual(stats['year'], '2015') + self.assertEqual(stats['value-whislo'], 0.0) + self.assertEqual(round(stats['value-mean'], 8), 0.54) + self.assertEqual(stats['value-whishi'], 0.9) + + @unittest.skipUnless(boxplot_stats, "test requires matplotlib 1.4+") + def test_multi_boxplot_series(self): + response = self.client.get("/multiboxplot.csv?group=series") + datasets = self.parse_plain_csv(response) + self.assertEqual(len(datasets), 2) + if datasets[0]['series'] == 'test1': + s1data, s2data = datasets + else: + s2data, s1data = datasets + + stats = s1data + self.assertNotIn('year', stats) + self.assertEqual(stats['value-whislo'], 0.1) + self.assertEqual(stats['value-mean'], 0.36) + self.assertEqual(stats['value-whishi'], 0.6) + + stats = s2data + self.assertNotIn('year', stats) + self.assertEqual(stats['value-whislo'], 0.0) + self.assertEqual(round(stats['value-mean'], 8), 0.54) + self.assertEqual(stats['value-whishi'], 0.9) + + @unittest.skipUnless(boxplot_stats, "test requires matplotlib 1.4+") + def test_multi_boxplot_series_month(self): + response = self.client.get("/multiboxplot.csv?group=series-month") + + datasets = self.parse_unstacked_csv(response) + self.assertEqual(len(datasets), 2) + if datasets[0]['series'] == 'test1': + s1data, s2data = datasets + else: + s2data, s1data = datasets + + self.assertEqual(len(s1data['data']), 1) + stats = s1data['data'][0] + self.assertEqual(stats['month'], '1') + self.assertEqual(stats['value-whislo'], 0.1) + self.assertEqual(stats['value-mean'], 0.36) + self.assertEqual(stats['value-whishi'], 0.6) + + stats = s2data['data'][0] + self.assertEqual(stats['month'], '1') + self.assertEqual(stats['value-whislo'], 0.0) + self.assertEqual(round(stats['value-mean'], 8), 0.54) + self.assertEqual(stats['value-whishi'], 0.9) + + @unittest.skipUnless(boxplot_stats, "test requires matplotlib 1.4+") + def test_multi_boxplot_year(self): + response = self.client.get("/multiboxplot.csv?group=year") + + datasets = self.parse_plain_csv(response) + self.assertEqual(len(datasets), 1) + stats = datasets[0] + self.assertEqual(stats['year'], 2015) + self.assertEqual(stats['value-whislo'], 0.0) + self.assertEqual(stats['value-mean'], 0.45) + self.assertEqual(stats['value-whishi'], 0.9) + + def test_not_unstackable(self): + qs = MultiTimeSeries.objects.all() + with self.assertRaises(ImproperlyConfigured) as e: + NotUnstackableSerializer(qs, many=True).data + self.assertEqual( + e.exception.args[0], + "pandas_unstacked_header should be specified on " + "NotUnstackableSerializer.Meta" + ) + + def parse_unstacked_csv(self, response): + return parse_csv(response.content.decode('utf-8')) + + def parse_plain_csv(self, response): + data = load_string(response.content.decode('utf-8')).data + for row in data: + for key in row: + try: + row[key] = float(row[key]) + except ValueError: + pass + return data diff --git a/tests/test_views.py b/tests/test_views.py index 1fe28cd..2e3e048 100644 --- a/tests/test_views.py +++ b/tests/test_views.py @@ -41,7 +41,10 @@ def test_view_json_kwargs(self): self.assertEqual(date, datetime.datetime(2014, 1, 1)) def test_viewset(self): - response = self.client.get("/router/timeseries/.csv") + response = self.client.get("/router/timeseries.csv") + if response.status_code == 404: + # DRF before 3.2 required an extra / before format suffix + response = self.client.get('/router/timeseries/.csv') data = self.load_string(response) self.assertEqual(len(data), 5) self.assertEqual(data[0].value, '0.5') diff --git a/tests/testapp/models.py b/tests/testapp/models.py index bbd511c..015a680 100644 --- a/tests/testapp/models.py +++ b/tests/testapp/models.py @@ -4,3 +4,29 @@ class TimeSeries(models.Model): date = models.DateField() value = models.FloatField() + + +class MultiTimeSeries(models.Model): + # Header + series = models.CharField(max_length=5) + + # Index + date = models.DateField() + + # Values + value = models.FloatField() + + +class ComplexTimeSeries(models.Model): + # Header + site = models.CharField(max_length=5) + parameter = models.CharField(max_length=5) + units = models.CharField(max_length=5, null=True, blank=True) + + # Index + date = models.DateField() + type = models.CharField(max_length=10) + + # Values + value = models.FloatField() + flag = models.CharField(max_length=1, null=True, blank=True) diff --git a/tests/testapp/serializers.py b/tests/testapp/serializers.py index 601f002..70b16cf 100644 --- a/tests/testapp/serializers.py +++ b/tests/testapp/serializers.py @@ -1,7 +1,70 @@ -from rest_framework.serializers import ModelSerializer -from .models import TimeSeries +from rest_framework.serializers import ModelSerializer, DateField +from rest_pandas import PandasUnstackedSerializer +from .models import TimeSeries, MultiTimeSeries, ComplexTimeSeries +from rest_pandas import USE_LIST_SERIALIZERS + + +if not USE_LIST_SERIALIZERS: + # DRF 2.4 appended 00:00:00 to dates + class DateField(DateField): + def to_native(self, date): + return str(date) class TimeSeriesSerializer(ModelSerializer): class Meta: model = TimeSeries + + +class MultiTimeSeriesSerializer(ModelSerializer): + class Meta: + model = MultiTimeSeries + exclude = ['id'] + + pandas_index = ['date'] + pandas_unstacked_header = ['series'] + pandas_scatter_coord = ['series'] + pandas_boxplot_group = 'series' + pandas_boxplot_date = 'date' + + +class ComplexTimeSeriesSerializer(ModelSerializer): + date = DateField() + + class Meta: + model = ComplexTimeSeries + exclude = ['id'] + + pandas_index = ['date', 'type'] + pandas_unstacked_header = ['site', 'parameter', 'units'] + + +class ComplexScatterSerializer(ComplexTimeSeriesSerializer): + class Meta(ComplexTimeSeriesSerializer.Meta): + exclude = ['id', 'flag'] + + pandas_scatter_coord = ['units', 'parameter'] + pandas_scatter_header = ['site'] + + +class ComplexBoxplotSerializer(ComplexTimeSeriesSerializer): + class Meta(ComplexTimeSeriesSerializer.Meta): + exclude = ['id', 'flag', 'type'] + pandas_boxplot_group = 'site' + pandas_boxplot_date = 'date' + pandas_boxplot_header = ['units', 'parameter'] + + +if USE_LIST_SERIALIZERS: + class NotUnstackableSerializer(ModelSerializer): + class Meta: + model = MultiTimeSeries + list_serializer_class = PandasUnstackedSerializer + # pandas_unstacked_header = Missing + pandas_index = ['series'] +else: + class NotUnstackableSerializer(ModelSerializer, PandasUnstackedSerializer): + class Meta: + model = MultiTimeSeries + # pandas_unstacked_header = Missing + pandas_index = ['series'] diff --git a/tests/testapp/urls.py b/tests/testapp/urls.py index a7cdfca..be00d7f 100644 --- a/tests/testapp/urls.py +++ b/tests/testapp/urls.py @@ -2,7 +2,11 @@ from rest_framework.routers import DefaultRouter from rest_framework.urlpatterns import format_suffix_patterns -from .views import NoModelView, TimeSeriesView, TimeSeriesViewSet +from .views import ( + NoModelView, TimeSeriesView, TimeSeriesViewSet, + MultiTimeSeriesView, MultiScatterView, MultiBoxplotView, + ComplexTimeSeriesView, ComplexScatterView, ComplexBoxplotView, +) router = DefaultRouter() router.register('timeseries', TimeSeriesViewSet) @@ -10,6 +14,12 @@ urlpatterns = patterns('', url(r'^nomodel$', NoModelView.as_view()), # noqa url(r'^timeseries$', TimeSeriesView.as_view()), + url(r'^multitimeseries$', MultiTimeSeriesView.as_view()), + url(r'^multiscatter$', MultiScatterView.as_view()), + url(r'^multiboxplot$', MultiBoxplotView.as_view()), + url(r'^complextimeseries$', ComplexTimeSeriesView.as_view()), + url(r'^complexscatter$', ComplexScatterView.as_view()), + url(r'^complexboxplot$', ComplexBoxplotView.as_view()), ) urlpatterns = format_suffix_patterns(urlpatterns) urlpatterns += patterns('', diff --git a/tests/testapp/views.py b/tests/testapp/views.py index 7b18462..054f87e 100644 --- a/tests/testapp/views.py +++ b/tests/testapp/views.py @@ -1,6 +1,13 @@ -from rest_pandas import PandasSimpleView, PandasView, PandasViewSet -from .models import TimeSeries -from .serializers import TimeSeriesSerializer +from rest_pandas import ( + PandasSimpleView, PandasView, PandasViewSet, + PandasUnstackedSerializer, PandasScatterSerializer, PandasBoxplotSerializer +) +from .models import TimeSeries, MultiTimeSeries, ComplexTimeSeries +from .serializers import ( + TimeSeriesSerializer, MultiTimeSeriesSerializer, + ComplexTimeSeriesSerializer, ComplexScatterSerializer, + ComplexBoxplotSerializer, +) class NoModelView(PandasSimpleView): @@ -25,3 +32,39 @@ def transform_dataframe(self, df): class TimeSeriesViewSet(PandasViewSet): queryset = TimeSeries.objects.all() serializer_class = TimeSeriesSerializer + + +class MultiTimeSeriesView(PandasView): + queryset = MultiTimeSeries.objects.all() + serializer_class = MultiTimeSeriesSerializer + pandas_serializer_class = PandasUnstackedSerializer + + +class MultiScatterView(PandasView): + queryset = MultiTimeSeries.objects.all() + serializer_class = MultiTimeSeriesSerializer + pandas_serializer_class = PandasScatterSerializer + + +class MultiBoxplotView(PandasView): + queryset = MultiTimeSeries.objects.all() + serializer_class = MultiTimeSeriesSerializer + pandas_serializer_class = PandasBoxplotSerializer + + +class ComplexTimeSeriesView(PandasView): + queryset = ComplexTimeSeries.objects.all() + serializer_class = ComplexTimeSeriesSerializer + pandas_serializer_class = PandasUnstackedSerializer + + +class ComplexScatterView(PandasView): + queryset = ComplexTimeSeries.objects.all() + serializer_class = ComplexScatterSerializer + pandas_serializer_class = PandasScatterSerializer + + +class ComplexBoxplotView(PandasView): + queryset = ComplexTimeSeries.objects.all() + serializer_class = ComplexBoxplotSerializer + pandas_serializer_class = PandasBoxplotSerializer