Skip to content

Commit 0bc06e5

Browse files
committed
fix(sqlite): run sqlite3_reset() in StatementWorker
this avoids possible race conditions without using a mutex
1 parent 52868f3 commit 0bc06e5

File tree

5 files changed

+133
-97
lines changed

5 files changed

+133
-97
lines changed

sqlx-core/src/sqlite/connection/executor.rs

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ use crate::error::Error;
44
use crate::executor::{Execute, Executor};
55
use crate::logger::QueryLogger;
66
use crate::sqlite::connection::describe::describe;
7-
use crate::sqlite::statement::{StatementHandle, VirtualStatement};
7+
use crate::sqlite::statement::{StatementHandle, StatementWorker, VirtualStatement};
88
use crate::sqlite::{
99
Sqlite, SqliteArguments, SqliteConnection, SqliteQueryResult, SqliteRow, SqliteStatement,
1010
SqliteTypeInfo,
@@ -16,7 +16,8 @@ use libsqlite3_sys::sqlite3_last_insert_rowid;
1616
use std::borrow::Cow;
1717
use std::sync::Arc;
1818

19-
fn prepare<'a>(
19+
async fn prepare<'a>(
20+
worker: &mut StatementWorker,
2021
statements: &'a mut StatementCache<VirtualStatement>,
2122
statement: &'a mut Option<VirtualStatement>,
2223
query: &str,
@@ -39,7 +40,7 @@ fn prepare<'a>(
3940
if exists {
4041
// as this statement has been executed before, we reset before continuing
4142
// this also causes any rows that are from the statement to be inflated
42-
statement.reset();
43+
statement.reset(worker).await?;
4344
}
4445

4546
Ok(statement)
@@ -61,21 +62,25 @@ fn bind(
6162

6263
/// A structure holding sqlite statement handle and resetting the
6364
/// statement when it is dropped.
64-
struct StatementResetter {
65+
struct StatementResetter<'a> {
6566
handle: Arc<StatementHandle>,
67+
worker: &'a mut StatementWorker,
6668
}
6769

68-
impl StatementResetter {
69-
fn new(handle: &Arc<StatementHandle>) -> Self {
70+
impl<'a> StatementResetter<'a> {
71+
fn new(worker: &'a mut StatementWorker, handle: &Arc<StatementHandle>) -> Self {
7072
Self {
73+
worker,
7174
handle: Arc::clone(handle),
7275
}
7376
}
7477
}
7578

76-
impl Drop for StatementResetter {
79+
impl Drop for StatementResetter<'_> {
7780
fn drop(&mut self) {
78-
self.handle.reset();
81+
// this method is designed to eagerly send the reset command
82+
// so we don't need to await or spawn it
83+
let _ = self.worker.reset(&self.handle);
7984
}
8085
}
8186

@@ -105,7 +110,7 @@ impl<'c> Executor<'c> for &'c mut SqliteConnection {
105110
} = self;
106111

107112
// prepare statement object (or checkout from cache)
108-
let stmt = prepare(statements, statement, sql, persistent)?;
113+
let stmt = prepare(worker, statements, statement, sql, persistent).await?;
109114

110115
// keep track of how many arguments we have bound
111116
let mut num_arguments = 0;
@@ -115,7 +120,7 @@ impl<'c> Executor<'c> for &'c mut SqliteConnection {
115120
// is dropped. `StatementResetter` will reliably reset the
116121
// statement even if the stream returned from `fetch_many`
117122
// is dropped early.
118-
let _resetter = StatementResetter::new(stmt);
123+
let resetter = StatementResetter::new(worker, stmt);
119124

120125
// bind values to the statement
121126
num_arguments += bind(stmt, &arguments, num_arguments)?;
@@ -127,7 +132,7 @@ impl<'c> Executor<'c> for &'c mut SqliteConnection {
127132

128133
// invoke [sqlite3_step] on the dedicated worker thread
129134
// this will move us forward one row or finish the statement
130-
let s = worker.step(stmt).await?;
135+
let s = resetter.worker.step(stmt).await?;
131136

132137
match s {
133138
Either::Left(changes) => {
@@ -190,7 +195,7 @@ impl<'c> Executor<'c> for &'c mut SqliteConnection {
190195
} = self;
191196

192197
// prepare statement object (or checkout from cache)
193-
let virtual_stmt = prepare(statements, statement, sql, persistent)?;
198+
let virtual_stmt = prepare(worker, statements, statement, sql, persistent).await?;
194199

195200
// keep track of how many arguments we have bound
196201
let mut num_arguments = 0;
@@ -218,7 +223,7 @@ impl<'c> Executor<'c> for &'c mut SqliteConnection {
218223

219224
logger.increment_rows();
220225

221-
virtual_stmt.reset();
226+
virtual_stmt.reset(worker).await?;
222227
return Ok(Some(row));
223228
}
224229
}
@@ -240,11 +245,12 @@ impl<'c> Executor<'c> for &'c mut SqliteConnection {
240245
handle: ref mut conn,
241246
ref mut statements,
242247
ref mut statement,
248+
ref mut worker,
243249
..
244250
} = self;
245251

246252
// prepare statement object (or checkout from cache)
247-
let statement = prepare(statements, statement, sql, true)?;
253+
let statement = prepare(worker, statements, statement, sql, true).await?;
248254

249255
let mut parameters = 0;
250256
let mut columns = None;

sqlx-core/src/sqlite/statement/handle.rs

Lines changed: 8 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,11 @@
1-
use either::Either;
21
use std::ffi::c_void;
32
use std::ffi::CStr;
4-
use std::hint;
3+
54
use std::os::raw::{c_char, c_int};
65
use std::ptr;
76
use std::ptr::NonNull;
87
use std::slice::from_raw_parts;
98
use std::str::{from_utf8, from_utf8_unchecked};
10-
use std::sync::atomic::{AtomicU8, Ordering};
119

1210
use libsqlite3_sys::{
1311
sqlite3, sqlite3_bind_blob64, sqlite3_bind_double, sqlite3_bind_int, sqlite3_bind_int64,
@@ -27,7 +25,7 @@ use crate::sqlite::type_info::DataType;
2725
use crate::sqlite::{SqliteError, SqliteTypeInfo};
2826

2927
#[derive(Debug)]
30-
pub(crate) struct StatementHandle(NonNull<sqlite3_stmt>, Lock);
28+
pub(crate) struct StatementHandle(NonNull<sqlite3_stmt>);
3129

3230
// access to SQLite3 statement handles are safe to send and share between threads
3331
// as long as the `sqlite3_step` call is serialized.
@@ -37,7 +35,11 @@ unsafe impl Sync for StatementHandle {}
3735

3836
impl StatementHandle {
3937
pub(super) fn new(ptr: NonNull<sqlite3_stmt>) -> Self {
40-
Self(ptr, Lock::new())
38+
Self(ptr)
39+
}
40+
41+
pub(crate) fn as_ptr(&self) -> *mut sqlite3_stmt {
42+
self.0.as_ptr()
4143
}
4244

4345
#[inline]
@@ -288,41 +290,13 @@ impl StatementHandle {
288290
Ok(from_utf8(self.column_blob(index))?)
289291
}
290292

291-
pub(crate) fn step(&self) -> Result<Either<u64, ()>, Error> {
292-
self.1.enter_step();
293-
294-
let status = unsafe { sqlite3_step(self.0.as_ptr()) };
295-
let result = match status {
296-
SQLITE_ROW => Ok(Either::Right(())),
297-
SQLITE_DONE => Ok(Either::Left(self.changes())),
298-
_ => Err(self.last_error().into()),
299-
};
300-
301-
if self.1.exit_step() {
302-
unsafe { sqlite3_reset(self.0.as_ptr()) };
303-
self.1.exit_reset();
304-
}
305-
306-
result
307-
}
308-
309-
pub(crate) fn reset(&self) {
310-
if !self.1.enter_reset() {
311-
// reset or step already in progress
312-
return;
313-
}
314-
315-
unsafe { sqlite3_reset(self.0.as_ptr()) };
316-
317-
self.1.exit_reset();
318-
}
319-
320293
pub(crate) fn clear_bindings(&self) {
321294
unsafe { sqlite3_clear_bindings(self.0.as_ptr()) };
322295
}
323296
}
324297
impl Drop for StatementHandle {
325298
fn drop(&mut self) {
299+
// SAFETY: we have exclusive access to the `StatementHandle` here
326300
unsafe {
327301
// https://sqlite.org/c3ref/finalize.html
328302
let status = sqlite3_finalize(self.0.as_ptr());
@@ -338,44 +312,3 @@ impl Drop for StatementHandle {
338312
}
339313
}
340314
}
341-
342-
const RESET: u8 = 0b0000_0001;
343-
const STEP: u8 = 0b0000_0010;
344-
345-
// Lock to synchronize calls to `step` and `reset`.
346-
#[derive(Debug)]
347-
struct Lock(AtomicU8);
348-
349-
impl Lock {
350-
fn new() -> Self {
351-
Self(AtomicU8::new(0))
352-
}
353-
354-
// If this returns `true` reset can be performed, otherwise reset must be delayed until the
355-
// current step finishes and `exit_step` is called.
356-
fn enter_reset(&self) -> bool {
357-
self.0.fetch_or(RESET, Ordering::Acquire) == 0
358-
}
359-
360-
fn exit_reset(&self) {
361-
self.0.fetch_and(!RESET, Ordering::Release);
362-
}
363-
364-
fn enter_step(&self) {
365-
// NOTE: spin loop should be fine here as we are only waiting for a `reset` to finish which
366-
// should be quick.
367-
while self
368-
.0
369-
.compare_exchange(0, STEP, Ordering::Acquire, Ordering::Relaxed)
370-
.is_err()
371-
{
372-
hint::spin_loop();
373-
}
374-
}
375-
376-
// If this returns `true` it means a previous attempt to reset was delayed and must be
377-
// performed now (followed by `exit_reset`).
378-
fn exit_step(&self) -> bool {
379-
self.0.fetch_and(!STEP, Ordering::Release) & RESET != 0
380-
}
381-
}

sqlx-core/src/sqlite/statement/virtual.rs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
use crate::error::Error;
44
use crate::ext::ustr::UStr;
55
use crate::sqlite::connection::ConnectionHandle;
6-
use crate::sqlite::statement::StatementHandle;
6+
use crate::sqlite::statement::{StatementHandle, StatementWorker};
77
use crate::sqlite::{SqliteColumn, SqliteError, SqliteRow, SqliteValue};
88
use crate::HashMap;
99
use bytes::{Buf, Bytes};
@@ -176,7 +176,7 @@ impl VirtualStatement {
176176
)))
177177
}
178178

179-
pub(crate) fn reset(&mut self) {
179+
pub(crate) async fn reset(&mut self, worker: &mut StatementWorker) -> Result<(), Error> {
180180
self.index = 0;
181181

182182
for (i, handle) in self.handles.iter().enumerate() {
@@ -185,9 +185,11 @@ impl VirtualStatement {
185185
// Reset A Prepared Statement Object
186186
// https://www.sqlite.org/c3ref/reset.html
187187
// https://www.sqlite.org/c3ref/clear_bindings.html
188-
handle.reset();
188+
worker.reset(handle).await?;
189189
handle.clear_bindings();
190190
}
191+
192+
Ok(())
191193
}
192194
}
193195

sqlx-core/src/sqlite/statement/worker.rs

Lines changed: 66 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@ use futures_channel::oneshot;
66
use std::sync::{Arc, Weak};
77
use std::thread;
88

9+
use libsqlite3_sys::{sqlite3_reset, sqlite3_step, SQLITE_DONE, SQLITE_ROW};
10+
use std::future::Future;
11+
912
// Each SQLite connection has a dedicated thread.
1013

1114
// TODO: Tweak this so that we can use a thread pool per pool of SQLite3 connections to reduce
@@ -21,6 +24,10 @@ enum StatementWorkerCommand {
2124
statement: Weak<StatementHandle>,
2225
tx: oneshot::Sender<Result<Either<u64, ()>, Error>>,
2326
},
27+
Reset {
28+
statement: Weak<StatementHandle>,
29+
tx: oneshot::Sender<()>,
30+
},
2431
}
2532

2633
impl StatementWorker {
@@ -31,13 +38,37 @@ impl StatementWorker {
3138
for cmd in rx {
3239
match cmd {
3340
StatementWorkerCommand::Step { statement, tx } => {
34-
let resp = if let Some(statement) = statement.upgrade() {
35-
statement.step()
41+
let statement = if let Some(statement) = statement.upgrade() {
42+
statement
3643
} else {
37-
// Statement is already finalized.
38-
Err(Error::WorkerCrashed)
44+
// statement is already finalized, the sender shouldn't be expecting a response
45+
continue;
46+
};
47+
48+
// SAFETY: only the `StatementWorker` calls this function
49+
let status = unsafe { sqlite3_step(statement.as_ptr()) };
50+
let result = match status {
51+
SQLITE_ROW => Ok(Either::Right(())),
52+
SQLITE_DONE => Ok(Either::Left(statement.changes())),
53+
_ => Err(statement.last_error().into()),
3954
};
40-
let _ = tx.send(resp);
55+
56+
let _ = tx.send(result);
57+
}
58+
StatementWorkerCommand::Reset { statement, tx } => {
59+
if let Some(statement) = statement.upgrade() {
60+
// SAFETY: this must be the only place we call `sqlite3_reset`
61+
unsafe { sqlite3_reset(statement.as_ptr()) };
62+
63+
// `sqlite3_reset()` always returns either `SQLITE_OK`
64+
// or the last error code for the statement,
65+
// which should have already been handled;
66+
// so it's assumed the return value is safe to ignore.
67+
//
68+
// https://www.sqlite.org/c3ref/reset.html
69+
70+
let _ = tx.send(());
71+
}
4172
}
4273
}
4374
}
@@ -61,4 +92,34 @@ impl StatementWorker {
6192

6293
rx.await.map_err(|_| Error::WorkerCrashed)?
6394
}
95+
96+
/// Send a command to the worker to execute `sqlite3_reset()` next.
97+
///
98+
/// This method is written to execute the sending of the command eagerly so
99+
/// you do not need to await the returned future unless you want to.
100+
///
101+
/// The only error is `WorkerCrashed` as `sqlite3_reset()` returns the last error
102+
/// in the statement execution which should have already been handled from `step()`.
103+
pub(crate) fn reset(
104+
&mut self,
105+
statement: &Arc<StatementHandle>,
106+
) -> impl Future<Output = Result<(), Error>> {
107+
// execute the sending eagerly so we don't need to spawn the future
108+
let (tx, rx) = oneshot::channel();
109+
110+
let send_res = self
111+
.tx
112+
.send(StatementWorkerCommand::Reset {
113+
statement: Arc::downgrade(statement),
114+
tx,
115+
})
116+
.map_err(|_| Error::WorkerCrashed);
117+
118+
async move {
119+
send_res?;
120+
121+
// wait for the response
122+
rx.await.map_err(|_| Error::WorkerCrashed)
123+
}
124+
}
64125
}

0 commit comments

Comments
 (0)