@@ -1368,12 +1368,12 @@ func (db *DB) execDC(ctx context.Context, dc *driverConn, release func(error), q
1368
1368
}
1369
1369
if ok {
1370
1370
var nvdargs []driver.NamedValue
1371
- nvdargs , err = driverArgs (dc .ci , nil , args )
1372
- if err != nil {
1373
- return nil , err
1374
- }
1375
1371
var resi driver.Result
1376
1372
withLock (dc , func () {
1373
+ nvdargs , err = driverArgsConnLocked (dc .ci , nil , args )
1374
+ if err != nil {
1375
+ return
1376
+ }
1377
1377
resi , err = ctxDriverExec (ctx , execerCtx , execer , query , nvdargs )
1378
1378
})
1379
1379
if err != driver .ErrSkip {
@@ -1439,13 +1439,14 @@ func (db *DB) queryDC(ctx, txctx context.Context, dc *driverConn, releaseConn fu
1439
1439
queryer , ok = dc .ci .(driver.Queryer )
1440
1440
}
1441
1441
if ok {
1442
- nvdargs , err := driverArgs (dc .ci , nil , args )
1443
- if err != nil {
1444
- releaseConn (err )
1445
- return nil , err
1446
- }
1442
+ var nvdargs []driver.NamedValue
1447
1443
var rowsi driver.Rows
1444
+ var err error
1448
1445
withLock (dc , func () {
1446
+ nvdargs , err = driverArgsConnLocked (dc .ci , nil , args )
1447
+ if err != nil {
1448
+ return
1449
+ }
1449
1450
rowsi , err = ctxDriverQuery (ctx , queryerCtx , queryer , query , nvdargs )
1450
1451
})
1451
1452
if err != driver .ErrSkip {
@@ -2034,11 +2035,14 @@ func (tx *Tx) StmtContext(ctx context.Context, stmt *Stmt) *Stmt {
2034
2035
stmt .mu .Unlock ()
2035
2036
2036
2037
if si == nil {
2037
- cs , err := stmt .prepareOnConnLocked (ctx , dc )
2038
+ withLock (dc , func () {
2039
+ var ds * driverStmt
2040
+ ds , err = stmt .prepareOnConnLocked (ctx , dc )
2041
+ si = ds .si
2042
+ })
2038
2043
if err != nil {
2039
2044
return & Stmt {stickyErr : err }
2040
2045
}
2041
- si = cs .si
2042
2046
}
2043
2047
parentStmt = stmt
2044
2048
}
@@ -2230,14 +2234,14 @@ func (s *Stmt) Exec(args ...interface{}) (Result, error) {
2230
2234
}
2231
2235
2232
2236
func resultFromStatement (ctx context.Context , ci driver.Conn , ds * driverStmt , args ... interface {}) (Result , error ) {
2233
- dargs , err := driverArgs (ci , ds , args )
2237
+ ds .Lock ()
2238
+ defer ds .Unlock ()
2239
+
2240
+ dargs , err := driverArgsConnLocked (ci , ds , args )
2234
2241
if err != nil {
2235
2242
return nil , err
2236
2243
}
2237
2244
2238
- ds .Lock ()
2239
- defer ds .Unlock ()
2240
-
2241
2245
resi , err := ctxDriverStmtExec (ctx , ds .si , dargs )
2242
2246
if err != nil {
2243
2247
return nil , err
@@ -2401,10 +2405,10 @@ func (s *Stmt) Query(args ...interface{}) (*Rows, error) {
2401
2405
}
2402
2406
2403
2407
func rowsiFromStatement (ctx context.Context , ci driver.Conn , ds * driverStmt , args ... interface {}) (driver.Rows , error ) {
2404
- var want int
2405
- withLock ( ds , func () {
2406
- want = ds . si . NumInput ()
2407
- } )
2408
+ ds . Lock ()
2409
+ defer ds . Unlock ()
2410
+
2411
+ want := ds . si . NumInput ( )
2408
2412
2409
2413
// -1 means the driver doesn't know how to count the number of
2410
2414
// placeholders, so we won't sanity check input here and instead let the
@@ -2413,14 +2417,11 @@ func rowsiFromStatement(ctx context.Context, ci driver.Conn, ds *driverStmt, arg
2413
2417
return nil , fmt .Errorf ("sql: statement expects %d inputs; got %d" , want , len (args ))
2414
2418
}
2415
2419
2416
- dargs , err := driverArgs (ci , ds , args )
2420
+ dargs , err := driverArgsConnLocked (ci , ds , args )
2417
2421
if err != nil {
2418
2422
return nil , err
2419
2423
}
2420
2424
2421
- ds .Lock ()
2422
- defer ds .Unlock ()
2423
-
2424
2425
rowsi , err := ctxDriverStmtQuery (ctx , ds .si , dargs )
2425
2426
if err != nil {
2426
2427
return nil , err
@@ -2583,9 +2584,16 @@ func (rs *Rows) nextLocked() (doClose, ok bool) {
2583
2584
if rs .closed {
2584
2585
return false , false
2585
2586
}
2587
+
2588
+ // Lock the driver connection before calling the driver interface
2589
+ // rowsi to prevent a Tx from rolling back the connection at the same time.
2590
+ rs .dc .Lock ()
2591
+ defer rs .dc .Unlock ()
2592
+
2586
2593
if rs .lastcols == nil {
2587
2594
rs .lastcols = make ([]driver.Value , len (rs .rowsi .Columns ()))
2588
2595
}
2596
+
2589
2597
rs .lasterr = rs .rowsi .Next (rs .lastcols )
2590
2598
if rs .lasterr != nil {
2591
2599
// Close the connection if there is a driver error.
@@ -2635,6 +2643,12 @@ func (rs *Rows) NextResultSet() bool {
2635
2643
doClose = true
2636
2644
return false
2637
2645
}
2646
+
2647
+ // Lock the driver connection before calling the driver interface
2648
+ // rowsi to prevent a Tx from rolling back the connection at the same time.
2649
+ rs .dc .Lock ()
2650
+ defer rs .dc .Unlock ()
2651
+
2638
2652
rs .lasterr = nextResultSet .NextResultSet ()
2639
2653
if rs .lasterr != nil {
2640
2654
doClose = true
@@ -2666,6 +2680,9 @@ func (rs *Rows) Columns() ([]string, error) {
2666
2680
if rs .rowsi == nil {
2667
2681
return nil , errors .New ("sql: no Rows available" )
2668
2682
}
2683
+ rs .dc .Lock ()
2684
+ defer rs .dc .Unlock ()
2685
+
2669
2686
return rs .rowsi .Columns (), nil
2670
2687
}
2671
2688
@@ -2680,7 +2697,10 @@ func (rs *Rows) ColumnTypes() ([]*ColumnType, error) {
2680
2697
if rs .rowsi == nil {
2681
2698
return nil , errors .New ("sql: no Rows available" )
2682
2699
}
2683
- return rowsColumnInfoSetup (rs .rowsi ), nil
2700
+ rs .dc .Lock ()
2701
+ defer rs .dc .Unlock ()
2702
+
2703
+ return rowsColumnInfoSetupConnLocked (rs .rowsi ), nil
2684
2704
}
2685
2705
2686
2706
// ColumnType contains the name and type of a column.
@@ -2741,7 +2761,7 @@ func (ci *ColumnType) DatabaseTypeName() string {
2741
2761
return ci .databaseType
2742
2762
}
2743
2763
2744
- func rowsColumnInfoSetup (rowsi driver.Rows ) []* ColumnType {
2764
+ func rowsColumnInfoSetupConnLocked (rowsi driver.Rows ) []* ColumnType {
2745
2765
names := rowsi .Columns ()
2746
2766
2747
2767
list := make ([]* ColumnType , len (names ))
0 commit comments