Skip to content

Commit 9d6215a

Browse files
author
Erlend Egeberg Aasland
authored
bpo-45126: Harden sqlite3 connection initialisation (GH-28227)
1 parent 6a84d61 commit 9d6215a

File tree

3 files changed

+110
-62
lines changed

3 files changed

+110
-62
lines changed

Lib/test/test_sqlite3/test_dbapi.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -523,6 +523,44 @@ def test_connection_init_good_isolation_levels(self):
523523
with memory_database(isolation_level=level) as cx:
524524
cx.execute("select 'ok'")
525525

526+
def test_connection_reinit(self):
527+
db = ":memory:"
528+
cx = sqlite.connect(db)
529+
cx.text_factory = bytes
530+
cx.row_factory = sqlite.Row
531+
cu = cx.cursor()
532+
cu.execute("create table foo (bar)")
533+
cu.executemany("insert into foo (bar) values (?)",
534+
((str(v),) for v in range(4)))
535+
cu.execute("select bar from foo")
536+
537+
rows = [r for r in cu.fetchmany(2)]
538+
self.assertTrue(all(isinstance(r, sqlite.Row) for r in rows))
539+
self.assertEqual([r[0] for r in rows], [b"0", b"1"])
540+
541+
cx.__init__(db)
542+
cx.execute("create table foo (bar)")
543+
cx.executemany("insert into foo (bar) values (?)",
544+
((v,) for v in ("a", "b", "c", "d")))
545+
546+
# This uses the old database, old row factory, but new text factory
547+
rows = [r for r in cu.fetchall()]
548+
self.assertTrue(all(isinstance(r, sqlite.Row) for r in rows))
549+
self.assertEqual([r[0] for r in rows], ["2", "3"])
550+
551+
def test_connection_bad_reinit(self):
552+
cx = sqlite.connect(":memory:")
553+
with cx:
554+
cx.execute("create table t(t)")
555+
with temp_dir() as db:
556+
self.assertRaisesRegex(sqlite.OperationalError,
557+
"unable to open database file",
558+
cx.__init__, db)
559+
self.assertRaisesRegex(sqlite.ProgrammingError,
560+
"Base Connection.__init__ not called",
561+
cx.executemany, "insert into t values(?)",
562+
((v,) for v in range(3)))
563+
526564

527565
class UninitialisedConnectionTests(unittest.TestCase):
528566
def setUp(self):

Modules/_sqlite/clinic/connection.c.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ pysqlite_connection_init_impl(pysqlite_Connection *self,
77
const char *database, double timeout,
88
int detect_types, const char *isolation_level,
99
int check_same_thread, PyObject *factory,
10-
int cached_statements, int uri);
10+
int cache_size, int uri);
1111

1212
static int
1313
pysqlite_connection_init(PyObject *self, PyObject *args, PyObject *kwargs)
@@ -25,7 +25,7 @@ pysqlite_connection_init(PyObject *self, PyObject *args, PyObject *kwargs)
2525
const char *isolation_level = "";
2626
int check_same_thread = 1;
2727
PyObject *factory = (PyObject*)clinic_state()->ConnectionType;
28-
int cached_statements = 128;
28+
int cache_size = 128;
2929
int uri = 0;
3030

3131
fastargs = _PyArg_UnpackKeywords(_PyTuple_CAST(args)->ob_item, nargs, kwargs, NULL, &_parser, 1, 8, 0, argsbuf);
@@ -101,8 +101,8 @@ pysqlite_connection_init(PyObject *self, PyObject *args, PyObject *kwargs)
101101
}
102102
}
103103
if (fastargs[6]) {
104-
cached_statements = _PyLong_AsInt(fastargs[6]);
105-
if (cached_statements == -1 && PyErr_Occurred()) {
104+
cache_size = _PyLong_AsInt(fastargs[6]);
105+
if (cache_size == -1 && PyErr_Occurred()) {
106106
goto exit;
107107
}
108108
if (!--noptargs) {
@@ -114,7 +114,7 @@ pysqlite_connection_init(PyObject *self, PyObject *args, PyObject *kwargs)
114114
goto exit;
115115
}
116116
skip_optional_pos:
117-
return_value = pysqlite_connection_init_impl((pysqlite_Connection *)self, database, timeout, detect_types, isolation_level, check_same_thread, factory, cached_statements, uri);
117+
return_value = pysqlite_connection_init_impl((pysqlite_Connection *)self, database, timeout, detect_types, isolation_level, check_same_thread, factory, cache_size, uri);
118118

119119
exit:
120120
/* Cleanup for database */
@@ -851,4 +851,4 @@ getlimit(pysqlite_Connection *self, PyObject *arg)
851851
#ifndef PYSQLITE_CONNECTION_LOAD_EXTENSION_METHODDEF
852852
#define PYSQLITE_CONNECTION_LOAD_EXTENSION_METHODDEF
853853
#endif /* !defined(PYSQLITE_CONNECTION_LOAD_EXTENSION_METHODDEF) */
854-
/*[clinic end generated code: output=663b1e9e71128f19 input=a9049054013a1b77]*/
854+
/*[clinic end generated code: output=6f267f20e77f92d0 input=a9049054013a1b77]*/

Modules/_sqlite/connection.c

Lines changed: 66 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -83,15 +83,17 @@ static void _pysqlite_drop_unused_cursor_references(pysqlite_Connection* self);
8383
static void free_callback_context(callback_context *ctx);
8484
static void set_callback_context(callback_context **ctx_pp,
8585
callback_context *ctx);
86+
static void connection_close(pysqlite_Connection *self);
8687

8788
static PyObject *
88-
new_statement_cache(pysqlite_Connection *self, int maxsize)
89+
new_statement_cache(pysqlite_Connection *self, pysqlite_state *state,
90+
int maxsize)
8991
{
9092
PyObject *args[] = { NULL, PyLong_FromLong(maxsize), };
9193
if (args[1] == NULL) {
9294
return NULL;
9395
}
94-
PyObject *lru_cache = self->state->lru_cache;
96+
PyObject *lru_cache = state->lru_cache;
9597
size_t nargsf = 1 | PY_VECTORCALL_ARGUMENTS_OFFSET;
9698
PyObject *inner = PyObject_Vectorcall(lru_cache, args + 1, nargsf, NULL);
9799
Py_DECREF(args[1]);
@@ -153,7 +155,7 @@ _sqlite3.Connection.__init__ as pysqlite_connection_init
153155
isolation_level: str(accept={str, NoneType}) = ""
154156
check_same_thread: bool(accept={int}) = True
155157
factory: object(c_default='(PyObject*)clinic_state()->ConnectionType') = ConnectionType
156-
cached_statements: int = 128
158+
cached_statements as cache_size: int = 128
157159
uri: bool = False
158160
[clinic start generated code]*/
159161

@@ -162,78 +164,82 @@ pysqlite_connection_init_impl(pysqlite_Connection *self,
162164
const char *database, double timeout,
163165
int detect_types, const char *isolation_level,
164166
int check_same_thread, PyObject *factory,
165-
int cached_statements, int uri)
166-
/*[clinic end generated code: output=d8c37afc46d318b0 input=adfb29ac461f9e61]*/
167+
int cache_size, int uri)
168+
/*[clinic end generated code: output=7d640ae1d83abfd4 input=35e316f66d9f70fd]*/
167169
{
168-
int rc;
169-
170170
if (PySys_Audit("sqlite3.connect", "s", database) < 0) {
171171
return -1;
172172
}
173173

174-
pysqlite_state *state = pysqlite_get_state_by_type(Py_TYPE(self));
175-
self->state = state;
176-
177-
Py_CLEAR(self->statement_cache);
178-
Py_CLEAR(self->cursors);
179-
180-
Py_INCREF(Py_None);
181-
Py_XSETREF(self->row_factory, Py_None);
182-
183-
Py_INCREF(&PyUnicode_Type);
184-
Py_XSETREF(self->text_factory, (PyObject*)&PyUnicode_Type);
174+
if (self->initialized) {
175+
PyTypeObject *tp = Py_TYPE(self);
176+
tp->tp_clear((PyObject *)self);
177+
connection_close(self);
178+
self->initialized = 0;
179+
}
185180

181+
// Create and configure SQLite database object.
182+
sqlite3 *db;
183+
int rc;
186184
Py_BEGIN_ALLOW_THREADS
187-
rc = sqlite3_open_v2(database, &self->db,
185+
rc = sqlite3_open_v2(database, &db,
188186
SQLITE_OPEN_READWRITE | SQLITE_OPEN_CREATE |
189187
(uri ? SQLITE_OPEN_URI : 0), NULL);
188+
if (rc == SQLITE_OK) {
189+
(void)sqlite3_busy_timeout(db, (int)(timeout*1000));
190+
}
190191
Py_END_ALLOW_THREADS
191192

192-
if (self->db == NULL && rc == SQLITE_NOMEM) {
193+
if (db == NULL && rc == SQLITE_NOMEM) {
193194
PyErr_NoMemory();
194195
return -1;
195196
}
197+
198+
pysqlite_state *state = pysqlite_get_state_by_type(Py_TYPE(self));
196199
if (rc != SQLITE_OK) {
197-
_pysqlite_seterror(state, self->db);
200+
_pysqlite_seterror(state, db);
198201
return -1;
199202
}
200203

201-
if (isolation_level) {
202-
const char *stmt = get_begin_statement(isolation_level);
203-
if (stmt == NULL) {
204+
// Convert isolation level to begin statement.
205+
const char *begin_statement = NULL;
206+
if (isolation_level != NULL) {
207+
begin_statement = get_begin_statement(isolation_level);
208+
if (begin_statement == NULL) {
204209
return -1;
205210
}
206-
self->begin_statement = stmt;
207-
}
208-
else {
209-
self->begin_statement = NULL;
210211
}
211212

212-
self->statement_cache = new_statement_cache(self, cached_statements);
213-
if (self->statement_cache == NULL) {
214-
return -1;
215-
}
216-
if (PyErr_Occurred()) {
213+
// Create LRU statement cache; returns a new reference.
214+
PyObject *statement_cache = new_statement_cache(self, state, cache_size);
215+
if (statement_cache == NULL) {
217216
return -1;
218217
}
219218

220-
self->created_cursors = 0;
221-
222-
/* Create list of weak references to cursors */
223-
self->cursors = PyList_New(0);
224-
if (self->cursors == NULL) {
219+
// Create list of weak references to cursors.
220+
PyObject *cursors = PyList_New(0);
221+
if (cursors == NULL) {
222+
Py_DECREF(statement_cache);
225223
return -1;
226224
}
227225

226+
// Init connection state members.
227+
self->db = db;
228+
self->state = state;
228229
self->detect_types = detect_types;
229-
(void)sqlite3_busy_timeout(self->db, (int)(timeout*1000));
230-
self->thread_ident = PyThread_get_thread_ident();
230+
self->begin_statement = begin_statement;
231231
self->check_same_thread = check_same_thread;
232+
self->thread_ident = PyThread_get_thread_ident();
233+
self->statement_cache = statement_cache;
234+
self->cursors = cursors;
235+
self->created_cursors = 0;
236+
self->row_factory = Py_NewRef(Py_None);
237+
self->text_factory = Py_NewRef(&PyUnicode_Type);
238+
self->trace_ctx = NULL;
239+
self->progress_ctx = NULL;
240+
self->authorizer_ctx = NULL;
232241

233-
set_callback_context(&self->trace_ctx, NULL);
234-
set_callback_context(&self->progress_ctx, NULL);
235-
set_callback_context(&self->authorizer_ctx, NULL);
236-
242+
// Borrowed refs
237243
self->Warning = state->Warning;
238244
self->Error = state->Error;
239245
self->InterfaceError = state->InterfaceError;
@@ -250,7 +256,6 @@ pysqlite_connection_init_impl(pysqlite_Connection *self,
250256
}
251257

252258
self->initialized = 1;
253-
254259
return 0;
255260
}
256261

@@ -321,16 +326,6 @@ connection_clear(pysqlite_Connection *self)
321326
return 0;
322327
}
323328

324-
static void
325-
connection_close(pysqlite_Connection *self)
326-
{
327-
if (self->db) {
328-
int rc = sqlite3_close_v2(self->db);
329-
assert(rc == SQLITE_OK), (void)rc;
330-
self->db = NULL;
331-
}
332-
}
333-
334329
static void
335330
free_callback_contexts(pysqlite_Connection *self)
336331
{
@@ -339,6 +334,22 @@ free_callback_contexts(pysqlite_Connection *self)
339334
set_callback_context(&self->authorizer_ctx, NULL);
340335
}
341336

337+
static void
338+
connection_close(pysqlite_Connection *self)
339+
{
340+
if (self->db) {
341+
free_callback_contexts(self);
342+
343+
sqlite3 *db = self->db;
344+
self->db = NULL;
345+
346+
Py_BEGIN_ALLOW_THREADS
347+
int rc = sqlite3_close_v2(db);
348+
assert(rc == SQLITE_OK), (void)rc;
349+
Py_END_ALLOW_THREADS
350+
}
351+
}
352+
342353
static void
343354
connection_dealloc(pysqlite_Connection *self)
344355
{
@@ -348,7 +359,6 @@ connection_dealloc(pysqlite_Connection *self)
348359

349360
/* Clean up if user has not called .close() explicitly. */
350361
connection_close(self);
351-
free_callback_contexts(self);
352362

353363
tp->tp_free(self);
354364
Py_DECREF(tp);

0 commit comments

Comments
 (0)