Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions blob.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ var _ io.ReadWriteSeeker = &Blob{}
//
// https://sqlite.org/c3ref/blob_open.html
func (c *Conn) OpenBlob(db, table, column string, row int64, write bool) (*Blob, error) {
if c.interrupt.Err() != nil {
return nil, INTERRUPT
}

defer c.arena.mark()()
blobPtr := c.arena.new(ptrlen)
dbPtr := c.arena.string(db)
Expand All @@ -42,7 +46,6 @@ func (c *Conn) OpenBlob(db, table, column string, row int64, write bool) (*Blob,
flags = 1
}

c.checkInterrupt()
rc := res_t(c.call("sqlite3_blob_open", stk_t(c.handle),
stk_t(dbPtr), stk_t(tablePtr), stk_t(columnPtr),
stk_t(row), stk_t(flags), stk_t(blobPtr)))
Expand Down Expand Up @@ -253,7 +256,9 @@ func (b *Blob) Seek(offset int64, whence int) (int64, error) {
//
// https://sqlite.org/c3ref/blob_reopen.html
func (b *Blob) Reopen(row int64) error {
b.c.checkInterrupt()
if b.c.interrupt.Err() != nil {
return INTERRUPT
}
err := b.c.error(res_t(b.c.call("sqlite3_blob_reopen", stk_t(b.handle), stk_t(row))))
b.bytes = int64(int32(b.c.call("sqlite3_blob_bytes", stk_t(b.handle))))
b.offset = 0
Expand Down
6 changes: 4 additions & 2 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -275,12 +275,14 @@ func traceCallback(ctx context.Context, mod api.Module, evt TraceEvent, pDB, pAr
//
// https://sqlite.org/c3ref/wal_checkpoint_v2.html
func (c *Conn) WALCheckpoint(schema string, mode CheckpointMode) (nLog, nCkpt int, err error) {
if c.interrupt.Err() != nil {
return 0, 0, INTERRUPT
}

defer c.arena.mark()()
nLogPtr := c.arena.new(ptrlen)
nCkptPtr := c.arena.new(ptrlen)
schemaPtr := c.arena.string(schema)

c.checkInterrupt()
rc := res_t(c.call("sqlite3_wal_checkpoint_v2",
stk_t(c.handle), stk_t(schemaPtr), stk_t(mode),
stk_t(nLogPtr), stk_t(nCkptPtr)))
Expand Down
51 changes: 10 additions & 41 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,6 @@ type Conn struct {
busylst time.Time
arena arena
handle ptr_t
pending ptr_t
stepped bool
gosched uint8
}

Expand Down Expand Up @@ -167,9 +165,6 @@ func (c *Conn) Close() error {
return nil
}

c.call("sqlite3_finalize", stk_t(c.pending))
c.pending = 0

rc := res_t(c.call("sqlite3_close", stk_t(c.handle)))
if err := c.error(rc); err != nil {
return err
Expand All @@ -184,10 +179,15 @@ func (c *Conn) Close() error {
//
// https://sqlite.org/c3ref/exec.html
func (c *Conn) Exec(sql string) error {
if c.interrupt.Err() != nil {
return INTERRUPT
}
return c.exec(sql)
}

func (c *Conn) exec(sql string) error {
defer c.arena.mark()()
textPtr := c.arena.string(sql)

c.checkInterrupt()
rc := res_t(c.call("sqlite3_exec", stk_t(c.handle), stk_t(textPtr), 0, 0, 0))
return c.error(rc, sql)
}
Expand All @@ -207,13 +207,15 @@ func (c *Conn) PrepareFlags(sql string, flags PrepareFlag) (stmt *Stmt, tail str
if len(sql) > _MAX_SQL_LENGTH {
return nil, "", TOOBIG
}
if c.interrupt.Err() != nil {
return nil, "", INTERRUPT
}

defer c.arena.mark()()
stmtPtr := c.arena.new(ptrlen)
tailPtr := c.arena.new(ptrlen)
textPtr := c.arena.string(sql)

c.checkInterrupt()
rc := res_t(c.call("sqlite3_prepare_v3", stk_t(c.handle),
stk_t(textPtr), stk_t(len(sql)+1), stk_t(flags),
stk_t(stmtPtr), stk_t(tailPtr)))
Expand Down Expand Up @@ -343,42 +345,9 @@ func (c *Conn) GetInterrupt() context.Context {
func (c *Conn) SetInterrupt(ctx context.Context) (old context.Context) {
old = c.interrupt
c.interrupt = ctx

if ctx == old {
return old
}

// An active SQL statement prevents SQLite from ignoring an interrupt
// that comes before any other statements are started.
if c.pending == 0 {
defer c.arena.mark()()
stmtPtr := c.arena.new(ptrlen)
textPtr := c.arena.string(`SELECT 0 UNION ALL SELECT 0`)
c.call("sqlite3_prepare_v3", stk_t(c.handle), stk_t(textPtr), math.MaxUint64,
stk_t(PREPARE_PERSISTENT), stk_t(stmtPtr), 0)
c.pending = util.Read32[ptr_t](c.mod, stmtPtr)
}

if c.stepped && ctx.Err() == nil {
c.call("sqlite3_reset", stk_t(c.pending))
c.stepped = false
} else {
c.checkInterrupt()
}
return old
}

func (c *Conn) checkInterrupt() {
if c.interrupt.Err() == nil {
return
}
if !c.stepped {
c.call("sqlite3_step", stk_t(c.pending))
c.stepped = true
}
c.call("sqlite3_interrupt", stk_t(c.handle))
}

func progressCallback(ctx context.Context, mod api.Module, _ ptr_t) (interrupt int32) {
if c, ok := ctx.Value(connKey{}).(*Conn); ok {
if c.gosched++; c.gosched%16 == 0 {
Expand Down
56 changes: 56 additions & 0 deletions driver/driver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,62 @@ func Test_BeginTx(t *testing.T) {
}
}

func Test_nested_context(t *testing.T) {
t.Parallel()
tmp := memdb.TestDB(t)

db, err := sql.Open("sqlite3", tmp)
if err != nil {
t.Fatal(err)
}
defer db.Close()

tx, err := db.Begin()
if err != nil {
t.Fatal(err)
}
defer tx.Rollback()

outer, err := tx.Query(`SELECT value FROM generate_series(0)`)
if err != nil {
t.Fatal(err)
}
defer outer.Close()

want := func(rows *sql.Rows, want int) {
t.Helper()

var got int
rows.Next()
if err := rows.Scan(&got); err != nil {
t.Fatal(err)
}
if got != want {
t.Errorf("got %d, want %d", got, want)
}
}

want(outer, 0)

ctx, cancel := context.WithCancel(context.Background())
defer cancel()

inner, err := tx.QueryContext(ctx, `SELECT value FROM generate_series(0)`)
if err != nil {
t.Fatal(err)
}
defer inner.Close()

want(inner, 0)
cancel()

if inner.Next() || !errors.Is(inner.Err(), sqlite3.INTERRUPT) {
t.Fatal(inner.Err())
}

want(outer, 1)
}

func Test_Prepare(t *testing.T) {
t.Parallel()
tmp := memdb.TestDB(t)
Expand Down
6 changes: 5 additions & 1 deletion stmt.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,11 @@ func (s *Stmt) Busy() bool {
//
// https://sqlite.org/c3ref/step.html
func (s *Stmt) Step() bool {
s.c.checkInterrupt()
if s.c.interrupt.Err() != nil {
s.err = INTERRUPT
return false
}

rc := res_t(s.c.call("sqlite3_step", stk_t(s.handle)))
switch rc {
case _ROW:
Expand Down
21 changes: 5 additions & 16 deletions txn.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package sqlite3

import (
"context"
"errors"
"math/rand"
"runtime"
"strconv"
Expand All @@ -25,7 +24,7 @@ type Txn struct {
// https://sqlite.org/lang_transaction.html
func (c *Conn) Begin() Txn {
// BEGIN even if interrupted.
err := c.txnExecInterrupted(`BEGIN DEFERRED`)
err := c.exec(`BEGIN DEFERRED`)
if err != nil {
panic(err)
}
Expand Down Expand Up @@ -120,7 +119,7 @@ func (tx Txn) Commit() error {
//
// https://sqlite.org/lang_transaction.html
func (tx Txn) Rollback() error {
return tx.c.txnExecInterrupted(`ROLLBACK`)
return tx.c.exec(`ROLLBACK`)
}

// Savepoint is a marker within a transaction
Expand All @@ -143,7 +142,7 @@ func (c *Conn) Savepoint() Savepoint {
// Names can be reused, but this makes catching bugs more likely.
name = QuoteIdentifier(name + "_" + strconv.Itoa(int(rand.Int31())))

err := c.txnExecInterrupted(`SAVEPOINT ` + name)
err := c.exec(`SAVEPOINT ` + name)
if err != nil {
panic(err)
}
Expand Down Expand Up @@ -199,7 +198,7 @@ func (s Savepoint) Release(errp *error) {
return
}
// ROLLBACK and RELEASE even if interrupted.
err := s.c.txnExecInterrupted(`ROLLBACK TO ` + s.name + `; RELEASE ` + s.name)
err := s.c.exec(`ROLLBACK TO ` + s.name + `; RELEASE ` + s.name)
if err != nil {
panic(err)
}
Expand All @@ -212,17 +211,7 @@ func (s Savepoint) Release(errp *error) {
// https://sqlite.org/lang_transaction.html
func (s Savepoint) Rollback() error {
// ROLLBACK even if interrupted.
return s.c.txnExecInterrupted(`ROLLBACK TO ` + s.name)
}

func (c *Conn) txnExecInterrupted(sql string) error {
err := c.Exec(sql)
if errors.Is(err, INTERRUPT) {
old := c.SetInterrupt(context.Background())
defer c.SetInterrupt(old)
err = c.Exec(sql)
}
return err
return s.c.exec(`ROLLBACK TO ` + s.name)
}

// TxnState determines the transaction state of a database.
Expand Down
5 changes: 3 additions & 2 deletions vtab.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,11 @@ func implements[T any](typ reflect.Type) bool {
//
// https://sqlite.org/c3ref/declare_vtab.html
func (c *Conn) DeclareVTab(sql string) error {
if c.interrupt.Err() != nil {
return INTERRUPT
}
defer c.arena.mark()()
textPtr := c.arena.string(sql)

c.checkInterrupt()
rc := res_t(c.call("sqlite3_declare_vtab", stk_t(c.handle), stk_t(textPtr)))
return c.error(rc)
}
Expand Down