Skip to content

Commit 337b94d

Browse files
author
Artemy Kolchinsky
committed
BUG: Dynamically created table names allow SQL injection
Cleanup doc Check for empty identifiers Tests fix Tests pass Doc update Error catching
1 parent 3030bba commit 337b94d

File tree

3 files changed

+105
-31
lines changed

3 files changed

+105
-31
lines changed

doc/source/whatsnew/v0.16.0.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ Enhancements
106106
- ``tseries.frequencies.to_offset()`` now accepts ``Timedelta`` as input (:issue:`9064`)
107107

108108
- ``Timedelta`` will now accept nanoseconds keyword in constructor (:issue:`9273`)
109+
- SQL code now safely escapes table and column names (:issue:`8986`)
109110

110111
Performance
111112
~~~~~~~~~~~

pandas/io/sql.py

Lines changed: 71 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1239,18 +1239,58 @@ def _create_sql_schema(self, frame, table_name, keys=None):
12391239
}
12401240

12411241

1242+
def _get_unicode_name(name):
1243+
try:
1244+
uname = name.encode("utf-8", "strict").decode("utf-8")
1245+
except UnicodeError:
1246+
raise ValueError("Cannot convert identifier to UTF-8: '%s'" % name)
1247+
return uname
1248+
1249+
def _get_valid_mysql_name(name):
1250+
# Filter for unquoted identifiers
1251+
# See http://dev.mysql.com/doc/refman/5.0/en/identifiers.html
1252+
uname = _get_unicode_name(name)
1253+
if not len(uname):
1254+
raise ValueError("Empty table or column name specified")
1255+
1256+
basere = r'[0-9,a-z,A-Z$_]'
1257+
for c in uname:
1258+
if not re.match(basere, c):
1259+
if not (0x80 < ord(c) < 0xFFFF):
1260+
raise ValueError("Invalid MySQL identifier '%s'" % uname)
1261+
if not re.match(r'[^0-9]', uname):
1262+
raise ValueError('MySQL identifier cannot be entirely numeric')
1263+
1264+
return '`' + uname + '`'
1265+
1266+
1267+
def _get_valid_sqlite_name(name):
1268+
# See http://stackoverflow.com/questions/6514274/how-do-you-escape-strings-for-sqlite-table-column-names-in-python
1269+
# Ensure the string can be encoded as UTF-8.
1270+
# Ensure the string does not include any NUL characters.
1271+
# Replace all " with "".
1272+
# Wrap the entire thing in double quotes.
1273+
1274+
uname = _get_unicode_name(name)
1275+
if not len(uname):
1276+
raise ValueError("Empty table or column name specified")
1277+
1278+
nul_index = uname.find("\x00")
1279+
if nul_index >= 0:
1280+
raise ValueError('SQLite identifier cannot contain NULs')
1281+
return '"' + uname.replace('"', '""') + '"'
1282+
1283+
12421284
# SQL enquote and wildcard symbols
1243-
_SQL_SYMB = {
1244-
'mysql': {
1245-
'br_l': '`',
1246-
'br_r': '`',
1247-
'wld': '%s'
1248-
},
1249-
'sqlite': {
1250-
'br_l': '[',
1251-
'br_r': ']',
1252-
'wld': '?'
1253-
}
1285+
_SQL_WILDCARD = {
1286+
'mysql': '%s',
1287+
'sqlite': '?'
1288+
}
1289+
1290+
# Validate and return escaped identifier
1291+
_SQL_GET_IDENTIFIER = {
1292+
'mysql': _get_valid_mysql_name,
1293+
'sqlite': _get_valid_sqlite_name,
12541294
}
12551295

12561296

@@ -1276,18 +1316,17 @@ def _execute_create(self):
12761316
def insert_statement(self):
12771317
names = list(map(str, self.frame.columns))
12781318
flv = self.pd_sql.flavor
1279-
br_l = _SQL_SYMB[flv]['br_l'] # left val quote char
1280-
br_r = _SQL_SYMB[flv]['br_r'] # right val quote char
1281-
wld = _SQL_SYMB[flv]['wld'] # wildcard char
1319+
wld = _SQL_WILDCARD[flv] # wildcard char
1320+
escape = _SQL_GET_IDENTIFIER[flv]
12821321

12831322
if self.index is not None:
12841323
[names.insert(0, idx) for idx in self.index[::-1]]
12851324

1286-
bracketed_names = [br_l + column + br_r for column in names]
1325+
bracketed_names = [escape(column) for column in names]
12871326
col_names = ','.join(bracketed_names)
12881327
wildcards = ','.join([wld] * len(names))
12891328
insert_statement = 'INSERT INTO %s (%s) VALUES (%s)' % (
1290-
self.name, col_names, wildcards)
1329+
escape(self.name), col_names, wildcards)
12911330
return insert_statement
12921331

12931332
def _execute_insert(self, conn, keys, data_iter):
@@ -1309,29 +1348,28 @@ def _create_table_setup(self):
13091348
warnings.warn(_SAFE_NAMES_WARNING)
13101349

13111350
flv = self.pd_sql.flavor
1351+
escape = _SQL_GET_IDENTIFIER[flv]
13121352

1313-
br_l = _SQL_SYMB[flv]['br_l'] # left val quote char
1314-
br_r = _SQL_SYMB[flv]['br_r'] # right val quote char
1353+
create_tbl_stmts = [escape(cname) + ' ' + ctype
1354+
for cname, ctype, _ in column_names_and_types]
13151355

1316-
create_tbl_stmts = [(br_l + '%s' + br_r + ' %s') % (cname, col_type)
1317-
for cname, col_type, _ in column_names_and_types]
13181356
if self.keys is not None and len(self.keys):
1319-
cnames_br = ",".join([br_l + c + br_r for c in self.keys])
1357+
cnames_br = ",".join([escape(c) for c in self.keys])
13201358
create_tbl_stmts.append(
13211359
"CONSTRAINT {tbl}_pk PRIMARY KEY ({cnames_br})".format(
13221360
tbl=self.name, cnames_br=cnames_br))
13231361

1324-
create_stmts = ["CREATE TABLE " + self.name + " (\n" +
1362+
create_stmts = ["CREATE TABLE " + escape(self.name) + " (\n" +
13251363
',\n '.join(create_tbl_stmts) + "\n)"]
13261364

13271365
ix_cols = [cname for cname, _, is_index in column_names_and_types
13281366
if is_index]
13291367
if len(ix_cols):
13301368
cnames = "_".join(ix_cols)
1331-
cnames_br = ",".join([br_l + c + br_r for c in ix_cols])
1369+
cnames_br = ",".join([escape(c) for c in ix_cols])
13321370
create_stmts.append(
1333-
"CREATE INDEX ix_{tbl}_{cnames} ON {tbl} ({cnames_br})".format(
1334-
tbl=self.name, cnames=cnames, cnames_br=cnames_br))
1371+
"CREATE INDEX " + escape("ix_"+self.name+"_"+cnames) +
1372+
"ON " + escape(self.name) + " (" + cnames_br + ")")
13351373

13361374
return create_stmts
13371375

@@ -1505,19 +1543,23 @@ def to_sql(self, frame, name, if_exists='fail', index=True,
15051543
table.insert(chunksize)
15061544

15071545
def has_table(self, name, schema=None):
1546+
escape = _SQL_GET_IDENTIFIER[self.flavor]
1547+
esc_name = escape(name)
1548+
wld = _SQL_WILDCARD[self.flavor]
15081549
flavor_map = {
15091550
'sqlite': ("SELECT name FROM sqlite_master "
1510-
"WHERE type='table' AND name='%s';") % name,
1511-
'mysql': "SHOW TABLES LIKE '%s'" % name}
1551+
"WHERE type='table' AND name=%s;") % wld,
1552+
'mysql': "SHOW TABLES LIKE %s" % wld}
15121553
query = flavor_map.get(self.flavor)
15131554

1514-
return len(self.execute(query).fetchall()) > 0
1555+
return len(self.execute(query, [name,]).fetchall()) > 0
15151556

15161557
def get_table(self, table_name, schema=None):
15171558
return None # not supported in fallback mode
15181559

15191560
def drop_table(self, name, schema=None):
1520-
drop_sql = "DROP TABLE %s" % name
1561+
escape = _SQL_GET_IDENTIFIER[self.flavor]
1562+
drop_sql = "DROP TABLE %s" % escape(name)
15211563
self.execute(drop_sql)
15221564

15231565
def _create_sql_schema(self, frame, table_name, keys=None):

pandas/io/tests/test_sql.py

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -865,7 +865,7 @@ def test_uquery(self):
865865
def _get_sqlite_column_type(self, schema, column):
866866

867867
for col in schema.split('\n'):
868-
if col.split()[0].strip('[]') == column:
868+
if col.split()[0].strip('""') == column:
869869
return col.split()[1]
870870
raise ValueError('Column %s not found' % (column))
871871

@@ -1630,6 +1630,24 @@ def test_notnull_dtype(self):
16301630
self.assertEqual(self._get_sqlite_column_type(tbl, 'Int'), 'INTEGER')
16311631
self.assertEqual(self._get_sqlite_column_type(tbl, 'Float'), 'REAL')
16321632

1633+
def test_illegal_names(self):
1634+
# For sqlite, these should work fine
1635+
df = DataFrame([[1, 2], [3, 4]], columns=['a', 'b'])
1636+
1637+
# Raise error on blank
1638+
self.assertRaises(ValueError, df.to_sql, "", self.conn,
1639+
flavor=self.flavor)
1640+
1641+
for ndx, weird_name in enumerate(['test_weird_name]','test_weird_name[',
1642+
'test_weird_name`','test_weird_name"', 'test_weird_name\'']):
1643+
df.to_sql(weird_name, self.conn, flavor=self.flavor)
1644+
sql.table_exists(weird_name, self.conn)
1645+
1646+
df2 = DataFrame([[1, 2], [3, 4]], columns=['a', weird_name])
1647+
c_tbl = 'test_weird_col_name%d'%ndx
1648+
df.to_sql(c_tbl, self.conn, flavor=self.flavor)
1649+
sql.table_exists(c_tbl, self.conn)
1650+
16331651

16341652
class TestMySQLLegacy(TestSQLiteFallback):
16351653
"""
@@ -1721,6 +1739,19 @@ def test_to_sql_save_index(self):
17211739
def test_to_sql_save_index(self):
17221740
self._to_sql_save_index()
17231741

1742+
def test_illegal_names(self):
1743+
# For MySQL, these should raise ValueError
1744+
for ndx, illegal_name in enumerate(['test_illegal_name]','test_illegal_name[',
1745+
'test_illegal_name`','test_illegal_name"', 'test_illegal_name\'', '']):
1746+
df = DataFrame([[1, 2], [3, 4]], columns=['a', 'b'])
1747+
self.assertRaises(ValueError, df.to_sql, illegal_name, self.conn,
1748+
flavor=self.flavor, index=False)
1749+
1750+
df2 = DataFrame([[1, 2], [3, 4]], columns=['a', illegal_name])
1751+
c_tbl = 'test_illegal_col_name%d'%ndx
1752+
self.assertRaises(ValueError, df2.to_sql, 'test_illegal_col_name',
1753+
self.conn, flavor=self.flavor, index=False)
1754+
17241755

17251756
#------------------------------------------------------------------------------
17261757
#--- Old tests from 0.13.1 (before refactor using sqlalchemy)
@@ -1817,7 +1848,7 @@ def test_schema(self):
18171848
frame = tm.makeTimeDataFrame()
18181849
create_sql = sql.get_schema(frame, 'test', 'sqlite', keys=['A', 'B'],)
18191850
lines = create_sql.splitlines()
1820-
self.assertTrue('PRIMARY KEY ([A],[B])' in create_sql)
1851+
self.assertTrue('PRIMARY KEY ("A","B")' in create_sql)
18211852
cur = self.db.cursor()
18221853
cur.execute(create_sql)
18231854

0 commit comments

Comments
 (0)