diff --git a/blob.go b/blob.go index bf3a275a..ea7caf9d 100644 --- a/blob.go +++ b/blob.go @@ -31,6 +31,10 @@ var _ io.ReadWriteSeeker = &Blob{} // // https://sqlite.org/c3ref/blob_open.html func (c *Conn) OpenBlob(db, table, column string, row int64, write bool) (*Blob, error) { + if c.interrupt.Err() != nil { + return nil, INTERRUPT + } + defer c.arena.mark()() blobPtr := c.arena.new(ptrlen) dbPtr := c.arena.string(db) @@ -42,7 +46,6 @@ func (c *Conn) OpenBlob(db, table, column string, row int64, write bool) (*Blob, flags = 1 } - c.checkInterrupt() rc := res_t(c.call("sqlite3_blob_open", stk_t(c.handle), stk_t(dbPtr), stk_t(tablePtr), stk_t(columnPtr), stk_t(row), stk_t(flags), stk_t(blobPtr))) @@ -253,7 +256,9 @@ func (b *Blob) Seek(offset int64, whence int) (int64, error) { // // https://sqlite.org/c3ref/blob_reopen.html func (b *Blob) Reopen(row int64) error { - b.c.checkInterrupt() + if b.c.interrupt.Err() != nil { + return INTERRUPT + } err := b.c.error(res_t(b.c.call("sqlite3_blob_reopen", stk_t(b.handle), stk_t(row)))) b.bytes = int64(int32(b.c.call("sqlite3_blob_bytes", stk_t(b.handle)))) b.offset = 0 diff --git a/config.go b/config.go index decea181..3921fe98 100644 --- a/config.go +++ b/config.go @@ -275,12 +275,14 @@ func traceCallback(ctx context.Context, mod api.Module, evt TraceEvent, pDB, pAr // // https://sqlite.org/c3ref/wal_checkpoint_v2.html func (c *Conn) WALCheckpoint(schema string, mode CheckpointMode) (nLog, nCkpt int, err error) { + if c.interrupt.Err() != nil { + return 0, 0, INTERRUPT + } + defer c.arena.mark()() nLogPtr := c.arena.new(ptrlen) nCkptPtr := c.arena.new(ptrlen) schemaPtr := c.arena.string(schema) - - c.checkInterrupt() rc := res_t(c.call("sqlite3_wal_checkpoint_v2", stk_t(c.handle), stk_t(schemaPtr), stk_t(mode), stk_t(nLogPtr), stk_t(nCkptPtr))) diff --git a/conn.go b/conn.go index 06b6637e..10d3cd27 100644 --- a/conn.go +++ b/conn.go @@ -40,8 +40,6 @@ type Conn struct { busylst time.Time arena arena handle ptr_t - pending ptr_t - stepped bool gosched uint8 } @@ -167,9 +165,6 @@ func (c *Conn) Close() error { return nil } - c.call("sqlite3_finalize", stk_t(c.pending)) - c.pending = 0 - rc := res_t(c.call("sqlite3_close", stk_t(c.handle))) if err := c.error(rc); err != nil { return err @@ -184,10 +179,15 @@ func (c *Conn) Close() error { // // https://sqlite.org/c3ref/exec.html func (c *Conn) Exec(sql string) error { + if c.interrupt.Err() != nil { + return INTERRUPT + } + return c.exec(sql) +} + +func (c *Conn) exec(sql string) error { defer c.arena.mark()() textPtr := c.arena.string(sql) - - c.checkInterrupt() rc := res_t(c.call("sqlite3_exec", stk_t(c.handle), stk_t(textPtr), 0, 0, 0)) return c.error(rc, sql) } @@ -207,13 +207,15 @@ func (c *Conn) PrepareFlags(sql string, flags PrepareFlag) (stmt *Stmt, tail str if len(sql) > _MAX_SQL_LENGTH { return nil, "", TOOBIG } + if c.interrupt.Err() != nil { + return nil, "", INTERRUPT + } defer c.arena.mark()() stmtPtr := c.arena.new(ptrlen) tailPtr := c.arena.new(ptrlen) textPtr := c.arena.string(sql) - c.checkInterrupt() rc := res_t(c.call("sqlite3_prepare_v3", stk_t(c.handle), stk_t(textPtr), stk_t(len(sql)+1), stk_t(flags), stk_t(stmtPtr), stk_t(tailPtr))) @@ -343,42 +345,9 @@ func (c *Conn) GetInterrupt() context.Context { func (c *Conn) SetInterrupt(ctx context.Context) (old context.Context) { old = c.interrupt c.interrupt = ctx - - if ctx == old { - return old - } - - // An active SQL statement prevents SQLite from ignoring an interrupt - // that comes before any other statements are started. - if c.pending == 0 { - defer c.arena.mark()() - stmtPtr := c.arena.new(ptrlen) - textPtr := c.arena.string(`SELECT 0 UNION ALL SELECT 0`) - c.call("sqlite3_prepare_v3", stk_t(c.handle), stk_t(textPtr), math.MaxUint64, - stk_t(PREPARE_PERSISTENT), stk_t(stmtPtr), 0) - c.pending = util.Read32[ptr_t](c.mod, stmtPtr) - } - - if c.stepped && ctx.Err() == nil { - c.call("sqlite3_reset", stk_t(c.pending)) - c.stepped = false - } else { - c.checkInterrupt() - } return old } -func (c *Conn) checkInterrupt() { - if c.interrupt.Err() == nil { - return - } - if !c.stepped { - c.call("sqlite3_step", stk_t(c.pending)) - c.stepped = true - } - c.call("sqlite3_interrupt", stk_t(c.handle)) -} - func progressCallback(ctx context.Context, mod api.Module, _ ptr_t) (interrupt int32) { if c, ok := ctx.Value(connKey{}).(*Conn); ok { if c.gosched++; c.gosched%16 == 0 { diff --git a/driver/driver_test.go b/driver/driver_test.go index 4ed030f4..fb839b87 100644 --- a/driver/driver_test.go +++ b/driver/driver_test.go @@ -199,6 +199,62 @@ func Test_BeginTx(t *testing.T) { } } +func Test_nested_context(t *testing.T) { + t.Parallel() + tmp := memdb.TestDB(t) + + db, err := sql.Open("sqlite3", tmp) + if err != nil { + t.Fatal(err) + } + defer db.Close() + + tx, err := db.Begin() + if err != nil { + t.Fatal(err) + } + defer tx.Rollback() + + outer, err := tx.Query(`SELECT value FROM generate_series(0)`) + if err != nil { + t.Fatal(err) + } + defer outer.Close() + + want := func(rows *sql.Rows, want int) { + t.Helper() + + var got int + rows.Next() + if err := rows.Scan(&got); err != nil { + t.Fatal(err) + } + if got != want { + t.Errorf("got %d, want %d", got, want) + } + } + + want(outer, 0) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + inner, err := tx.QueryContext(ctx, `SELECT value FROM generate_series(0)`) + if err != nil { + t.Fatal(err) + } + defer inner.Close() + + want(inner, 0) + cancel() + + if inner.Next() || !errors.Is(inner.Err(), sqlite3.INTERRUPT) { + t.Fatal(inner.Err()) + } + + want(outer, 1) +} + func Test_Prepare(t *testing.T) { t.Parallel() tmp := memdb.TestDB(t) diff --git a/stmt.go b/stmt.go index 5314595a..50dd5e32 100644 --- a/stmt.go +++ b/stmt.go @@ -106,7 +106,11 @@ func (s *Stmt) Busy() bool { // // https://sqlite.org/c3ref/step.html func (s *Stmt) Step() bool { - s.c.checkInterrupt() + if s.c.interrupt.Err() != nil { + s.err = INTERRUPT + return false + } + rc := res_t(s.c.call("sqlite3_step", stk_t(s.handle))) switch rc { case _ROW: diff --git a/txn.go b/txn.go index b24789f8..a21b99ac 100644 --- a/txn.go +++ b/txn.go @@ -2,7 +2,6 @@ package sqlite3 import ( "context" - "errors" "math/rand" "runtime" "strconv" @@ -25,7 +24,7 @@ type Txn struct { // https://sqlite.org/lang_transaction.html func (c *Conn) Begin() Txn { // BEGIN even if interrupted. - err := c.txnExecInterrupted(`BEGIN DEFERRED`) + err := c.exec(`BEGIN DEFERRED`) if err != nil { panic(err) } @@ -120,7 +119,7 @@ func (tx Txn) Commit() error { // // https://sqlite.org/lang_transaction.html func (tx Txn) Rollback() error { - return tx.c.txnExecInterrupted(`ROLLBACK`) + return tx.c.exec(`ROLLBACK`) } // Savepoint is a marker within a transaction @@ -143,7 +142,7 @@ func (c *Conn) Savepoint() Savepoint { // Names can be reused, but this makes catching bugs more likely. name = QuoteIdentifier(name + "_" + strconv.Itoa(int(rand.Int31()))) - err := c.txnExecInterrupted(`SAVEPOINT ` + name) + err := c.exec(`SAVEPOINT ` + name) if err != nil { panic(err) } @@ -199,7 +198,7 @@ func (s Savepoint) Release(errp *error) { return } // ROLLBACK and RELEASE even if interrupted. - err := s.c.txnExecInterrupted(`ROLLBACK TO ` + s.name + `; RELEASE ` + s.name) + err := s.c.exec(`ROLLBACK TO ` + s.name + `; RELEASE ` + s.name) if err != nil { panic(err) } @@ -212,17 +211,7 @@ func (s Savepoint) Release(errp *error) { // https://sqlite.org/lang_transaction.html func (s Savepoint) Rollback() error { // ROLLBACK even if interrupted. - return s.c.txnExecInterrupted(`ROLLBACK TO ` + s.name) -} - -func (c *Conn) txnExecInterrupted(sql string) error { - err := c.Exec(sql) - if errors.Is(err, INTERRUPT) { - old := c.SetInterrupt(context.Background()) - defer c.SetInterrupt(old) - err = c.Exec(sql) - } - return err + return s.c.exec(`ROLLBACK TO ` + s.name) } // TxnState determines the transaction state of a database. diff --git a/vtab.go b/vtab.go index 1db142c4..16ff2806 100644 --- a/vtab.go +++ b/vtab.go @@ -79,10 +79,11 @@ func implements[T any](typ reflect.Type) bool { // // https://sqlite.org/c3ref/declare_vtab.html func (c *Conn) DeclareVTab(sql string) error { + if c.interrupt.Err() != nil { + return INTERRUPT + } defer c.arena.mark()() textPtr := c.arena.string(sql) - - c.checkInterrupt() rc := res_t(c.call("sqlite3_declare_vtab", stk_t(c.handle), stk_t(textPtr))) return c.error(rc) }