Skip to content

Commit 1126d14

Browse files
committed
database/sql: ensure all driver interfaces are called under single lock
Russ pointed out in a previous CL golang.org/cl/65731 that not only was the locking incomplete, previous changes did not correctly lock driver calls in other sections. After inspecting driverConn, driverStmt, driverResult, Tx, and Rows structs where driver interfaces are stored, I discovered a few more places that failed to lock driver calls. The largest of these was the parameter type converter "driverArgs". driverArgs was typically called right before another call to the driver in a locked region, so I made the entire driverArgs expect a locked driver mutex and combined the region. This should not be a problem because the connection is pulled out of the connection pool either way so there shouldn't be contention. Fixes #21117 Change-Id: I88d46f74dca25fb11a30f0bf8e79785a73133d23 Reviewed-on: https://go-review.googlesource.com/71433 Run-TryBot: Daniel Theophanes <[email protected]> TryBot-Result: Gobot Gobot <[email protected]> Reviewed-by: Russ Cox <[email protected]>
1 parent 9865821 commit 1126d14

File tree

5 files changed

+55
-39
lines changed

5 files changed

+55
-39
lines changed

src/database/sql/convert.go

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ import (
1212
"fmt"
1313
"reflect"
1414
"strconv"
15-
"sync"
1615
"time"
1716
"unicode"
1817
"unicode/utf8"
@@ -38,17 +37,10 @@ func validateNamedValueName(name string) error {
3837
return fmt.Errorf("name %q does not begin with a letter", name)
3938
}
4039

41-
func driverNumInput(ds *driverStmt) int {
42-
ds.Lock()
43-
defer ds.Unlock() // in case NumInput panics
44-
return ds.si.NumInput()
45-
}
46-
4740
// ccChecker wraps the driver.ColumnConverter and allows it to be used
4841
// as if it were a NamedValueChecker. If the driver ColumnConverter
4942
// is not present then the NamedValueChecker will return driver.ErrSkip.
5043
type ccChecker struct {
51-
sync.Locker
5244
cci driver.ColumnConverter
5345
want int
5446
}
@@ -88,9 +80,7 @@ func (c ccChecker) CheckNamedValue(nv *driver.NamedValue) error {
8880
// same error.
8981
var err error
9082
arg := nv.Value
91-
c.Lock()
9283
nv.Value, err = c.cci.ColumnConverter(index).ConvertValue(arg)
93-
c.Unlock()
9484
if err != nil {
9585
return err
9686
}
@@ -112,7 +102,7 @@ func defaultCheckNamedValue(nv *driver.NamedValue) (err error) {
112102
// Stmt.Query into driver Values.
113103
//
114104
// The statement ds may be nil, if no statement is available.
115-
func driverArgs(ci driver.Conn, ds *driverStmt, args []interface{}) ([]driver.NamedValue, error) {
105+
func driverArgsConnLocked(ci driver.Conn, ds *driverStmt, args []interface{}) ([]driver.NamedValue, error) {
116106
nvargs := make([]driver.NamedValue, len(args))
117107

118108
// -1 means the driver doesn't know how to count the number of
@@ -124,8 +114,7 @@ func driverArgs(ci driver.Conn, ds *driverStmt, args []interface{}) ([]driver.Na
124114
var cc ccChecker
125115
if ds != nil {
126116
si = ds.si
127-
want = driverNumInput(ds)
128-
cc.Locker = ds.Locker
117+
want = ds.si.NumInput()
129118
cc.want = want
130119
}
131120

src/database/sql/convert_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -481,7 +481,7 @@ func TestDriverArgs(t *testing.T) {
481481
}
482482
for i, tt := range tests {
483483
ds := &driverStmt{Locker: &sync.Mutex{}, si: stubDriverStmt{nil}}
484-
got, err := driverArgs(nil, ds, tt.args)
484+
got, err := driverArgsConnLocked(nil, ds, tt.args)
485485
if err != nil {
486486
t.Errorf("test[%d]: %v", i, err)
487487
continue

src/database/sql/fakedb_test.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1005,6 +1005,7 @@ type rowsCursor struct {
10051005
}
10061006

10071007
func (rc *rowsCursor) touchMem() {
1008+
rc.parentMem.touchMem()
10081009
rc.line++
10091010
}
10101011

src/database/sql/sql.go

Lines changed: 45 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1368,12 +1368,12 @@ func (db *DB) execDC(ctx context.Context, dc *driverConn, release func(error), q
13681368
}
13691369
if ok {
13701370
var nvdargs []driver.NamedValue
1371-
nvdargs, err = driverArgs(dc.ci, nil, args)
1372-
if err != nil {
1373-
return nil, err
1374-
}
13751371
var resi driver.Result
13761372
withLock(dc, func() {
1373+
nvdargs, err = driverArgsConnLocked(dc.ci, nil, args)
1374+
if err != nil {
1375+
return
1376+
}
13771377
resi, err = ctxDriverExec(ctx, execerCtx, execer, query, nvdargs)
13781378
})
13791379
if err != driver.ErrSkip {
@@ -1439,13 +1439,14 @@ func (db *DB) queryDC(ctx, txctx context.Context, dc *driverConn, releaseConn fu
14391439
queryer, ok = dc.ci.(driver.Queryer)
14401440
}
14411441
if ok {
1442-
nvdargs, err := driverArgs(dc.ci, nil, args)
1443-
if err != nil {
1444-
releaseConn(err)
1445-
return nil, err
1446-
}
1442+
var nvdargs []driver.NamedValue
14471443
var rowsi driver.Rows
1444+
var err error
14481445
withLock(dc, func() {
1446+
nvdargs, err = driverArgsConnLocked(dc.ci, nil, args)
1447+
if err != nil {
1448+
return
1449+
}
14491450
rowsi, err = ctxDriverQuery(ctx, queryerCtx, queryer, query, nvdargs)
14501451
})
14511452
if err != driver.ErrSkip {
@@ -2034,11 +2035,14 @@ func (tx *Tx) StmtContext(ctx context.Context, stmt *Stmt) *Stmt {
20342035
stmt.mu.Unlock()
20352036

20362037
if si == nil {
2037-
cs, err := stmt.prepareOnConnLocked(ctx, dc)
2038+
withLock(dc, func() {
2039+
var ds *driverStmt
2040+
ds, err = stmt.prepareOnConnLocked(ctx, dc)
2041+
si = ds.si
2042+
})
20382043
if err != nil {
20392044
return &Stmt{stickyErr: err}
20402045
}
2041-
si = cs.si
20422046
}
20432047
parentStmt = stmt
20442048
}
@@ -2230,14 +2234,14 @@ func (s *Stmt) Exec(args ...interface{}) (Result, error) {
22302234
}
22312235

22322236
func resultFromStatement(ctx context.Context, ci driver.Conn, ds *driverStmt, args ...interface{}) (Result, error) {
2233-
dargs, err := driverArgs(ci, ds, args)
2237+
ds.Lock()
2238+
defer ds.Unlock()
2239+
2240+
dargs, err := driverArgsConnLocked(ci, ds, args)
22342241
if err != nil {
22352242
return nil, err
22362243
}
22372244

2238-
ds.Lock()
2239-
defer ds.Unlock()
2240-
22412245
resi, err := ctxDriverStmtExec(ctx, ds.si, dargs)
22422246
if err != nil {
22432247
return nil, err
@@ -2401,10 +2405,10 @@ func (s *Stmt) Query(args ...interface{}) (*Rows, error) {
24012405
}
24022406

24032407
func rowsiFromStatement(ctx context.Context, ci driver.Conn, ds *driverStmt, args ...interface{}) (driver.Rows, error) {
2404-
var want int
2405-
withLock(ds, func() {
2406-
want = ds.si.NumInput()
2407-
})
2408+
ds.Lock()
2409+
defer ds.Unlock()
2410+
2411+
want := ds.si.NumInput()
24082412

24092413
// -1 means the driver doesn't know how to count the number of
24102414
// placeholders, so we won't sanity check input here and instead let the
@@ -2413,14 +2417,11 @@ func rowsiFromStatement(ctx context.Context, ci driver.Conn, ds *driverStmt, arg
24132417
return nil, fmt.Errorf("sql: statement expects %d inputs; got %d", want, len(args))
24142418
}
24152419

2416-
dargs, err := driverArgs(ci, ds, args)
2420+
dargs, err := driverArgsConnLocked(ci, ds, args)
24172421
if err != nil {
24182422
return nil, err
24192423
}
24202424

2421-
ds.Lock()
2422-
defer ds.Unlock()
2423-
24242425
rowsi, err := ctxDriverStmtQuery(ctx, ds.si, dargs)
24252426
if err != nil {
24262427
return nil, err
@@ -2583,9 +2584,16 @@ func (rs *Rows) nextLocked() (doClose, ok bool) {
25832584
if rs.closed {
25842585
return false, false
25852586
}
2587+
2588+
// Lock the driver connection before calling the driver interface
2589+
// rowsi to prevent a Tx from rolling back the connection at the same time.
2590+
rs.dc.Lock()
2591+
defer rs.dc.Unlock()
2592+
25862593
if rs.lastcols == nil {
25872594
rs.lastcols = make([]driver.Value, len(rs.rowsi.Columns()))
25882595
}
2596+
25892597
rs.lasterr = rs.rowsi.Next(rs.lastcols)
25902598
if rs.lasterr != nil {
25912599
// Close the connection if there is a driver error.
@@ -2635,6 +2643,12 @@ func (rs *Rows) NextResultSet() bool {
26352643
doClose = true
26362644
return false
26372645
}
2646+
2647+
// Lock the driver connection before calling the driver interface
2648+
// rowsi to prevent a Tx from rolling back the connection at the same time.
2649+
rs.dc.Lock()
2650+
defer rs.dc.Unlock()
2651+
26382652
rs.lasterr = nextResultSet.NextResultSet()
26392653
if rs.lasterr != nil {
26402654
doClose = true
@@ -2666,6 +2680,9 @@ func (rs *Rows) Columns() ([]string, error) {
26662680
if rs.rowsi == nil {
26672681
return nil, errors.New("sql: no Rows available")
26682682
}
2683+
rs.dc.Lock()
2684+
defer rs.dc.Unlock()
2685+
26692686
return rs.rowsi.Columns(), nil
26702687
}
26712688

@@ -2680,7 +2697,10 @@ func (rs *Rows) ColumnTypes() ([]*ColumnType, error) {
26802697
if rs.rowsi == nil {
26812698
return nil, errors.New("sql: no Rows available")
26822699
}
2683-
return rowsColumnInfoSetup(rs.rowsi), nil
2700+
rs.dc.Lock()
2701+
defer rs.dc.Unlock()
2702+
2703+
return rowsColumnInfoSetupConnLocked(rs.rowsi), nil
26842704
}
26852705

26862706
// ColumnType contains the name and type of a column.
@@ -2741,7 +2761,7 @@ func (ci *ColumnType) DatabaseTypeName() string {
27412761
return ci.databaseType
27422762
}
27432763

2744-
func rowsColumnInfoSetup(rowsi driver.Rows) []*ColumnType {
2764+
func rowsColumnInfoSetupConnLocked(rowsi driver.Rows) []*ColumnType {
27452765
names := rowsi.Columns()
27462766

27472767
list := make([]*ColumnType, len(names))

src/database/sql/sql_test.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3157,6 +3157,9 @@ func TestIssue6081(t *testing.T) {
31573157
// In the test, a context is canceled while the query is in process so
31583158
// the internal rollback will run concurrently with the explicitly called
31593159
// Tx.Rollback.
3160+
//
3161+
// The addition of calling rows.Next also tests
3162+
// Issue 21117.
31603163
func TestIssue18429(t *testing.T) {
31613164
db := newTestDB(t, "people")
31623165
defer closeDB(t, db)
@@ -3189,6 +3192,9 @@ func TestIssue18429(t *testing.T) {
31893192
// reported.
31903193
rows, _ := tx.QueryContext(ctx, "WAIT|"+qwait+"|SELECT|people|name|")
31913194
if rows != nil {
3195+
// Call Next to test Issue 21117 and check for races.
3196+
for rows.Next() {
3197+
}
31923198
rows.Close()
31933199
}
31943200
// This call will race with the context cancel rollback to complete

0 commit comments

Comments
 (0)