diff --git a/internal/database/basestore/handle.go b/internal/database/basestore/handle.go index bcc7623c9779..d9653bd203c5 100644 --- a/internal/database/basestore/handle.go +++ b/internal/database/basestore/handle.go @@ -5,8 +5,10 @@ import ( "database/sql" "fmt" "strings" + "sync" "github.com/google/uuid" + "github.com/sourcegraph/log" "github.com/sourcegraph/sourcegraph/internal/database/dbutil" "github.com/sourcegraph/sourcegraph/lib/errors" @@ -43,17 +45,28 @@ var ( // NewHandleWithDB returns a new transactable database handle using the given database connection. func NewHandleWithDB(db *sql.DB, txOptions sql.TxOptions) TransactableHandle { - return &dbHandle{DB: db, txOptions: txOptions} + return &dbHandle{ + DB: db, + logger: log.Scoped("internal", "database"), + txOptions: txOptions, + } } // NewHandleWithTx returns a new transactable database handle using the given transaction. func NewHandleWithTx(tx *sql.Tx, txOptions sql.TxOptions) TransactableHandle { - return &txHandle{Tx: tx, txOptions: txOptions} + return &txHandle{ + lockingTx: &lockingTx{ + tx: tx, + logger: log.Scoped("internal", "database"), + }, + txOptions: txOptions, + } } type dbHandle struct { *sql.DB txOptions sql.TxOptions + logger log.Logger } func (h *dbHandle) InTransaction() bool { @@ -65,7 +78,7 @@ func (h *dbHandle) Transact(ctx context.Context) (TransactableHandle, error) { if err != nil { return nil, err } - return &txHandle{Tx: tx, txOptions: h.txOptions}, nil + return &txHandle{lockingTx: &lockingTx{tx: tx, logger: h.logger}, txOptions: h.txOptions}, nil } func (h *dbHandle) Done(err error) error { @@ -73,7 +86,7 @@ func (h *dbHandle) Done(err error) error { } type txHandle struct { - *sql.Tx + *lockingTx txOptions sql.TxOptions } @@ -82,23 +95,23 @@ func (h *txHandle) InTransaction() bool { } func (h *txHandle) Transact(ctx context.Context) (TransactableHandle, error) { - savepointID, err := newTxSavepoint(ctx, h.Tx) + savepointID, err := newTxSavepoint(ctx, h.lockingTx) if err != nil { return nil, err } - return &savepointHandle{Tx: h.Tx, savepointID: savepointID}, nil + return &savepointHandle{lockingTx: h.lockingTx, savepointID: savepointID}, nil } func (h *txHandle) Done(err error) error { if err == nil { - return h.Tx.Commit() + return h.Commit() } - return errors.Append(err, h.Tx.Rollback()) + return errors.Append(err, h.Rollback()) } type savepointHandle struct { - *sql.Tx + *lockingTx savepointID string } @@ -107,21 +120,21 @@ func (h *savepointHandle) InTransaction() bool { } func (h *savepointHandle) Transact(ctx context.Context) (TransactableHandle, error) { - savepointID, err := newTxSavepoint(ctx, h.Tx) + savepointID, err := newTxSavepoint(ctx, h.lockingTx) if err != nil { return nil, err } - return &savepointHandle{Tx: h.Tx, savepointID: savepointID}, nil + return &savepointHandle{lockingTx: h.lockingTx, savepointID: savepointID}, nil } func (h *savepointHandle) Done(err error) error { if err == nil { - _, execErr := h.Tx.Exec(fmt.Sprintf(commitSavepointQuery, h.savepointID)) + _, execErr := h.ExecContext(context.Background(), fmt.Sprintf(commitSavepointQuery, h.savepointID)) return execErr } - _, execErr := h.Tx.Exec(fmt.Sprintf(rollbackSavepointQuery, h.savepointID)) + _, execErr := h.ExecContext(context.Background(), fmt.Sprintf(rollbackSavepointQuery, h.savepointID)) return errors.Append(err, execErr) } @@ -131,7 +144,7 @@ const ( rollbackSavepointQuery = "ROLLBACK TO %s" ) -func newTxSavepoint(ctx context.Context, tx *sql.Tx) (string, error) { +func newTxSavepoint(ctx context.Context, tx *lockingTx) (string, error) { savepointID, err := makeSavepointID() if err != nil { return "", err @@ -153,3 +166,71 @@ func makeSavepointID() (string, error) { return fmt.Sprintf("sp_%s", strings.ReplaceAll(id.String(), "-", "_")), nil } + +var ErrConcurrentTransactionAccess = errors.New("transaction used concurrently") + +// lockingTx wraps a *sql.Tx with a mutex, and reports when a caller tries to +// use the transaction concurrently. Since using a transaction concurrently is +// unsafe, we want to catch these issues. If lockingTx detects that a +// transaction is being used concurrently, it will log an error and attempt to +// serialize the transaction accesses. +// +// NOTE: this is not foolproof. Interleaving savepoints, accessing rows while +// sending another query, etc. will still fail, so the logged error is a +// notification that something needs fixed, not a notification that the locking +// successfully prevented an issue. In the future, this will likely be upgraded +// to a hard error. Think of this like the race detector, not a race protector. +type lockingTx struct { + tx *sql.Tx + mu sync.Mutex + logger log.Logger +} + +func (t *lockingTx) lock() { + if !t.mu.TryLock() { + // For now, log an error, but try to serialize access anyways to try to + // keep things slightly safer. + err := errors.WithStack(ErrConcurrentTransactionAccess) + t.logger.Error("transaction used concurrently", log.Error(err)) + t.mu.Lock() + } +} + +func (t *lockingTx) unlock() { + t.mu.Unlock() +} + +func (t *lockingTx) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) { + t.lock() + defer t.unlock() + + return t.tx.ExecContext(ctx, query, args...) +} + +func (t *lockingTx) QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) { + t.lock() + defer t.unlock() + + return t.tx.QueryContext(ctx, query, args...) +} + +func (t *lockingTx) QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row { + t.lock() + defer t.unlock() + + return t.tx.QueryRowContext(ctx, query, args...) +} + +func (t *lockingTx) Commit() error { + t.lock() + defer t.unlock() + + return t.tx.Commit() +} + +func (t *lockingTx) Rollback() error { + t.lock() + defer t.unlock() + + return t.tx.Rollback() +} diff --git a/internal/database/basestore/store_test.go b/internal/database/basestore/store_test.go index 719ec3941e94..d14ec4bf0888 100644 --- a/internal/database/basestore/store_test.go +++ b/internal/database/basestore/store_test.go @@ -8,6 +8,9 @@ import ( "github.com/google/go-cmp/cmp" "github.com/keegancsmith/sqlf" + "github.com/sourcegraph/log/logtest" + "github.com/stretchr/testify/require" + "golang.org/x/sync/errgroup" "github.com/sourcegraph/sourcegraph/internal/database/dbtest" "github.com/sourcegraph/sourcegraph/internal/database/dbutil" @@ -15,7 +18,7 @@ import ( ) func TestTransaction(t *testing.T) { - db := dbtest.NewDB(t) + db := dbtest.NewRawDB(t) setupStoreTest(t, db) store := testStore(db) @@ -61,8 +64,53 @@ func TestTransaction(t *testing.T) { assertCounts(t, db, map[int]int{1: 42, 3: 44}) } +func TestConcurrentTransactions(t *testing.T) { + db := dbtest.NewRawDB(t) + setupStoreTest(t, db) + store := testStore(db) + ctx := context.Background() + + t.Run("creating transactions concurrently does not fail", func(t *testing.T) { + var g errgroup.Group + for i := 0; i < 2; i++ { + g.Go(func() (err error) { + tx, err := store.Transact(ctx) + if err != nil { + return err + } + defer func() { err = tx.Done(err) }() + + return tx.Exec(ctx, sqlf.Sprintf(`select pg_sleep(0.1)`)) + }) + } + require.NoError(t, g.Wait()) + }) + + t.Run("parallel insertion on a single transaction does not fail but logs an error", func(t *testing.T) { + tx, err := store.Transact(ctx) + if err != nil { + t.Fatal(err) + } + capturingLogger, export := logtest.Captured(t) + tx.handle.(*txHandle).logger = capturingLogger + + var g errgroup.Group + for i := 0; i < 2; i++ { + g.Go(func() (err error) { + return tx.Exec(ctx, sqlf.Sprintf(`select pg_sleep(0.1)`)) + }) + } + err = g.Wait() + require.NoError(t, err) + + captured := export() + require.Greater(t, len(captured), 0) + require.Equal(t, "transaction used concurrently", captured[0].Message) + }) +} + func TestSavepoints(t *testing.T) { - db := dbtest.NewDB(t) + db := dbtest.NewRawDB(t) setupStoreTest(t, db) NumSavepointTests := 10 @@ -88,7 +136,7 @@ func TestSavepoints(t *testing.T) { } func TestSetLocal(t *testing.T) { - db := dbtest.NewDB(t) + db := dbtest.NewRawDB(t) setupStoreTest(t, db) store := testStore(db)