Skip to content
This repository was archived by the owner on Sep 30, 2024. It is now read-only.

Commit 291a3df

Browse files
authored
DB Backend: report explicit error when transactions are used concurrently (#37172)
This updates our transaction wrapper to return an explicit error whenever a transaction is used concurrently. Concurrent transaction access is a form of race condition that causes (in my experiments) either a conn busy error, a bad connection error, or a panic. So, instead of getting these errors that are very difficult to debug, this logs error with the stack trace that actually describes what's wrong. These errors should get pushed to Sentry, which will allow us to track where this is happening
1 parent cff846b commit 291a3df

File tree

2 files changed

+146
-17
lines changed

2 files changed

+146
-17
lines changed

internal/database/basestore/handle.go

Lines changed: 95 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,10 @@ import (
55
"database/sql"
66
"fmt"
77
"strings"
8+
"sync"
89

910
"github.com/google/uuid"
11+
"github.com/sourcegraph/log"
1012

1113
"github.com/sourcegraph/sourcegraph/internal/database/dbutil"
1214
"github.com/sourcegraph/sourcegraph/lib/errors"
@@ -43,17 +45,28 @@ var (
4345

4446
// NewHandleWithDB returns a new transactable database handle using the given database connection.
4547
func NewHandleWithDB(db *sql.DB, txOptions sql.TxOptions) TransactableHandle {
46-
return &dbHandle{DB: db, txOptions: txOptions}
48+
return &dbHandle{
49+
DB: db,
50+
logger: log.Scoped("internal", "database"),
51+
txOptions: txOptions,
52+
}
4753
}
4854

4955
// NewHandleWithTx returns a new transactable database handle using the given transaction.
5056
func NewHandleWithTx(tx *sql.Tx, txOptions sql.TxOptions) TransactableHandle {
51-
return &txHandle{Tx: tx, txOptions: txOptions}
57+
return &txHandle{
58+
lockingTx: &lockingTx{
59+
tx: tx,
60+
logger: log.Scoped("internal", "database"),
61+
},
62+
txOptions: txOptions,
63+
}
5264
}
5365

5466
type dbHandle struct {
5567
*sql.DB
5668
txOptions sql.TxOptions
69+
logger log.Logger
5770
}
5871

5972
func (h *dbHandle) InTransaction() bool {
@@ -65,15 +78,15 @@ func (h *dbHandle) Transact(ctx context.Context) (TransactableHandle, error) {
6578
if err != nil {
6679
return nil, err
6780
}
68-
return &txHandle{Tx: tx, txOptions: h.txOptions}, nil
81+
return &txHandle{lockingTx: &lockingTx{tx: tx, logger: h.logger}, txOptions: h.txOptions}, nil
6982
}
7083

7184
func (h *dbHandle) Done(err error) error {
7285
return errors.Append(err, ErrNotInTransaction)
7386
}
7487

7588
type txHandle struct {
76-
*sql.Tx
89+
*lockingTx
7790
txOptions sql.TxOptions
7891
}
7992

@@ -82,23 +95,23 @@ func (h *txHandle) InTransaction() bool {
8295
}
8396

8497
func (h *txHandle) Transact(ctx context.Context) (TransactableHandle, error) {
85-
savepointID, err := newTxSavepoint(ctx, h.Tx)
98+
savepointID, err := newTxSavepoint(ctx, h.lockingTx)
8699
if err != nil {
87100
return nil, err
88101
}
89102

90-
return &savepointHandle{Tx: h.Tx, savepointID: savepointID}, nil
103+
return &savepointHandle{lockingTx: h.lockingTx, savepointID: savepointID}, nil
91104
}
92105

93106
func (h *txHandle) Done(err error) error {
94107
if err == nil {
95-
return h.Tx.Commit()
108+
return h.Commit()
96109
}
97-
return errors.Append(err, h.Tx.Rollback())
110+
return errors.Append(err, h.Rollback())
98111
}
99112

100113
type savepointHandle struct {
101-
*sql.Tx
114+
*lockingTx
102115
savepointID string
103116
}
104117

@@ -107,21 +120,21 @@ func (h *savepointHandle) InTransaction() bool {
107120
}
108121

109122
func (h *savepointHandle) Transact(ctx context.Context) (TransactableHandle, error) {
110-
savepointID, err := newTxSavepoint(ctx, h.Tx)
123+
savepointID, err := newTxSavepoint(ctx, h.lockingTx)
111124
if err != nil {
112125
return nil, err
113126
}
114127

115-
return &savepointHandle{Tx: h.Tx, savepointID: savepointID}, nil
128+
return &savepointHandle{lockingTx: h.lockingTx, savepointID: savepointID}, nil
116129
}
117130

118131
func (h *savepointHandle) Done(err error) error {
119132
if err == nil {
120-
_, execErr := h.Tx.Exec(fmt.Sprintf(commitSavepointQuery, h.savepointID))
133+
_, execErr := h.ExecContext(context.Background(), fmt.Sprintf(commitSavepointQuery, h.savepointID))
121134
return execErr
122135
}
123136

124-
_, execErr := h.Tx.Exec(fmt.Sprintf(rollbackSavepointQuery, h.savepointID))
137+
_, execErr := h.ExecContext(context.Background(), fmt.Sprintf(rollbackSavepointQuery, h.savepointID))
125138
return errors.Append(err, execErr)
126139
}
127140

@@ -131,7 +144,7 @@ const (
131144
rollbackSavepointQuery = "ROLLBACK TO %s"
132145
)
133146

134-
func newTxSavepoint(ctx context.Context, tx *sql.Tx) (string, error) {
147+
func newTxSavepoint(ctx context.Context, tx *lockingTx) (string, error) {
135148
savepointID, err := makeSavepointID()
136149
if err != nil {
137150
return "", err
@@ -153,3 +166,71 @@ func makeSavepointID() (string, error) {
153166

154167
return fmt.Sprintf("sp_%s", strings.ReplaceAll(id.String(), "-", "_")), nil
155168
}
169+
170+
var ErrConcurrentTransactionAccess = errors.New("transaction used concurrently")
171+
172+
// lockingTx wraps a *sql.Tx with a mutex, and reports when a caller tries to
173+
// use the transaction concurrently. Since using a transaction concurrently is
174+
// unsafe, we want to catch these issues. If lockingTx detects that a
175+
// transaction is being used concurrently, it will log an error and attempt to
176+
// serialize the transaction accesses.
177+
//
178+
// NOTE: this is not foolproof. Interleaving savepoints, accessing rows while
179+
// sending another query, etc. will still fail, so the logged error is a
180+
// notification that something needs fixed, not a notification that the locking
181+
// successfully prevented an issue. In the future, this will likely be upgraded
182+
// to a hard error. Think of this like the race detector, not a race protector.
183+
type lockingTx struct {
184+
tx *sql.Tx
185+
mu sync.Mutex
186+
logger log.Logger
187+
}
188+
189+
func (t *lockingTx) lock() {
190+
if !t.mu.TryLock() {
191+
// For now, log an error, but try to serialize access anyways to try to
192+
// keep things slightly safer.
193+
err := errors.WithStack(ErrConcurrentTransactionAccess)
194+
t.logger.Error("transaction used concurrently", log.Error(err))
195+
t.mu.Lock()
196+
}
197+
}
198+
199+
func (t *lockingTx) unlock() {
200+
t.mu.Unlock()
201+
}
202+
203+
func (t *lockingTx) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) {
204+
t.lock()
205+
defer t.unlock()
206+
207+
return t.tx.ExecContext(ctx, query, args...)
208+
}
209+
210+
func (t *lockingTx) QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) {
211+
t.lock()
212+
defer t.unlock()
213+
214+
return t.tx.QueryContext(ctx, query, args...)
215+
}
216+
217+
func (t *lockingTx) QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row {
218+
t.lock()
219+
defer t.unlock()
220+
221+
return t.tx.QueryRowContext(ctx, query, args...)
222+
}
223+
224+
func (t *lockingTx) Commit() error {
225+
t.lock()
226+
defer t.unlock()
227+
228+
return t.tx.Commit()
229+
}
230+
231+
func (t *lockingTx) Rollback() error {
232+
t.lock()
233+
defer t.unlock()
234+
235+
return t.tx.Rollback()
236+
}

internal/database/basestore/store_test.go

Lines changed: 51 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,17 @@ import (
88

99
"github.com/google/go-cmp/cmp"
1010
"github.com/keegancsmith/sqlf"
11+
"github.com/sourcegraph/log/logtest"
12+
"github.com/stretchr/testify/require"
13+
"golang.org/x/sync/errgroup"
1114

1215
"github.com/sourcegraph/sourcegraph/internal/database/dbtest"
1316
"github.com/sourcegraph/sourcegraph/internal/database/dbutil"
1417
"github.com/sourcegraph/sourcegraph/lib/errors"
1518
)
1619

1720
func TestTransaction(t *testing.T) {
18-
db := dbtest.NewDB(t)
21+
db := dbtest.NewRawDB(t)
1922
setupStoreTest(t, db)
2023
store := testStore(db)
2124

@@ -61,8 +64,53 @@ func TestTransaction(t *testing.T) {
6164
assertCounts(t, db, map[int]int{1: 42, 3: 44})
6265
}
6366

67+
func TestConcurrentTransactions(t *testing.T) {
68+
db := dbtest.NewRawDB(t)
69+
setupStoreTest(t, db)
70+
store := testStore(db)
71+
ctx := context.Background()
72+
73+
t.Run("creating transactions concurrently does not fail", func(t *testing.T) {
74+
var g errgroup.Group
75+
for i := 0; i < 2; i++ {
76+
g.Go(func() (err error) {
77+
tx, err := store.Transact(ctx)
78+
if err != nil {
79+
return err
80+
}
81+
defer func() { err = tx.Done(err) }()
82+
83+
return tx.Exec(ctx, sqlf.Sprintf(`select pg_sleep(0.1)`))
84+
})
85+
}
86+
require.NoError(t, g.Wait())
87+
})
88+
89+
t.Run("parallel insertion on a single transaction does not fail but logs an error", func(t *testing.T) {
90+
tx, err := store.Transact(ctx)
91+
if err != nil {
92+
t.Fatal(err)
93+
}
94+
capturingLogger, export := logtest.Captured(t)
95+
tx.handle.(*txHandle).logger = capturingLogger
96+
97+
var g errgroup.Group
98+
for i := 0; i < 2; i++ {
99+
g.Go(func() (err error) {
100+
return tx.Exec(ctx, sqlf.Sprintf(`select pg_sleep(0.1)`))
101+
})
102+
}
103+
err = g.Wait()
104+
require.NoError(t, err)
105+
106+
captured := export()
107+
require.Greater(t, len(captured), 0)
108+
require.Equal(t, "transaction used concurrently", captured[0].Message)
109+
})
110+
}
111+
64112
func TestSavepoints(t *testing.T) {
65-
db := dbtest.NewDB(t)
113+
db := dbtest.NewRawDB(t)
66114
setupStoreTest(t, db)
67115

68116
NumSavepointTests := 10
@@ -88,7 +136,7 @@ func TestSavepoints(t *testing.T) {
88136
}
89137

90138
func TestSetLocal(t *testing.T) {
91-
db := dbtest.NewDB(t)
139+
db := dbtest.NewRawDB(t)
92140
setupStoreTest(t, db)
93141
store := testStore(db)
94142

0 commit comments

Comments
 (0)