Skip to content

Commit 1731d27

Browse files
authored
Fix and optimize fetching dict rows. (#458)
1 parent 329bae7 commit 1731d27

File tree

3 files changed

+136
-36
lines changed

3 files changed

+136
-36
lines changed

.github/workflows/windows.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@ name: Build windows wheels
22

33
on:
44
push:
5+
branches:
6+
- master
57
workflow_dispatch:
68

79
jobs:

MySQLdb/_mysql.c

Lines changed: 95 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1194,7 +1194,8 @@ _mysql_field_to_python(
11941194
static PyObject *
11951195
_mysql_row_to_tuple(
11961196
_mysql_ResultObject *self,
1197-
MYSQL_ROW row)
1197+
MYSQL_ROW row,
1198+
PyObject *unused)
11981199
{
11991200
unsigned int n, i;
12001201
unsigned long *length;
@@ -1221,7 +1222,8 @@ _mysql_row_to_tuple(
12211222
static PyObject *
12221223
_mysql_row_to_dict(
12231224
_mysql_ResultObject *self,
1224-
MYSQL_ROW row)
1225+
MYSQL_ROW row,
1226+
PyObject *cache)
12251227
{
12261228
unsigned int n, i;
12271229
unsigned long *length;
@@ -1243,40 +1245,42 @@ _mysql_row_to_dict(
12431245
Py_DECREF(v);
12441246
goto error;
12451247
}
1246-
1247-
PyObject *tmp = PyDict_SetDefault(r, pyname, v);
1248-
Py_DECREF(pyname);
1249-
if (!tmp) {
1248+
int err = PyDict_Contains(r, pyname);
1249+
if (err < 0) { // error
12501250
Py_DECREF(v);
12511251
goto error;
12521252
}
1253-
if (tmp == v) {
1254-
Py_DECREF(v);
1255-
continue;
1253+
if (err) { // duplicate
1254+
Py_DECREF(pyname);
1255+
pyname = PyUnicode_FromFormat("%s.%s", fields[i].table, fields[i].name);
1256+
if (pyname == NULL) {
1257+
Py_DECREF(v);
1258+
goto error;
1259+
}
12561260
}
12571261

1258-
pyname = PyUnicode_FromFormat("%s.%s", fields[i].table, fields[i].name);
1259-
if (!pyname) {
1260-
Py_DECREF(v);
1261-
goto error;
1262+
err = PyDict_SetItem(r, pyname, v);
1263+
if (cache) {
1264+
PyTuple_SET_ITEM(cache, i, pyname);
1265+
} else {
1266+
Py_DECREF(pyname);
12621267
}
1263-
int err = PyDict_SetItem(r, pyname, v);
1264-
Py_DECREF(pyname);
12651268
Py_DECREF(v);
12661269
if (err) {
12671270
goto error;
12681271
}
12691272
}
12701273
return r;
1271-
error:
1272-
Py_XDECREF(r);
1274+
error:
1275+
Py_DECREF(r);
12731276
return NULL;
12741277
}
12751278

12761279
static PyObject *
12771280
_mysql_row_to_dict_old(
12781281
_mysql_ResultObject *self,
1279-
MYSQL_ROW row)
1282+
MYSQL_ROW row,
1283+
PyObject *cache)
12801284
{
12811285
unsigned int n, i;
12821286
unsigned long *length;
@@ -1302,8 +1306,12 @@ _mysql_row_to_dict_old(
13021306
pyname = PyUnicode_FromString(fields[i].name);
13031307
}
13041308
int err = PyDict_SetItem(r, pyname, v);
1305-
Py_DECREF(pyname);
13061309
Py_DECREF(v);
1310+
if (cache) {
1311+
PyTuple_SET_ITEM(cache, i, pyname);
1312+
} else {
1313+
Py_DECREF(pyname);
1314+
}
13071315
if (err) {
13081316
goto error;
13091317
}
@@ -1314,15 +1322,66 @@ _mysql_row_to_dict_old(
13141322
return NULL;
13151323
}
13161324

1317-
typedef PyObject *_PYFUNC(_mysql_ResultObject *, MYSQL_ROW);
1325+
static PyObject *
1326+
_mysql_row_to_dict_cached(
1327+
_mysql_ResultObject *self,
1328+
MYSQL_ROW row,
1329+
PyObject *cache)
1330+
{
1331+
PyObject *r = PyDict_New();
1332+
if (!r) {
1333+
return NULL;
1334+
}
1335+
1336+
unsigned int n = mysql_num_fields(self->result);
1337+
unsigned long *length = mysql_fetch_lengths(self->result);
1338+
MYSQL_FIELD *fields = mysql_fetch_fields(self->result);
1339+
1340+
for (unsigned int i=0; i<n; i++) {
1341+
PyObject *c = PyTuple_GET_ITEM(self->converter, i);
1342+
PyObject *v = _mysql_field_to_python(c, row[i], length[i], &fields[i], self->encoding);
1343+
if (!v) {
1344+
goto error;
1345+
}
1346+
1347+
PyObject *pyname = PyTuple_GET_ITEM(cache, i); // borrowed
1348+
int err = PyDict_SetItem(r, pyname, v);
1349+
Py_DECREF(v);
1350+
if (err) {
1351+
goto error;
1352+
}
1353+
}
1354+
return r;
1355+
error:
1356+
Py_XDECREF(r);
1357+
return NULL;
1358+
}
1359+
1360+
1361+
typedef PyObject *_convertfunc(_mysql_ResultObject *, MYSQL_ROW, PyObject *);
1362+
static _convertfunc * const row_converters[] = {
1363+
_mysql_row_to_tuple,
1364+
_mysql_row_to_dict,
1365+
_mysql_row_to_dict_old
1366+
};
13181367

13191368
Py_ssize_t
13201369
_mysql__fetch_row(
13211370
_mysql_ResultObject *self,
13221371
PyObject *r, /* list object */
13231372
Py_ssize_t maxrows,
1324-
_PYFUNC *convert_row)
1373+
int how)
13251374
{
1375+
_convertfunc *convert_row = row_converters[how];
1376+
1377+
PyObject *cache = NULL;
1378+
if (maxrows > 0 && how > 0) {
1379+
cache = PyTuple_New(mysql_num_fields(self->result));
1380+
if (!cache) {
1381+
return -1;
1382+
}
1383+
}
1384+
13261385
Py_ssize_t i;
13271386
for (i = 0; i < maxrows; i++) {
13281387
MYSQL_ROW row;
@@ -1335,20 +1394,29 @@ _mysql__fetch_row(
13351394
}
13361395
if (!row && mysql_errno(&(((_mysql_ConnectionObject *)(self->conn))->connection))) {
13371396
_mysql_Exception((_mysql_ConnectionObject *)self->conn);
1338-
return -1;
1397+
goto error;
13391398
}
13401399
if (!row) {
13411400
break;
13421401
}
1343-
PyObject *v = convert_row(self, row);
1344-
if (!v) return -1;
1402+
PyObject *v = convert_row(self, row, cache);
1403+
if (!v) {
1404+
goto error;
1405+
}
1406+
if (cache) {
1407+
convert_row = _mysql_row_to_dict_cached;
1408+
}
13451409
if (PyList_Append(r, v)) {
13461410
Py_DECREF(v);
1347-
return -1;
1411+
goto error;
13481412
}
13491413
Py_DECREF(v);
13501414
}
1415+
Py_XDECREF(cache);
13511416
return i;
1417+
error:
1418+
Py_XDECREF(cache);
1419+
return -1;
13521420
}
13531421

13541422
static char _mysql_ResultObject_fetch_row__doc__[] =
@@ -1366,15 +1434,7 @@ _mysql_ResultObject_fetch_row(
13661434
PyObject *args,
13671435
PyObject *kwargs)
13681436
{
1369-
typedef PyObject *_PYFUNC(_mysql_ResultObject *, MYSQL_ROW);
1370-
static char *kwlist[] = { "maxrows", "how", NULL };
1371-
static _PYFUNC *row_converters[] =
1372-
{
1373-
_mysql_row_to_tuple,
1374-
_mysql_row_to_dict,
1375-
_mysql_row_to_dict_old
1376-
};
1377-
_PYFUNC *convert_row;
1437+
static char *kwlist[] = {"maxrows", "how", NULL };
13781438
int maxrows=1, how=0;
13791439
PyObject *r=NULL;
13801440

@@ -1386,7 +1446,6 @@ _mysql_ResultObject_fetch_row(
13861446
PyErr_SetString(PyExc_ValueError, "how out of range");
13871447
return NULL;
13881448
}
1389-
convert_row = row_converters[how];
13901449
if (!maxrows) {
13911450
if (self->use) {
13921451
maxrows = INT_MAX;
@@ -1396,7 +1455,7 @@ _mysql_ResultObject_fetch_row(
13961455
}
13971456
}
13981457
if (!(r = PyList_New(0))) goto error;
1399-
Py_ssize_t rowsadded = _mysql__fetch_row(self, r, maxrows, convert_row);
1458+
Py_ssize_t rowsadded = _mysql__fetch_row(self, r, maxrows, how);
14001459
if (rowsadded == -1) goto error;
14011460

14021461
/* DB-API allows return rows as list.

tests/test_cursor.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,3 +111,42 @@ def test_pyparam():
111111
assert cursor._executed == b"SELECT 1, 2"
112112
cursor.execute(b"SELECT %(a)s, %(b)s", {b"a": 3, b"b": 4})
113113
assert cursor._executed == b"SELECT 3, 4"
114+
115+
116+
def test_dictcursor():
117+
conn = connect()
118+
cursor = conn.cursor(MySQLdb.cursors.DictCursor)
119+
120+
cursor.execute("CREATE TABLE t1 (a int, b int, c int)")
121+
_tables.append("t1")
122+
cursor.execute("INSERT INTO t1 (a,b,c) VALUES (1,1,47), (2,2,47)")
123+
124+
cursor.execute("CREATE TABLE t2 (b int, c int)")
125+
_tables.append("t2")
126+
cursor.execute("INSERT INTO t2 (b,c) VALUES (1,1), (2,2)")
127+
128+
cursor.execute("SELECT * FROM t1 JOIN t2 ON t1.b=t2.b")
129+
rows = cursor.fetchall()
130+
131+
assert len(rows) == 2
132+
assert rows[0] == {"a": 1, "b": 1, "c": 47, "t2.b": 1, "t2.c": 1}
133+
assert rows[1] == {"a": 2, "b": 2, "c": 47, "t2.b": 2, "t2.c": 2}
134+
135+
names1 = sorted(rows[0])
136+
names2 = sorted(rows[1])
137+
for a, b in zip(names1, names2):
138+
assert a is b
139+
140+
# Old fetchtype
141+
cursor._fetch_type = 2
142+
cursor.execute("SELECT * FROM t1 JOIN t2 ON t1.b=t2.b")
143+
rows = cursor.fetchall()
144+
145+
assert len(rows) == 2
146+
assert rows[0] == {"t1.a": 1, "t1.b": 1, "t1.c": 47, "t2.b": 1, "t2.c": 1}
147+
assert rows[1] == {"t1.a": 2, "t1.b": 2, "t1.c": 47, "t2.b": 2, "t2.c": 2}
148+
149+
names1 = sorted(rows[0])
150+
names2 = sorted(rows[1])
151+
for a, b in zip(names1, names2):
152+
assert a is b

0 commit comments

Comments
 (0)