@@ -5,8 +5,10 @@ import (
5
5
"database/sql"
6
6
"fmt"
7
7
"strings"
8
+ "sync"
8
9
9
10
"github.com/google/uuid"
11
+ "github.com/sourcegraph/log"
10
12
11
13
"github.com/sourcegraph/sourcegraph/internal/database/dbutil"
12
14
"github.com/sourcegraph/sourcegraph/lib/errors"
@@ -48,7 +50,7 @@ func NewHandleWithDB(db *sql.DB, txOptions sql.TxOptions) TransactableHandle {
48
50
49
51
// NewHandleWithTx returns a new transactable database handle using the given transaction.
50
52
func NewHandleWithTx (tx * sql.Tx , txOptions sql.TxOptions ) TransactableHandle {
51
- return & txHandle {Tx : tx , txOptions : txOptions }
53
+ return & txHandle {lockingTx : & lockingTx { tx : tx } , txOptions : txOptions }
52
54
}
53
55
54
56
type dbHandle struct {
@@ -65,15 +67,15 @@ func (h *dbHandle) Transact(ctx context.Context) (TransactableHandle, error) {
65
67
if err != nil {
66
68
return nil , err
67
69
}
68
- return & txHandle {Tx : tx , txOptions : h .txOptions }, nil
70
+ return & txHandle {lockingTx : & lockingTx { tx : tx } , txOptions : h .txOptions }, nil
69
71
}
70
72
71
73
func (h * dbHandle ) Done (err error ) error {
72
74
return errors .Append (err , ErrNotInTransaction )
73
75
}
74
76
75
77
type txHandle struct {
76
- * sql. Tx
78
+ * lockingTx
77
79
txOptions sql.TxOptions
78
80
}
79
81
@@ -82,23 +84,23 @@ func (h *txHandle) InTransaction() bool {
82
84
}
83
85
84
86
func (h * txHandle ) Transact (ctx context.Context ) (TransactableHandle , error ) {
85
- savepointID , err := newTxSavepoint (ctx , h .Tx )
87
+ savepointID , err := newTxSavepoint (ctx , h .lockingTx )
86
88
if err != nil {
87
89
return nil , err
88
90
}
89
91
90
- return & savepointHandle {Tx : h .Tx , savepointID : savepointID }, nil
92
+ return & savepointHandle {lockingTx : h .lockingTx , savepointID : savepointID }, nil
91
93
}
92
94
93
95
func (h * txHandle ) Done (err error ) error {
94
96
if err == nil {
95
- return h .Tx . Commit ()
97
+ return h .Commit ()
96
98
}
97
- return errors .Append (err , h .Tx . Rollback ())
99
+ return errors .Append (err , h .Rollback ())
98
100
}
99
101
100
102
type savepointHandle struct {
101
- * sql. Tx
103
+ * lockingTx
102
104
savepointID string
103
105
}
104
106
@@ -107,21 +109,21 @@ func (h *savepointHandle) InTransaction() bool {
107
109
}
108
110
109
111
func (h * savepointHandle ) Transact (ctx context.Context ) (TransactableHandle , error ) {
110
- savepointID , err := newTxSavepoint (ctx , h .Tx )
112
+ savepointID , err := newTxSavepoint (ctx , h .lockingTx )
111
113
if err != nil {
112
114
return nil , err
113
115
}
114
116
115
- return & savepointHandle {Tx : h .Tx , savepointID : savepointID }, nil
117
+ return & savepointHandle {lockingTx : h .lockingTx , savepointID : savepointID }, nil
116
118
}
117
119
118
120
func (h * savepointHandle ) Done (err error ) error {
119
121
if err == nil {
120
- _ , execErr := h .Tx . Exec ( fmt .Sprintf (commitSavepointQuery , h .savepointID ))
122
+ _ , execErr := h .ExecContext ( context . Background (), fmt .Sprintf (commitSavepointQuery , h .savepointID ))
121
123
return execErr
122
124
}
123
125
124
- _ , execErr := h .Tx . Exec ( fmt .Sprintf (rollbackSavepointQuery , h .savepointID ))
126
+ _ , execErr := h .ExecContext ( context . Background (), fmt .Sprintf (rollbackSavepointQuery , h .savepointID ))
125
127
return errors .Append (err , execErr )
126
128
}
127
129
@@ -131,7 +133,7 @@ const (
131
133
rollbackSavepointQuery = "ROLLBACK TO %s"
132
134
)
133
135
134
- func newTxSavepoint (ctx context.Context , tx * sql. Tx ) (string , error ) {
136
+ func newTxSavepoint (ctx context.Context , tx * lockingTx ) (string , error ) {
135
137
savepointID , err := makeSavepointID ()
136
138
if err != nil {
137
139
return "" , err
@@ -153,3 +155,78 @@ func makeSavepointID() (string, error) {
153
155
154
156
return fmt .Sprintf ("sp_%s" , strings .ReplaceAll (id .String (), "-" , "_" )), nil
155
157
}
158
+
159
+ var ErrConcurrentTransactions = errors .New ("transaction used concurrently" )
160
+
161
+ // lockingTx wraps a *sql.Tx with a mutex, and reports when a caller tries
162
+ // to use the transaction concurrently. Since using a transaction concurrently
163
+ // is unsafe, we want to catch these issues. Currently, lockingTx will just
164
+ // log an error and serialize accesses to the wrapped *sql.Tx, but in the future
165
+ // concurrent calls may be upgraded to an error.
166
+ type lockingTx struct {
167
+ tx * sql.Tx
168
+ mu sync.Mutex
169
+ }
170
+
171
+ func (t * lockingTx ) lock () error {
172
+ if ! t .mu .TryLock () {
173
+ err := errors .WithStack (ErrConcurrentTransactions )
174
+ log .Scoped ("internal" , "database" ).Error ("transaction used concurrently" , log .Error (err ))
175
+ return err
176
+ }
177
+ return nil
178
+ }
179
+
180
+ func (t * lockingTx ) unlock () {
181
+ t .mu .Unlock ()
182
+ }
183
+
184
+ func (t * lockingTx ) ExecContext (ctx context.Context , query string , args ... any ) (sql.Result , error ) {
185
+ if err := t .lock (); err != nil {
186
+ return nil , err
187
+ }
188
+ defer t .unlock ()
189
+
190
+ return t .tx .ExecContext (ctx , query , args ... )
191
+ }
192
+
193
+ func (t * lockingTx ) QueryContext (ctx context.Context , query string , args ... any ) (* sql.Rows , error ) {
194
+ if err := t .lock (); err != nil {
195
+ return nil , err
196
+ }
197
+ defer t .unlock ()
198
+
199
+ return t .tx .QueryContext (ctx , query , args ... )
200
+ }
201
+
202
+ func (t * lockingTx ) QueryRowContext (ctx context.Context , query string , args ... any ) * sql.Row {
203
+ if err := t .lock (); err != nil {
204
+ // XXX(camdencheek): There is no way to construct a row that has an
205
+ // error, so in order for this to return an error, we'll need to have
206
+ // it return an interface. Instead, we attempt to acquire the lock
207
+ // to make this at least be safer if we can't report the error other
208
+ // than through logs.
209
+ t .mu .Lock ()
210
+ }
211
+ defer t .unlock ()
212
+
213
+ return t .tx .QueryRowContext (ctx , query , args ... )
214
+ }
215
+
216
+ func (t * lockingTx ) Commit () error {
217
+ if err := t .lock (); err != nil {
218
+ return err
219
+ }
220
+ defer t .unlock ()
221
+
222
+ return t .tx .Commit ()
223
+ }
224
+
225
+ func (t * lockingTx ) Rollback () error {
226
+ if err := t .lock (); err != nil {
227
+ return err
228
+ }
229
+ defer t .unlock ()
230
+
231
+ return t .tx .Rollback ()
232
+ }
0 commit comments