Skip to content
This repository was archived by the owner on Sep 30, 2024. It is now read-only.

Commit d63d365

Browse files
committed
make concurrent transaction usage loud
1 parent ee784c5 commit d63d365

File tree

2 files changed

+159
-16
lines changed

2 files changed

+159
-16
lines changed

internal/database/basestore/handle.go

Lines changed: 90 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,10 @@ import (
55
"database/sql"
66
"fmt"
77
"strings"
8+
"sync"
89

910
"github.com/google/uuid"
11+
"github.com/sourcegraph/log"
1012

1113
"github.com/sourcegraph/sourcegraph/internal/database/dbutil"
1214
"github.com/sourcegraph/sourcegraph/lib/errors"
@@ -48,7 +50,7 @@ func NewHandleWithDB(db *sql.DB, txOptions sql.TxOptions) TransactableHandle {
4850

4951
// NewHandleWithTx returns a new transactable database handle using the given transaction.
5052
func NewHandleWithTx(tx *sql.Tx, txOptions sql.TxOptions) TransactableHandle {
51-
return &txHandle{Tx: tx, txOptions: txOptions}
53+
return &txHandle{lockingTx: &lockingTx{tx: tx}, txOptions: txOptions}
5254
}
5355

5456
type dbHandle struct {
@@ -65,15 +67,15 @@ func (h *dbHandle) Transact(ctx context.Context) (TransactableHandle, error) {
6567
if err != nil {
6668
return nil, err
6769
}
68-
return &txHandle{Tx: tx, txOptions: h.txOptions}, nil
70+
return &txHandle{lockingTx: &lockingTx{tx: tx}, txOptions: h.txOptions}, nil
6971
}
7072

7173
func (h *dbHandle) Done(err error) error {
7274
return errors.Append(err, ErrNotInTransaction)
7375
}
7476

7577
type txHandle struct {
76-
*sql.Tx
78+
*lockingTx
7779
txOptions sql.TxOptions
7880
}
7981

@@ -82,23 +84,23 @@ func (h *txHandle) InTransaction() bool {
8284
}
8385

8486
func (h *txHandle) Transact(ctx context.Context) (TransactableHandle, error) {
85-
savepointID, err := newTxSavepoint(ctx, h.Tx)
87+
savepointID, err := newTxSavepoint(ctx, h.lockingTx)
8688
if err != nil {
8789
return nil, err
8890
}
8991

90-
return &savepointHandle{Tx: h.Tx, savepointID: savepointID}, nil
92+
return &savepointHandle{lockingTx: h.lockingTx, savepointID: savepointID}, nil
9193
}
9294

9395
func (h *txHandle) Done(err error) error {
9496
if err == nil {
95-
return h.Tx.Commit()
97+
return h.Commit()
9698
}
97-
return errors.Append(err, h.Tx.Rollback())
99+
return errors.Append(err, h.Rollback())
98100
}
99101

100102
type savepointHandle struct {
101-
*sql.Tx
103+
*lockingTx
102104
savepointID string
103105
}
104106

@@ -107,21 +109,21 @@ func (h *savepointHandle) InTransaction() bool {
107109
}
108110

109111
func (h *savepointHandle) Transact(ctx context.Context) (TransactableHandle, error) {
110-
savepointID, err := newTxSavepoint(ctx, h.Tx)
112+
savepointID, err := newTxSavepoint(ctx, h.lockingTx)
111113
if err != nil {
112114
return nil, err
113115
}
114116

115-
return &savepointHandle{Tx: h.Tx, savepointID: savepointID}, nil
117+
return &savepointHandle{lockingTx: h.lockingTx, savepointID: savepointID}, nil
116118
}
117119

118120
func (h *savepointHandle) Done(err error) error {
119121
if err == nil {
120-
_, execErr := h.Tx.Exec(fmt.Sprintf(commitSavepointQuery, h.savepointID))
122+
_, execErr := h.ExecContext(context.Background(), fmt.Sprintf(commitSavepointQuery, h.savepointID))
121123
return execErr
122124
}
123125

124-
_, execErr := h.Tx.Exec(fmt.Sprintf(rollbackSavepointQuery, h.savepointID))
126+
_, execErr := h.ExecContext(context.Background(), fmt.Sprintf(rollbackSavepointQuery, h.savepointID))
125127
return errors.Append(err, execErr)
126128
}
127129

@@ -131,7 +133,7 @@ const (
131133
rollbackSavepointQuery = "ROLLBACK TO %s"
132134
)
133135

134-
func newTxSavepoint(ctx context.Context, tx *sql.Tx) (string, error) {
136+
func newTxSavepoint(ctx context.Context, tx *lockingTx) (string, error) {
135137
savepointID, err := makeSavepointID()
136138
if err != nil {
137139
return "", err
@@ -153,3 +155,78 @@ func makeSavepointID() (string, error) {
153155

154156
return fmt.Sprintf("sp_%s", strings.ReplaceAll(id.String(), "-", "_")), nil
155157
}
158+
159+
var ErrConcurrentTransactions = errors.New("transaction used concurrently")
160+
161+
// lockingTx wraps a *sql.Tx with a mutex, and reports when a caller tries
162+
// to use the transaction concurrently. Since using a transaction concurrently
163+
// is unsafe, we want to catch these issues. Currently, lockingTx will just
164+
// log an error and serialize accesses to the wrapped *sql.Tx, but in the future
165+
// concurrent calls may be upgraded to an error.
166+
type lockingTx struct {
167+
tx *sql.Tx
168+
mu sync.Mutex
169+
}
170+
171+
func (t *lockingTx) lock() error {
172+
if !t.mu.TryLock() {
173+
err := errors.WithStack(ErrConcurrentTransactions)
174+
log.Scoped("internal", "database").Error("transaction used concurrently", log.Error(err))
175+
return err
176+
}
177+
return nil
178+
}
179+
180+
func (t *lockingTx) unlock() {
181+
t.mu.Unlock()
182+
}
183+
184+
func (t *lockingTx) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) {
185+
if err := t.lock(); err != nil {
186+
return nil, err
187+
}
188+
defer t.unlock()
189+
190+
return t.tx.ExecContext(ctx, query, args...)
191+
}
192+
193+
func (t *lockingTx) QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) {
194+
if err := t.lock(); err != nil {
195+
return nil, err
196+
}
197+
defer t.unlock()
198+
199+
return t.tx.QueryContext(ctx, query, args...)
200+
}
201+
202+
func (t *lockingTx) QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row {
203+
if err := t.lock(); err != nil {
204+
// XXX(camdencheek): There is no way to construct a row that has an
205+
// error, so in order for this to return an error, we'll need to have
206+
// it return an interface. Instead, we attempt to acquire the lock
207+
// to make this at least be safer if we can't report the error other
208+
// than through logs.
209+
t.mu.Lock()
210+
}
211+
defer t.unlock()
212+
213+
return t.tx.QueryRowContext(ctx, query, args...)
214+
}
215+
216+
func (t *lockingTx) Commit() error {
217+
if err := t.lock(); err != nil {
218+
return err
219+
}
220+
defer t.unlock()
221+
222+
return t.tx.Commit()
223+
}
224+
225+
func (t *lockingTx) Rollback() error {
226+
if err := t.lock(); err != nil {
227+
return err
228+
}
229+
defer t.unlock()
230+
231+
return t.tx.Rollback()
232+
}

internal/database/basestore/store_test.go

Lines changed: 69 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,16 @@ import (
88

99
"github.com/google/go-cmp/cmp"
1010
"github.com/keegancsmith/sqlf"
11+
"github.com/stretchr/testify/require"
12+
"golang.org/x/sync/errgroup"
1113

1214
"github.com/sourcegraph/sourcegraph/internal/database/dbtest"
1315
"github.com/sourcegraph/sourcegraph/internal/database/dbutil"
1416
"github.com/sourcegraph/sourcegraph/lib/errors"
1517
)
1618

1719
func TestTransaction(t *testing.T) {
18-
db := dbtest.NewDB(t)
20+
db := dbtest.NewRawDB(t)
1921
setupStoreTest(t, db)
2022
store := testStore(db)
2123

@@ -61,8 +63,72 @@ func TestTransaction(t *testing.T) {
6163
assertCounts(t, db, map[int]int{1: 42, 3: 44})
6264
}
6365

66+
func TestConcurrentTransactions(t *testing.T) {
67+
db := dbtest.NewRawDB(t)
68+
setupStoreTest(t, db)
69+
store := testStore(db)
70+
ctx := context.Background()
71+
72+
t.Run("creating transactions concurrently does not fail", func(t *testing.T) {
73+
var g errgroup.Group
74+
for i := 0; i < 100; i++ {
75+
i := i
76+
g.Go(func() (err error) {
77+
tx, err := store.Transact(ctx)
78+
if err != nil {
79+
return err
80+
}
81+
defer func() { err = tx.Done(err) }()
82+
83+
return tx.Exec(ctx, sqlf.Sprintf(`INSERT INTO store_counts_test VALUES (%s, 42)`, i))
84+
})
85+
}
86+
require.NoError(t, g.Wait())
87+
})
88+
89+
t.Run("creating concurrent savepoints on a single transaction fails", func(t *testing.T) {
90+
tx, err := store.Transact(ctx)
91+
if err != nil {
92+
t.Fatal(err)
93+
}
94+
95+
var g errgroup.Group
96+
for i := 0; i < 100; i++ {
97+
i := i
98+
g.Go(func() (err error) {
99+
txNested, err := tx.Transact(ctx)
100+
if err != nil {
101+
return err
102+
}
103+
defer func() { err = txNested.Done(err) }()
104+
105+
return txNested.Exec(ctx, sqlf.Sprintf(`INSERT INTO store_counts_test VALUES (%s, 42)`, i))
106+
})
107+
}
108+
err = g.Wait()
109+
require.ErrorIs(t, err, ErrConcurrentTransactions)
110+
})
111+
112+
t.Run("parallel insertions on a single transaction fails", func(t *testing.T) {
113+
tx, err := store.Transact(ctx)
114+
if err != nil {
115+
t.Fatal(err)
116+
}
117+
118+
var g errgroup.Group
119+
for i := 0; i < 100; i++ {
120+
i := i
121+
g.Go(func() (err error) {
122+
return tx.Exec(ctx, sqlf.Sprintf(`INSERT INTO store_counts_test VALUES (%s, 42)`, i))
123+
})
124+
}
125+
err = g.Wait()
126+
require.ErrorIs(t, err, ErrConcurrentTransactions)
127+
})
128+
}
129+
64130
func TestSavepoints(t *testing.T) {
65-
db := dbtest.NewDB(t)
131+
db := dbtest.NewRawDB(t)
66132
setupStoreTest(t, db)
67133

68134
NumSavepointTests := 10
@@ -88,7 +154,7 @@ func TestSavepoints(t *testing.T) {
88154
}
89155

90156
func TestSetLocal(t *testing.T) {
91-
db := dbtest.NewDB(t)
157+
db := dbtest.NewRawDB(t)
92158
setupStoreTest(t, db)
93159
store := testStore(db)
94160

0 commit comments

Comments
 (0)