@@ -5,8 +5,10 @@ import (
5
5
"database/sql"
6
6
"fmt"
7
7
"strings"
8
+ "sync"
8
9
9
10
"github.com/google/uuid"
11
+ "github.com/sourcegraph/log"
10
12
11
13
"github.com/sourcegraph/sourcegraph/internal/database/dbutil"
12
14
"github.com/sourcegraph/sourcegraph/lib/errors"
@@ -43,17 +45,28 @@ var (
43
45
44
46
// NewHandleWithDB returns a new transactable database handle using the given database connection.
45
47
func NewHandleWithDB (db * sql.DB , txOptions sql.TxOptions ) TransactableHandle {
46
- return & dbHandle {DB : db , txOptions : txOptions }
48
+ return & dbHandle {
49
+ DB : db ,
50
+ logger : log .Scoped ("internal" , "database" ),
51
+ txOptions : txOptions ,
52
+ }
47
53
}
48
54
49
55
// NewHandleWithTx returns a new transactable database handle using the given transaction.
50
56
func NewHandleWithTx (tx * sql.Tx , txOptions sql.TxOptions ) TransactableHandle {
51
- return & txHandle {Tx : tx , txOptions : txOptions }
57
+ return & txHandle {
58
+ lockingTx : & lockingTx {
59
+ tx : tx ,
60
+ logger : log .Scoped ("internal" , "database" ),
61
+ },
62
+ txOptions : txOptions ,
63
+ }
52
64
}
53
65
54
66
type dbHandle struct {
55
67
* sql.DB
56
68
txOptions sql.TxOptions
69
+ logger log.Logger
57
70
}
58
71
59
72
func (h * dbHandle ) InTransaction () bool {
@@ -65,15 +78,15 @@ func (h *dbHandle) Transact(ctx context.Context) (TransactableHandle, error) {
65
78
if err != nil {
66
79
return nil , err
67
80
}
68
- return & txHandle {Tx : tx , txOptions : h .txOptions }, nil
81
+ return & txHandle {lockingTx : & lockingTx { tx : tx , logger : h . logger } , txOptions : h .txOptions }, nil
69
82
}
70
83
71
84
func (h * dbHandle ) Done (err error ) error {
72
85
return errors .Append (err , ErrNotInTransaction )
73
86
}
74
87
75
88
type txHandle struct {
76
- * sql. Tx
89
+ * lockingTx
77
90
txOptions sql.TxOptions
78
91
}
79
92
@@ -82,23 +95,23 @@ func (h *txHandle) InTransaction() bool {
82
95
}
83
96
84
97
func (h * txHandle ) Transact (ctx context.Context ) (TransactableHandle , error ) {
85
- savepointID , err := newTxSavepoint (ctx , h .Tx )
98
+ savepointID , err := newTxSavepoint (ctx , h .lockingTx )
86
99
if err != nil {
87
100
return nil , err
88
101
}
89
102
90
- return & savepointHandle {Tx : h .Tx , savepointID : savepointID }, nil
103
+ return & savepointHandle {lockingTx : h .lockingTx , savepointID : savepointID }, nil
91
104
}
92
105
93
106
func (h * txHandle ) Done (err error ) error {
94
107
if err == nil {
95
- return h .Tx . Commit ()
108
+ return h .Commit ()
96
109
}
97
- return errors .Append (err , h .Tx . Rollback ())
110
+ return errors .Append (err , h .Rollback ())
98
111
}
99
112
100
113
type savepointHandle struct {
101
- * sql. Tx
114
+ * lockingTx
102
115
savepointID string
103
116
}
104
117
@@ -107,21 +120,21 @@ func (h *savepointHandle) InTransaction() bool {
107
120
}
108
121
109
122
func (h * savepointHandle ) Transact (ctx context.Context ) (TransactableHandle , error ) {
110
- savepointID , err := newTxSavepoint (ctx , h .Tx )
123
+ savepointID , err := newTxSavepoint (ctx , h .lockingTx )
111
124
if err != nil {
112
125
return nil , err
113
126
}
114
127
115
- return & savepointHandle {Tx : h .Tx , savepointID : savepointID }, nil
128
+ return & savepointHandle {lockingTx : h .lockingTx , savepointID : savepointID }, nil
116
129
}
117
130
118
131
func (h * savepointHandle ) Done (err error ) error {
119
132
if err == nil {
120
- _ , execErr := h .Tx . Exec ( fmt .Sprintf (commitSavepointQuery , h .savepointID ))
133
+ _ , execErr := h .ExecContext ( context . Background (), fmt .Sprintf (commitSavepointQuery , h .savepointID ))
121
134
return execErr
122
135
}
123
136
124
- _ , execErr := h .Tx . Exec ( fmt .Sprintf (rollbackSavepointQuery , h .savepointID ))
137
+ _ , execErr := h .ExecContext ( context . Background (), fmt .Sprintf (rollbackSavepointQuery , h .savepointID ))
125
138
return errors .Append (err , execErr )
126
139
}
127
140
@@ -131,7 +144,7 @@ const (
131
144
rollbackSavepointQuery = "ROLLBACK TO %s"
132
145
)
133
146
134
- func newTxSavepoint (ctx context.Context , tx * sql. Tx ) (string , error ) {
147
+ func newTxSavepoint (ctx context.Context , tx * lockingTx ) (string , error ) {
135
148
savepointID , err := makeSavepointID ()
136
149
if err != nil {
137
150
return "" , err
@@ -153,3 +166,71 @@ func makeSavepointID() (string, error) {
153
166
154
167
return fmt .Sprintf ("sp_%s" , strings .ReplaceAll (id .String (), "-" , "_" )), nil
155
168
}
169
+
170
+ var ErrConcurrentTransactionAccess = errors .New ("transaction used concurrently" )
171
+
172
+ // lockingTx wraps a *sql.Tx with a mutex, and reports when a caller tries to
173
+ // use the transaction concurrently. Since using a transaction concurrently is
174
+ // unsafe, we want to catch these issues. If lockingTx detects that a
175
+ // transaction is being used concurrently, it will log an error and attempt to
176
+ // serialize the transaction accesses.
177
+ //
178
+ // NOTE: this is not foolproof. Interleaving savepoints, accessing rows while
179
+ // sending another query, etc. will still fail, so the logged error is a
180
+ // notification that something needs fixed, not a notification that the locking
181
+ // successfully prevented an issue. In the future, this will likely be upgraded
182
+ // to a hard error. Think of this like the race detector, not a race protector.
183
+ type lockingTx struct {
184
+ tx * sql.Tx
185
+ mu sync.Mutex
186
+ logger log.Logger
187
+ }
188
+
189
+ func (t * lockingTx ) lock () {
190
+ if ! t .mu .TryLock () {
191
+ // For now, log an error, but try to serialize access anyways to try to
192
+ // keep things slightly safer.
193
+ err := errors .WithStack (ErrConcurrentTransactionAccess )
194
+ t .logger .Error ("transaction used concurrently" , log .Error (err ))
195
+ t .mu .Lock ()
196
+ }
197
+ }
198
+
199
+ func (t * lockingTx ) unlock () {
200
+ t .mu .Unlock ()
201
+ }
202
+
203
+ func (t * lockingTx ) ExecContext (ctx context.Context , query string , args ... any ) (sql.Result , error ) {
204
+ t .lock ()
205
+ defer t .unlock ()
206
+
207
+ return t .tx .ExecContext (ctx , query , args ... )
208
+ }
209
+
210
+ func (t * lockingTx ) QueryContext (ctx context.Context , query string , args ... any ) (* sql.Rows , error ) {
211
+ t .lock ()
212
+ defer t .unlock ()
213
+
214
+ return t .tx .QueryContext (ctx , query , args ... )
215
+ }
216
+
217
+ func (t * lockingTx ) QueryRowContext (ctx context.Context , query string , args ... any ) * sql.Row {
218
+ t .lock ()
219
+ defer t .unlock ()
220
+
221
+ return t .tx .QueryRowContext (ctx , query , args ... )
222
+ }
223
+
224
+ func (t * lockingTx ) Commit () error {
225
+ t .lock ()
226
+ defer t .unlock ()
227
+
228
+ return t .tx .Commit ()
229
+ }
230
+
231
+ func (t * lockingTx ) Rollback () error {
232
+ t .lock ()
233
+ defer t .unlock ()
234
+
235
+ return t .tx .Rollback ()
236
+ }
0 commit comments