From d63d3652636fdbd1154726178ec98f927382fcc6 Mon Sep 17 00:00:00 2001 From: Camden Cheek Date: Mon, 13 Jun 2022 16:18:30 -0600 Subject: [PATCH 1/7] make concurrent transaction usage loud --- internal/database/basestore/handle.go | 103 +++++++++++++++++++--- internal/database/basestore/store_test.go | 72 ++++++++++++++- 2 files changed, 159 insertions(+), 16 deletions(-) diff --git a/internal/database/basestore/handle.go b/internal/database/basestore/handle.go index bcc7623c9779..31a05a42a805 100644 --- a/internal/database/basestore/handle.go +++ b/internal/database/basestore/handle.go @@ -5,8 +5,10 @@ import ( "database/sql" "fmt" "strings" + "sync" "github.com/google/uuid" + "github.com/sourcegraph/log" "github.com/sourcegraph/sourcegraph/internal/database/dbutil" "github.com/sourcegraph/sourcegraph/lib/errors" @@ -48,7 +50,7 @@ func NewHandleWithDB(db *sql.DB, txOptions sql.TxOptions) TransactableHandle { // NewHandleWithTx returns a new transactable database handle using the given transaction. func NewHandleWithTx(tx *sql.Tx, txOptions sql.TxOptions) TransactableHandle { - return &txHandle{Tx: tx, txOptions: txOptions} + return &txHandle{lockingTx: &lockingTx{tx: tx}, txOptions: txOptions} } type dbHandle struct { @@ -65,7 +67,7 @@ func (h *dbHandle) Transact(ctx context.Context) (TransactableHandle, error) { if err != nil { return nil, err } - return &txHandle{Tx: tx, txOptions: h.txOptions}, nil + return &txHandle{lockingTx: &lockingTx{tx: tx}, txOptions: h.txOptions}, nil } func (h *dbHandle) Done(err error) error { @@ -73,7 +75,7 @@ func (h *dbHandle) Done(err error) error { } type txHandle struct { - *sql.Tx + *lockingTx txOptions sql.TxOptions } @@ -82,23 +84,23 @@ func (h *txHandle) InTransaction() bool { } func (h *txHandle) Transact(ctx context.Context) (TransactableHandle, error) { - savepointID, err := newTxSavepoint(ctx, h.Tx) + savepointID, err := newTxSavepoint(ctx, h.lockingTx) if err != nil { return nil, err } - return &savepointHandle{Tx: h.Tx, savepointID: savepointID}, nil + return &savepointHandle{lockingTx: h.lockingTx, savepointID: savepointID}, nil } func (h *txHandle) Done(err error) error { if err == nil { - return h.Tx.Commit() + return h.Commit() } - return errors.Append(err, h.Tx.Rollback()) + return errors.Append(err, h.Rollback()) } type savepointHandle struct { - *sql.Tx + *lockingTx savepointID string } @@ -107,21 +109,21 @@ func (h *savepointHandle) InTransaction() bool { } func (h *savepointHandle) Transact(ctx context.Context) (TransactableHandle, error) { - savepointID, err := newTxSavepoint(ctx, h.Tx) + savepointID, err := newTxSavepoint(ctx, h.lockingTx) if err != nil { return nil, err } - return &savepointHandle{Tx: h.Tx, savepointID: savepointID}, nil + return &savepointHandle{lockingTx: h.lockingTx, savepointID: savepointID}, nil } func (h *savepointHandle) Done(err error) error { if err == nil { - _, execErr := h.Tx.Exec(fmt.Sprintf(commitSavepointQuery, h.savepointID)) + _, execErr := h.ExecContext(context.Background(), fmt.Sprintf(commitSavepointQuery, h.savepointID)) return execErr } - _, execErr := h.Tx.Exec(fmt.Sprintf(rollbackSavepointQuery, h.savepointID)) + _, execErr := h.ExecContext(context.Background(), fmt.Sprintf(rollbackSavepointQuery, h.savepointID)) return errors.Append(err, execErr) } @@ -131,7 +133,7 @@ const ( rollbackSavepointQuery = "ROLLBACK TO %s" ) -func newTxSavepoint(ctx context.Context, tx *sql.Tx) (string, error) { +func newTxSavepoint(ctx context.Context, tx *lockingTx) (string, error) { savepointID, err := makeSavepointID() if err != nil { return "", err @@ -153,3 +155,78 @@ func makeSavepointID() (string, error) { return fmt.Sprintf("sp_%s", strings.ReplaceAll(id.String(), "-", "_")), nil } + +var ErrConcurrentTransactions = errors.New("transaction used concurrently") + +// lockingTx wraps a *sql.Tx with a mutex, and reports when a caller tries +// to use the transaction concurrently. Since using a transaction concurrently +// is unsafe, we want to catch these issues. Currently, lockingTx will just +// log an error and serialize accesses to the wrapped *sql.Tx, but in the future +// concurrent calls may be upgraded to an error. +type lockingTx struct { + tx *sql.Tx + mu sync.Mutex +} + +func (t *lockingTx) lock() error { + if !t.mu.TryLock() { + err := errors.WithStack(ErrConcurrentTransactions) + log.Scoped("internal", "database").Error("transaction used concurrently", log.Error(err)) + return err + } + return nil +} + +func (t *lockingTx) unlock() { + t.mu.Unlock() +} + +func (t *lockingTx) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) { + if err := t.lock(); err != nil { + return nil, err + } + defer t.unlock() + + return t.tx.ExecContext(ctx, query, args...) +} + +func (t *lockingTx) QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) { + if err := t.lock(); err != nil { + return nil, err + } + defer t.unlock() + + return t.tx.QueryContext(ctx, query, args...) +} + +func (t *lockingTx) QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row { + if err := t.lock(); err != nil { + // XXX(camdencheek): There is no way to construct a row that has an + // error, so in order for this to return an error, we'll need to have + // it return an interface. Instead, we attempt to acquire the lock + // to make this at least be safer if we can't report the error other + // than through logs. + t.mu.Lock() + } + defer t.unlock() + + return t.tx.QueryRowContext(ctx, query, args...) +} + +func (t *lockingTx) Commit() error { + if err := t.lock(); err != nil { + return err + } + defer t.unlock() + + return t.tx.Commit() +} + +func (t *lockingTx) Rollback() error { + if err := t.lock(); err != nil { + return err + } + defer t.unlock() + + return t.tx.Rollback() +} diff --git a/internal/database/basestore/store_test.go b/internal/database/basestore/store_test.go index 719ec3941e94..e6635c22a6dd 100644 --- a/internal/database/basestore/store_test.go +++ b/internal/database/basestore/store_test.go @@ -8,6 +8,8 @@ import ( "github.com/google/go-cmp/cmp" "github.com/keegancsmith/sqlf" + "github.com/stretchr/testify/require" + "golang.org/x/sync/errgroup" "github.com/sourcegraph/sourcegraph/internal/database/dbtest" "github.com/sourcegraph/sourcegraph/internal/database/dbutil" @@ -15,7 +17,7 @@ import ( ) func TestTransaction(t *testing.T) { - db := dbtest.NewDB(t) + db := dbtest.NewRawDB(t) setupStoreTest(t, db) store := testStore(db) @@ -61,8 +63,72 @@ func TestTransaction(t *testing.T) { assertCounts(t, db, map[int]int{1: 42, 3: 44}) } +func TestConcurrentTransactions(t *testing.T) { + db := dbtest.NewRawDB(t) + setupStoreTest(t, db) + store := testStore(db) + ctx := context.Background() + + t.Run("creating transactions concurrently does not fail", func(t *testing.T) { + var g errgroup.Group + for i := 0; i < 100; i++ { + i := i + g.Go(func() (err error) { + tx, err := store.Transact(ctx) + if err != nil { + return err + } + defer func() { err = tx.Done(err) }() + + return tx.Exec(ctx, sqlf.Sprintf(`INSERT INTO store_counts_test VALUES (%s, 42)`, i)) + }) + } + require.NoError(t, g.Wait()) + }) + + t.Run("creating concurrent savepoints on a single transaction fails", func(t *testing.T) { + tx, err := store.Transact(ctx) + if err != nil { + t.Fatal(err) + } + + var g errgroup.Group + for i := 0; i < 100; i++ { + i := i + g.Go(func() (err error) { + txNested, err := tx.Transact(ctx) + if err != nil { + return err + } + defer func() { err = txNested.Done(err) }() + + return txNested.Exec(ctx, sqlf.Sprintf(`INSERT INTO store_counts_test VALUES (%s, 42)`, i)) + }) + } + err = g.Wait() + require.ErrorIs(t, err, ErrConcurrentTransactions) + }) + + t.Run("parallel insertions on a single transaction fails", func(t *testing.T) { + tx, err := store.Transact(ctx) + if err != nil { + t.Fatal(err) + } + + var g errgroup.Group + for i := 0; i < 100; i++ { + i := i + g.Go(func() (err error) { + return tx.Exec(ctx, sqlf.Sprintf(`INSERT INTO store_counts_test VALUES (%s, 42)`, i)) + }) + } + err = g.Wait() + require.ErrorIs(t, err, ErrConcurrentTransactions) + }) +} + func TestSavepoints(t *testing.T) { - db := dbtest.NewDB(t) + db := dbtest.NewRawDB(t) setupStoreTest(t, db) NumSavepointTests := 10 @@ -88,7 +154,7 @@ func TestSavepoints(t *testing.T) { } func TestSetLocal(t *testing.T) { - db := dbtest.NewDB(t) + db := dbtest.NewRawDB(t) setupStoreTest(t, db) store := testStore(db) From 3eb6b8b0c07f3b9d75c1ca6266e536e10a928386 Mon Sep 17 00:00:00 2001 From: Camden Cheek Date: Mon, 13 Jun 2022 16:22:56 -0600 Subject: [PATCH 2/7] update comment --- internal/database/basestore/handle.go | 14 +++++++------- internal/database/basestore/store_test.go | 4 ++-- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/internal/database/basestore/handle.go b/internal/database/basestore/handle.go index 31a05a42a805..bd000c85d566 100644 --- a/internal/database/basestore/handle.go +++ b/internal/database/basestore/handle.go @@ -156,13 +156,13 @@ func makeSavepointID() (string, error) { return fmt.Sprintf("sp_%s", strings.ReplaceAll(id.String(), "-", "_")), nil } -var ErrConcurrentTransactions = errors.New("transaction used concurrently") +var ErrConcurrentTransactionAccess = errors.New("transaction used concurrently") -// lockingTx wraps a *sql.Tx with a mutex, and reports when a caller tries -// to use the transaction concurrently. Since using a transaction concurrently -// is unsafe, we want to catch these issues. Currently, lockingTx will just -// log an error and serialize accesses to the wrapped *sql.Tx, but in the future -// concurrent calls may be upgraded to an error. +// lockingTx wraps a *sql.Tx with a mutex, and reports when a caller tries to +// use the transaction concurrently. Since using a transaction concurrently is +// unsafe, we want to catch these issues. If lockingTx detects that a +// transaction is being used concurrently, it will return +// ErrConcurrentTransactionAccess and log an error. type lockingTx struct { tx *sql.Tx mu sync.Mutex @@ -170,7 +170,7 @@ type lockingTx struct { func (t *lockingTx) lock() error { if !t.mu.TryLock() { - err := errors.WithStack(ErrConcurrentTransactions) + err := errors.WithStack(ErrConcurrentTransactionAccess) log.Scoped("internal", "database").Error("transaction used concurrently", log.Error(err)) return err } diff --git a/internal/database/basestore/store_test.go b/internal/database/basestore/store_test.go index e6635c22a6dd..ecf224c86b1f 100644 --- a/internal/database/basestore/store_test.go +++ b/internal/database/basestore/store_test.go @@ -106,7 +106,7 @@ func TestConcurrentTransactions(t *testing.T) { }) } err = g.Wait() - require.ErrorIs(t, err, ErrConcurrentTransactions) + require.ErrorIs(t, err, ErrConcurrentTransactionAccess) }) t.Run("parallel insertions on a single transaction fails", func(t *testing.T) { @@ -123,7 +123,7 @@ func TestConcurrentTransactions(t *testing.T) { }) } err = g.Wait() - require.ErrorIs(t, err, ErrConcurrentTransactions) + require.ErrorIs(t, err, ErrConcurrentTransactionAccess) }) } From f6b62b2dab03fb6602941b0f8130c3b7e0afa844 Mon Sep 17 00:00:00 2001 From: Camden Cheek Date: Tue, 14 Jun 2022 13:30:01 -0600 Subject: [PATCH 3/7] demote error to log message and update tests --- internal/database/basestore/handle.go | 63 ++++++++++++----------- internal/database/basestore/store_test.go | 34 ++++-------- 2 files changed, 42 insertions(+), 55 deletions(-) diff --git a/internal/database/basestore/handle.go b/internal/database/basestore/handle.go index bd000c85d566..f930ab38becc 100644 --- a/internal/database/basestore/handle.go +++ b/internal/database/basestore/handle.go @@ -45,17 +45,28 @@ var ( // NewHandleWithDB returns a new transactable database handle using the given database connection. func NewHandleWithDB(db *sql.DB, txOptions sql.TxOptions) TransactableHandle { - return &dbHandle{DB: db, txOptions: txOptions} + return &dbHandle{ + DB: db, + logger: log.Scoped("internal", "database"), + txOptions: txOptions, + } } // NewHandleWithTx returns a new transactable database handle using the given transaction. func NewHandleWithTx(tx *sql.Tx, txOptions sql.TxOptions) TransactableHandle { - return &txHandle{lockingTx: &lockingTx{tx: tx}, txOptions: txOptions} + return &txHandle{ + lockingTx: &lockingTx{ + tx: tx, + logger: log.Scoped("internal", "database"), + }, + txOptions: txOptions, + } } type dbHandle struct { *sql.DB txOptions sql.TxOptions + logger log.Logger } func (h *dbHandle) InTransaction() bool { @@ -161,20 +172,27 @@ var ErrConcurrentTransactionAccess = errors.New("transaction used concurrently") // lockingTx wraps a *sql.Tx with a mutex, and reports when a caller tries to // use the transaction concurrently. Since using a transaction concurrently is // unsafe, we want to catch these issues. If lockingTx detects that a -// transaction is being used concurrently, it will return -// ErrConcurrentTransactionAccess and log an error. +// transaction is being used concurrently, it will log an error and attempt to +// serialize the transaction accesses. +// NOTE: this is not foolproof. Interleaving savepoints, accessing rows while +// sending another query, etc. will still fail, so the logged error is a +// notification that something needs fixed, not a notification that the locking +// successfully prevented an issue. In the future, this will likely be upgraded +// to a hard error. type lockingTx struct { - tx *sql.Tx - mu sync.Mutex + tx *sql.Tx + mu sync.Mutex + logger log.Logger } -func (t *lockingTx) lock() error { +func (t *lockingTx) lock() { if !t.mu.TryLock() { + // For now, log an error, but try to serialize access anyways to try to + // keep things slightly safer. err := errors.WithStack(ErrConcurrentTransactionAccess) - log.Scoped("internal", "database").Error("transaction used concurrently", log.Error(err)) - return err + t.logger.Error("transaction used concurrently", log.Error(err)) + t.mu.Lock() } - return nil } func (t *lockingTx) unlock() { @@ -182,50 +200,35 @@ func (t *lockingTx) unlock() { } func (t *lockingTx) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) { - if err := t.lock(); err != nil { - return nil, err - } + t.lock() defer t.unlock() return t.tx.ExecContext(ctx, query, args...) } func (t *lockingTx) QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) { - if err := t.lock(); err != nil { - return nil, err - } + t.lock() defer t.unlock() return t.tx.QueryContext(ctx, query, args...) } func (t *lockingTx) QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row { - if err := t.lock(); err != nil { - // XXX(camdencheek): There is no way to construct a row that has an - // error, so in order for this to return an error, we'll need to have - // it return an interface. Instead, we attempt to acquire the lock - // to make this at least be safer if we can't report the error other - // than through logs. - t.mu.Lock() - } + t.lock() defer t.unlock() return t.tx.QueryRowContext(ctx, query, args...) } func (t *lockingTx) Commit() error { - if err := t.lock(); err != nil { - return err - } + t.lock() defer t.unlock() return t.tx.Commit() } func (t *lockingTx) Rollback() error { - if err := t.lock(); err != nil { - return err - } + t.lock() defer t.unlock() return t.tx.Rollback() diff --git a/internal/database/basestore/store_test.go b/internal/database/basestore/store_test.go index ecf224c86b1f..4f3903a22e72 100644 --- a/internal/database/basestore/store_test.go +++ b/internal/database/basestore/store_test.go @@ -8,6 +8,7 @@ import ( "github.com/google/go-cmp/cmp" "github.com/keegancsmith/sqlf" + "github.com/sourcegraph/log/logtest" "github.com/stretchr/testify/require" "golang.org/x/sync/errgroup" @@ -86,34 +87,13 @@ func TestConcurrentTransactions(t *testing.T) { require.NoError(t, g.Wait()) }) - t.Run("creating concurrent savepoints on a single transaction fails", func(t *testing.T) { - tx, err := store.Transact(ctx) - if err != nil { - t.Fatal(err) - } - - var g errgroup.Group - for i := 0; i < 100; i++ { - i := i - g.Go(func() (err error) { - txNested, err := tx.Transact(ctx) - if err != nil { - return err - } - defer func() { err = txNested.Done(err) }() - - return txNested.Exec(ctx, sqlf.Sprintf(`INSERT INTO store_counts_test VALUES (%s, 42)`, i)) - }) - } - err = g.Wait() - require.ErrorIs(t, err, ErrConcurrentTransactionAccess) - }) - - t.Run("parallel insertions on a single transaction fails", func(t *testing.T) { + t.Run("parallel insertion on a single transaction does not fail but logs an error", func(t *testing.T) { tx, err := store.Transact(ctx) if err != nil { t.Fatal(err) } + capturingLogger, export := logtest.Captured(t) + tx.handle.(*txHandle).logger = capturingLogger var g errgroup.Group for i := 0; i < 100; i++ { @@ -123,7 +103,11 @@ func TestConcurrentTransactions(t *testing.T) { }) } err = g.Wait() - require.ErrorIs(t, err, ErrConcurrentTransactionAccess) + require.NoError(t, err) + + captured := export() + require.Greater(t, len(captured), 0) + require.Equal(t, "transaction used concurrently", captured[0].Message) }) } From ba7ea7a8e8163e041fa776b0036be678c7eb5885 Mon Sep 17 00:00:00 2001 From: Camden Cheek Date: Tue, 14 Jun 2022 13:34:24 -0600 Subject: [PATCH 4/7] add witty aphorism --- internal/database/basestore/handle.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/database/basestore/handle.go b/internal/database/basestore/handle.go index f930ab38becc..40276859b841 100644 --- a/internal/database/basestore/handle.go +++ b/internal/database/basestore/handle.go @@ -178,7 +178,7 @@ var ErrConcurrentTransactionAccess = errors.New("transaction used concurrently") // sending another query, etc. will still fail, so the logged error is a // notification that something needs fixed, not a notification that the locking // successfully prevented an issue. In the future, this will likely be upgraded -// to a hard error. +// to a hard error. Think of this like the race detector, not a race protector. type lockingTx struct { tx *sql.Tx mu sync.Mutex From 23b28b6e6a4a814dbc7f381acc9a29b0fbd1bc23 Mon Sep 17 00:00:00 2001 From: Camden Cheek Date: Tue, 14 Jun 2022 14:05:14 -0600 Subject: [PATCH 5/7] fix unset logger --- internal/database/basestore/handle.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/database/basestore/handle.go b/internal/database/basestore/handle.go index 40276859b841..e87290eecb60 100644 --- a/internal/database/basestore/handle.go +++ b/internal/database/basestore/handle.go @@ -78,7 +78,7 @@ func (h *dbHandle) Transact(ctx context.Context) (TransactableHandle, error) { if err != nil { return nil, err } - return &txHandle{lockingTx: &lockingTx{tx: tx}, txOptions: h.txOptions}, nil + return &txHandle{lockingTx: &lockingTx{tx: tx, logger: h.logger}, txOptions: h.txOptions}, nil } func (h *dbHandle) Done(err error) error { From 61b0a38bd76a7ecbabb4697c887889c48301901e Mon Sep 17 00:00:00 2001 From: Camden Cheek Date: Tue, 14 Jun 2022 21:08:21 -0600 Subject: [PATCH 6/7] Update internal/database/basestore/handle.go Co-authored-by: Joe Chen --- internal/database/basestore/handle.go | 1 + 1 file changed, 1 insertion(+) diff --git a/internal/database/basestore/handle.go b/internal/database/basestore/handle.go index e87290eecb60..d9653bd203c5 100644 --- a/internal/database/basestore/handle.go +++ b/internal/database/basestore/handle.go @@ -174,6 +174,7 @@ var ErrConcurrentTransactionAccess = errors.New("transaction used concurrently") // unsafe, we want to catch these issues. If lockingTx detects that a // transaction is being used concurrently, it will log an error and attempt to // serialize the transaction accesses. +// // NOTE: this is not foolproof. Interleaving savepoints, accessing rows while // sending another query, etc. will still fail, so the logged error is a // notification that something needs fixed, not a notification that the locking From a5d2f3c95d9b5dad1ec402f61f01b95802b3055c Mon Sep 17 00:00:00 2001 From: Camden Cheek Date: Wed, 15 Jun 2022 08:32:57 -0600 Subject: [PATCH 7/7] use sleeps instead of 100 goroutines --- internal/database/basestore/store_test.go | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/internal/database/basestore/store_test.go b/internal/database/basestore/store_test.go index 4f3903a22e72..d14ec4bf0888 100644 --- a/internal/database/basestore/store_test.go +++ b/internal/database/basestore/store_test.go @@ -72,8 +72,7 @@ func TestConcurrentTransactions(t *testing.T) { t.Run("creating transactions concurrently does not fail", func(t *testing.T) { var g errgroup.Group - for i := 0; i < 100; i++ { - i := i + for i := 0; i < 2; i++ { g.Go(func() (err error) { tx, err := store.Transact(ctx) if err != nil { @@ -81,7 +80,7 @@ func TestConcurrentTransactions(t *testing.T) { } defer func() { err = tx.Done(err) }() - return tx.Exec(ctx, sqlf.Sprintf(`INSERT INTO store_counts_test VALUES (%s, 42)`, i)) + return tx.Exec(ctx, sqlf.Sprintf(`select pg_sleep(0.1)`)) }) } require.NoError(t, g.Wait()) @@ -96,10 +95,9 @@ func TestConcurrentTransactions(t *testing.T) { tx.handle.(*txHandle).logger = capturingLogger var g errgroup.Group - for i := 0; i < 100; i++ { - i := i + for i := 0; i < 2; i++ { g.Go(func() (err error) { - return tx.Exec(ctx, sqlf.Sprintf(`INSERT INTO store_counts_test VALUES (%s, 42)`, i)) + return tx.Exec(ctx, sqlf.Sprintf(`select pg_sleep(0.1)`)) }) } err = g.Wait()