From d164182c97d9b8805cad042264429f0cd1dfc282 Mon Sep 17 00:00:00 2001 From: trbs Date: Fri, 15 Apr 2016 18:12:37 +0200 Subject: [PATCH 1/2] add support for specifying secondary indexes with to_sql --- pandas/core/generic.py | 5 +-- pandas/io/sql.py | 37 ++++++++++++++----- pandas/io/tests/test_sql.py | 71 ++++++++++++++++++++++++++++++++++++- 3 files changed, 101 insertions(+), 12 deletions(-) diff --git a/pandas/core/generic.py b/pandas/core/generic.py index 30252f7068424..63c1efc98db8e 100644 --- a/pandas/core/generic.py +++ b/pandas/core/generic.py @@ -1117,7 +1117,8 @@ def to_msgpack(self, path_or_buf=None, encoding='utf-8', **kwargs): **kwargs) def to_sql(self, name, con, flavor='sqlite', schema=None, if_exists='fail', - index=True, index_label=None, chunksize=None, dtype=None): + index=True, index_label=None, chunksize=None, dtype=None, + indexes=None): """ Write records stored in a DataFrame to a SQL database. @@ -1157,7 +1158,7 @@ def to_sql(self, name, con, flavor='sqlite', schema=None, if_exists='fail', from pandas.io import sql sql.to_sql(self, name, con, flavor=flavor, schema=schema, if_exists=if_exists, index=index, index_label=index_label, - chunksize=chunksize, dtype=dtype) + chunksize=chunksize, dtype=dtype, indexes=indexes) def to_pickle(self, path): """ diff --git a/pandas/io/sql.py b/pandas/io/sql.py index 324988360c9fe..85d97c515608b 100644 --- a/pandas/io/sql.py +++ b/pandas/io/sql.py @@ -516,7 +516,8 @@ def read_sql(sql, con, index_col=None, coerce_float=True, params=None, def to_sql(frame, name, con, flavor='sqlite', schema=None, if_exists='fail', - index=True, index_label=None, chunksize=None, dtype=None): + index=True, index_label=None, chunksize=None, dtype=None, + indexes=None): """ Write records stored in a DataFrame to a SQL database. @@ -568,7 +569,7 @@ def to_sql(frame, name, con, flavor='sqlite', schema=None, if_exists='fail', pandas_sql.to_sql(frame, name, if_exists=if_exists, index=index, index_label=index_label, schema=schema, - chunksize=chunksize, dtype=dtype) + chunksize=chunksize, dtype=dtype, indexes=indexes) def has_table(table_name, con, flavor='sqlite', schema=None): @@ -653,12 +654,13 @@ class SQLTable(PandasObject): def __init__(self, name, pandas_sql_engine, frame=None, index=True, if_exists='fail', prefix='pandas', index_label=None, - schema=None, keys=None, dtype=None): + schema=None, keys=None, dtype=None, indexes=None): self.name = name self.pd_sql = pandas_sql_engine self.prefix = prefix self.frame = frame self.index = self._index_name(index, index_label) + self.indexes = indexes self.schema = schema self.if_exists = if_exists self.keys = keys @@ -849,18 +851,33 @@ def _index_name(self, index, index_label): else: return None + def _is_column_indexed(self, label): + if self.indexes is not None and label in self.indexes: + return True + + if self.index is not None and label in self.index: + if self.keys is None: + return True + + col_nr = self.index.index(label) + 1 + if self.keys[:col_nr] != self.index[:col_nr]: + return True + + return False + def _get_column_names_and_types(self, dtype_mapper): column_names_and_types = [] if self.index is not None: for i, idx_label in enumerate(self.index): idx_type = dtype_mapper( self.frame.index.get_level_values(i)) - column_names_and_types.append((idx_label, idx_type, True)) + indexed = self._is_column_indexed(idx_label) + column_names_and_types.append((idx_label, idx_type, indexed)) column_names_and_types += [ (text_type(self.frame.columns[i]), dtype_mapper(self.frame.iloc[:, i]), - False) + self._is_column_indexed(text_type(self.frame.columns[i]))) for i in range(len(self.frame.columns)) ] @@ -1205,7 +1222,8 @@ def read_query(self, sql, index_col=None, coerce_float=True, read_sql = read_query def to_sql(self, frame, name, if_exists='fail', index=True, - index_label=None, schema=None, chunksize=None, dtype=None): + index_label=None, schema=None, chunksize=None, dtype=None, + indexes=None): """ Write records stored in a DataFrame to a SQL database. @@ -1245,7 +1263,7 @@ def to_sql(self, frame, name, if_exists='fail', index=True, table = SQLTable(name, self, frame=frame, index=index, if_exists=if_exists, index_label=index_label, - schema=schema, dtype=dtype) + schema=schema, dtype=dtype, indexes=indexes) table.create() table.insert(chunksize) if (not name.isdigit() and not name.islower()): @@ -1620,7 +1638,8 @@ def _fetchall_as_list(self, cur): return result def to_sql(self, frame, name, if_exists='fail', index=True, - index_label=None, schema=None, chunksize=None, dtype=None): + index_label=None, schema=None, chunksize=None, dtype=None, + indexes=None): """ Write records stored in a DataFrame to a SQL database. @@ -1657,7 +1676,7 @@ def to_sql(self, frame, name, if_exists='fail', index=True, table = SQLiteTable(name, self, frame=frame, index=index, if_exists=if_exists, index_label=index_label, - dtype=dtype) + dtype=dtype, indexes=indexes) table.create() table.insert(chunksize) diff --git a/pandas/io/tests/test_sql.py b/pandas/io/tests/test_sql.py index b72258cbf588d..092eaa43036d6 100644 --- a/pandas/io/tests/test_sql.py +++ b/pandas/io/tests/test_sql.py @@ -309,6 +309,14 @@ def _load_test3_data(self): self.test_frame3 = DataFrame(data, columns=columns) + def _load_test4_data(self): + n = 10 + colors = np.random.choice(['red', 'green'], size=n) + foods = np.random.choice(['eggs', 'ham'], size=n) + index = pd.MultiIndex.from_arrays([colors, foods], + names=['color', 'food']) + self.test_frame4 = DataFrame(np.random.randn(n, 2), index=index) + def _load_raw_sql(self): self.drop_table('types_test_data') self._get_exec().execute(SQL_STRINGS['create_test_types'][self.flavor]) @@ -512,6 +520,7 @@ def setUp(self): self._load_test1_data() self._load_test2_data() self._load_test3_data() + self._load_test4_data() self._load_raw_sql() def test_read_sql_iris(self): @@ -933,7 +942,7 @@ def test_warning_case_insensitive_table_name(self): def _get_index_columns(self, tbl_name): from sqlalchemy.engine import reflection insp = reflection.Inspector.from_engine(self.conn) - ixs = insp.get_indexes('test_index_saved') + ixs = insp.get_indexes(tbl_name) ixs = [i['column_names'] for i in ixs] return ixs @@ -966,6 +975,66 @@ def test_to_sql_read_sql_with_database_uri(self): tm.assert_frame_equal(test_frame1, test_frame3) tm.assert_frame_equal(test_frame1, test_frame4) + def test_to_sql_column_indexes(self): + temp_frame = DataFrame({'col1': range(4), 'col2': range(4)}) + sql.to_sql(temp_frame, 'test_to_sql_column_indexes', self.conn, + index=False, if_exists='replace', indexes=['col1', 'col2']) + ix_cols = self._get_index_columns('test_to_sql_column_indexes') + self.assertEqual(sorted(ix_cols), [['col1'], ['col2']], + "columns are not correctly indexes") + + def test_sqltable_key_and_multiindex_no_pk(self): + db = sql.SQLDatabase(self.conn) + table = sql.SQLTable('test_sqltable_key_and_multiindex_no_pk', db, + frame=self.test_frame4, index=True) + metadata = table.table.tometadata(table.pd_sql.meta) + indexed_columns = [e.columns.keys() for e in metadata.indexes] + primary_keys = metadata.primary_key.columns.keys() + self.assertListEqual([['color'], ['food']], sorted(indexed_columns), + "Wrong secondary indexes") + self.assertListEqual([], primary_keys, + "There should be no primary keys") + + def test_sqltable_key_and_multiindex_one_pk(self): + db = sql.SQLDatabase(self.conn) + table = sql.SQLTable('test_sqltable_key_and_multiindex_one_pk', db, + frame=self.test_frame4, index=True, + keys=['color']) + metadata = table.table.tometadata(table.pd_sql.meta) + indexed_columns = [e.columns.keys() for e in metadata.indexes] + primary_keys = metadata.primary_key.columns.keys() + self.assertListEqual([['food']], indexed_columns, + "Wrong secondary indexes") + self.assertListEqual(['color'], primary_keys, + "Wrong primary keys") + + def test_sqltable_key_and_multiindex_two_pk(self): + db = sql.SQLDatabase(self.conn) + table = sql.SQLTable('test_sqltable_key_and_multiindex_two_pk', db, + frame=self.test_frame4, index=True, + keys=['color', 'food']) + metadata = table.table.tometadata(table.pd_sql.meta) + indexed_columns = [e.columns.keys() for e in metadata.indexes] + primary_keys = metadata.primary_key.columns.keys() + self.assertListEqual([], indexed_columns, + "There should be no secondary indexes") + self.assertListEqual(['color', 'food'], primary_keys, + "Wrong primary keys") + + def test_sqltable_no_double_key_and_index_index(self): + temp_frame = DataFrame({'col1': range(4), 'col2': range(4)}) + db = sql.SQLDatabase(self.conn) + table = sql.SQLTable('test_sqltable_no_double_key_and_index_index', db, + frame=temp_frame, index=True, index_label='id', + keys=['id'], indexes=['col1', 'col2']) + table_metadata = table.table.tometadata(table.pd_sql.meta) + indexed_columns = [e.columns.keys() for e in table_metadata.indexes] + self.assertNotIn('id', indexed_columns, + "Secondary Index found for primary key") + + self.assertListEqual(['id'], table_metadata.primary_key.columns.keys(), + "Primary key missing from table") + def _make_iris_table_metadata(self): sa = sqlalchemy metadata = sa.MetaData() From 82a0118c8b18a1d745dba38a470cc54a3714fc84 Mon Sep 17 00:00:00 2001 From: trbs Date: Sun, 11 Sep 2016 23:47:50 -0400 Subject: [PATCH 2/2] add comments and versionadded --- pandas/io/sql.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/pandas/io/sql.py b/pandas/io/sql.py index 85d97c515608b..c868984d381b1 100644 --- a/pandas/io/sql.py +++ b/pandas/io/sql.py @@ -516,8 +516,8 @@ def read_sql(sql, con, index_col=None, coerce_float=True, params=None, def to_sql(frame, name, con, flavor='sqlite', schema=None, if_exists='fail', - index=True, index_label=None, chunksize=None, dtype=None, - indexes=None): + index=True, index_label=None, indexes=None, chunksize=None, + dtype=None): """ Write records stored in a DataFrame to a SQL database. @@ -548,6 +548,10 @@ def to_sql(frame, name, con, flavor='sqlite', schema=None, if_exists='fail', Column label for index column(s). If None is given (default) and `index` is True, then the index names are used. A sequence should be given if the DataFrame uses MultiIndex. + indexes : list of column name(s). Columns names in this list will have + an indexes created for them in the database. + + .. versionadded:: 0.18.2 chunksize : int, default None If not None, then rows will be written in batches of this size at a time. If None, all rows will be written at once. @@ -852,9 +856,13 @@ def _index_name(self, index, index_label): return None def _is_column_indexed(self, label): + # column is explicitly set to be indexed if self.indexes is not None and label in self.indexes: return True + # if df index is also a column it needs an index unless it's + # also a primary key (otherwise there would be two indexes). + # multi-index can use primary key if the left hand side matches. if self.index is not None and label in self.index: if self.keys is None: return True