Skip to content

Commit c026845

Browse files
kardianosrsc
authored andcommitted
database/sql: record the context error in Rows if canceled
Previously it was intended that Rows.Scan would return an error and Rows.Err would return nil. This was problematic because drivers could not differentiate between a normal Rows.Close or a context cancel close. The alternative is to require drivers to return a Scan to return an error if the driver is closed while there are still rows to be read. This is currently not how several drivers currently work and may be difficult to detect when there are additional rows. At the same time guard the the Rows.lasterr and prevent a close while a Rows operation is active. For the drivers that do not have Context methods, do not check for context cancelation after the operation, but before for any operation that may modify the database state. Fixes #18961 Change-Id: I49a25318ecd9f97a35d5b50540ecd850c01cfa5e Reviewed-on: https://go-review.googlesource.com/36485 Reviewed-by: Russ Cox <[email protected]> Run-TryBot: Russ Cox <[email protected]> TryBot-Result: Gobot Gobot <[email protected]>
1 parent 0c9325e commit c026845

File tree

3 files changed

+100
-66
lines changed

3 files changed

+100
-66
lines changed

src/database/sql/ctxutil.go

+20-34
Original file line numberDiff line numberDiff line change
@@ -35,15 +35,12 @@ func ctxDriverExec(ctx context.Context, execer driver.Execer, query string, nvda
3535
return nil, err
3636
}
3737

38-
resi, err := execer.Exec(query, dargs)
39-
if err == nil {
40-
select {
41-
default:
42-
case <-ctx.Done():
43-
return resi, ctx.Err()
44-
}
38+
select {
39+
default:
40+
case <-ctx.Done():
41+
return nil, ctx.Err()
4542
}
46-
return resi, err
43+
return execer.Exec(query, dargs)
4744
}
4845

4946
func ctxDriverQuery(ctx context.Context, queryer driver.Queryer, query string, nvdargs []driver.NamedValue) (driver.Rows, error) {
@@ -56,16 +53,12 @@ func ctxDriverQuery(ctx context.Context, queryer driver.Queryer, query string, n
5653
return nil, err
5754
}
5855

59-
rowsi, err := queryer.Query(query, dargs)
60-
if err == nil {
61-
select {
62-
default:
63-
case <-ctx.Done():
64-
rowsi.Close()
65-
return nil, ctx.Err()
66-
}
56+
select {
57+
default:
58+
case <-ctx.Done():
59+
return nil, ctx.Err()
6760
}
68-
return rowsi, err
61+
return queryer.Query(query, dargs)
6962
}
7063

7164
func ctxDriverStmtExec(ctx context.Context, si driver.Stmt, nvdargs []driver.NamedValue) (driver.Result, error) {
@@ -77,15 +70,12 @@ func ctxDriverStmtExec(ctx context.Context, si driver.Stmt, nvdargs []driver.Nam
7770
return nil, err
7871
}
7972

80-
resi, err := si.Exec(dargs)
81-
if err == nil {
82-
select {
83-
default:
84-
case <-ctx.Done():
85-
return resi, ctx.Err()
86-
}
73+
select {
74+
default:
75+
case <-ctx.Done():
76+
return nil, ctx.Err()
8777
}
88-
return resi, err
78+
return si.Exec(dargs)
8979
}
9080

9181
func ctxDriverStmtQuery(ctx context.Context, si driver.Stmt, nvdargs []driver.NamedValue) (driver.Rows, error) {
@@ -97,16 +87,12 @@ func ctxDriverStmtQuery(ctx context.Context, si driver.Stmt, nvdargs []driver.Na
9787
return nil, err
9888
}
9989

100-
rowsi, err := si.Query(dargs)
101-
if err == nil {
102-
select {
103-
default:
104-
case <-ctx.Done():
105-
rowsi.Close()
106-
return nil, ctx.Err()
107-
}
90+
select {
91+
default:
92+
case <-ctx.Done():
93+
return nil, ctx.Err()
10894
}
109-
return rowsi, err
95+
return si.Query(dargs)
11096
}
11197

11298
var errLevelNotSupported = errors.New("sql: selected isolation level is not supported")

src/database/sql/sql.go

+73-29
Original file line numberDiff line numberDiff line change
@@ -2071,14 +2071,21 @@ type Rows struct {
20712071
dc *driverConn // owned; must call releaseConn when closed to release
20722072
releaseConn func(error)
20732073
rowsi driver.Rows
2074+
cancel func() // called when Rows is closed, may be nil.
2075+
closeStmt *driverStmt // if non-nil, statement to Close on close
20742076

2075-
// closed value is 1 when the Rows is closed.
2076-
// Use atomic operations on value when checking value.
2077-
closed int32
2078-
cancel func() // called when Rows is closed, may be nil.
2079-
lastcols []driver.Value
2080-
lasterr error // non-nil only if closed is true
2081-
closeStmt *driverStmt // if non-nil, statement to Close on close
2077+
// closemu prevents Rows from closing while there
2078+
// is an active streaming result. It is held for read during non-close operations
2079+
// and exclusively during close.
2080+
//
2081+
// closemu guards lasterr and closed.
2082+
closemu sync.RWMutex
2083+
closed bool
2084+
lasterr error // non-nil only if closed is true
2085+
2086+
// lastcols is only used in Scan, Next, and NextResultSet which are expected
2087+
// not not be called concurrently.
2088+
lastcols []driver.Value
20822089
}
20832090

20842091
func (rs *Rows) initContextClose(ctx context.Context) {
@@ -2089,7 +2096,7 @@ func (rs *Rows) initContextClose(ctx context.Context) {
20892096
// awaitDone blocks until the rows are closed or the context canceled.
20902097
func (rs *Rows) awaitDone(ctx context.Context) {
20912098
<-ctx.Done()
2092-
rs.Close()
2099+
rs.close(ctx.Err())
20932100
}
20942101

20952102
// Next prepares the next result row for reading with the Scan method. It
@@ -2099,8 +2106,19 @@ func (rs *Rows) awaitDone(ctx context.Context) {
20992106
//
21002107
// Every call to Scan, even the first one, must be preceded by a call to Next.
21012108
func (rs *Rows) Next() bool {
2102-
if rs.isClosed() {
2103-
return false
2109+
var doClose, ok bool
2110+
withLock(rs.closemu.RLocker(), func() {
2111+
doClose, ok = rs.nextLocked()
2112+
})
2113+
if doClose {
2114+
rs.Close()
2115+
}
2116+
return ok
2117+
}
2118+
2119+
func (rs *Rows) nextLocked() (doClose, ok bool) {
2120+
if rs.closed {
2121+
return false, false
21042122
}
21052123
if rs.lastcols == nil {
21062124
rs.lastcols = make([]driver.Value, len(rs.rowsi.Columns()))
@@ -2109,23 +2127,21 @@ func (rs *Rows) Next() bool {
21092127
if rs.lasterr != nil {
21102128
// Close the connection if there is a driver error.
21112129
if rs.lasterr != io.EOF {
2112-
rs.Close()
2113-
return false
2130+
return true, false
21142131
}
21152132
nextResultSet, ok := rs.rowsi.(driver.RowsNextResultSet)
21162133
if !ok {
2117-
rs.Close()
2118-
return false
2134+
return true, false
21192135
}
21202136
// The driver is at the end of the current result set.
21212137
// Test to see if there is another result set after the current one.
21222138
// Only close Rows if there is no further result sets to read.
21232139
if !nextResultSet.HasNextResultSet() {
2124-
rs.Close()
2140+
doClose = true
21252141
}
2126-
return false
2142+
return doClose, false
21272143
}
2128-
return true
2144+
return false, true
21292145
}
21302146

21312147
// NextResultSet prepares the next result set for reading. It returns true if
@@ -2137,18 +2153,28 @@ func (rs *Rows) Next() bool {
21372153
// scanning. If there are further result sets they may not have rows in the result
21382154
// set.
21392155
func (rs *Rows) NextResultSet() bool {
2140-
if rs.isClosed() {
2156+
var doClose bool
2157+
defer func() {
2158+
if doClose {
2159+
rs.Close()
2160+
}
2161+
}()
2162+
rs.closemu.RLock()
2163+
defer rs.closemu.RUnlock()
2164+
2165+
if rs.closed {
21412166
return false
21422167
}
2168+
21432169
rs.lastcols = nil
21442170
nextResultSet, ok := rs.rowsi.(driver.RowsNextResultSet)
21452171
if !ok {
2146-
rs.Close()
2172+
doClose = true
21472173
return false
21482174
}
21492175
rs.lasterr = nextResultSet.NextResultSet()
21502176
if rs.lasterr != nil {
2151-
rs.Close()
2177+
doClose = true
21522178
return false
21532179
}
21542180
return true
@@ -2157,6 +2183,8 @@ func (rs *Rows) NextResultSet() bool {
21572183
// Err returns the error, if any, that was encountered during iteration.
21582184
// Err may be called after an explicit or implicit Close.
21592185
func (rs *Rows) Err() error {
2186+
rs.closemu.RLock()
2187+
defer rs.closemu.RUnlock()
21602188
if rs.lasterr == io.EOF {
21612189
return nil
21622190
}
@@ -2167,7 +2195,9 @@ func (rs *Rows) Err() error {
21672195
// Columns returns an error if the rows are closed, or if the rows
21682196
// are from QueryRow and there was a deferred error.
21692197
func (rs *Rows) Columns() ([]string, error) {
2170-
if rs.isClosed() {
2198+
rs.closemu.RLock()
2199+
defer rs.closemu.RUnlock()
2200+
if rs.closed {
21712201
return nil, errors.New("sql: Rows are closed")
21722202
}
21732203
if rs.rowsi == nil {
@@ -2179,7 +2209,9 @@ func (rs *Rows) Columns() ([]string, error) {
21792209
// ColumnTypes returns column information such as column type, length,
21802210
// and nullable. Some information may not be available from some drivers.
21812211
func (rs *Rows) ColumnTypes() ([]*ColumnType, error) {
2182-
if rs.isClosed() {
2212+
rs.closemu.RLock()
2213+
defer rs.closemu.RUnlock()
2214+
if rs.closed {
21832215
return nil, errors.New("sql: Rows are closed")
21842216
}
21852217
if rs.rowsi == nil {
@@ -2329,9 +2361,13 @@ func rowsColumnInfoSetup(rowsi driver.Rows) []*ColumnType {
23292361
// For scanning into *bool, the source may be true, false, 1, 0, or
23302362
// string inputs parseable by strconv.ParseBool.
23312363
func (rs *Rows) Scan(dest ...interface{}) error {
2332-
if rs.isClosed() {
2364+
rs.closemu.RLock()
2365+
if rs.closed {
2366+
rs.closemu.RUnlock()
23332367
return errors.New("sql: Rows are closed")
23342368
}
2369+
rs.closemu.RUnlock()
2370+
23352371
if rs.lastcols == nil {
23362372
return errors.New("sql: Scan called without calling Next")
23372373
}
@@ -2351,20 +2387,28 @@ func (rs *Rows) Scan(dest ...interface{}) error {
23512387
// hook through a test only mutex.
23522388
var rowsCloseHook = func() func(*Rows, *error) { return nil }
23532389

2354-
func (rs *Rows) isClosed() bool {
2355-
return atomic.LoadInt32(&rs.closed) != 0
2356-
}
2357-
23582390
// Close closes the Rows, preventing further enumeration. If Next is called
23592391
// and returns false and there are no further result sets,
23602392
// the Rows are closed automatically and it will suffice to check the
23612393
// result of Err. Close is idempotent and does not affect the result of Err.
23622394
func (rs *Rows) Close() error {
2363-
if !atomic.CompareAndSwapInt32(&rs.closed, 0, 1) {
2395+
return rs.close(nil)
2396+
}
2397+
2398+
func (rs *Rows) close(err error) error {
2399+
rs.closemu.Lock()
2400+
defer rs.closemu.Unlock()
2401+
2402+
if rs.closed {
23642403
return nil
23652404
}
2405+
rs.closed = true
2406+
2407+
if rs.lasterr == nil {
2408+
rs.lasterr = err
2409+
}
23662410

2367-
err := rs.rowsi.Close()
2411+
err = rs.rowsi.Close()
23682412
if fn := rowsCloseHook(); fn != nil {
23692413
fn(rs, &err)
23702414
}

src/database/sql/sql_test.go

+7-3
Original file line numberDiff line numberDiff line change
@@ -313,9 +313,13 @@ func TestQueryContext(t *testing.T) {
313313
got = append(got, r)
314314
index++
315315
}
316-
err = rows.Err()
317-
if err != nil {
318-
t.Fatalf("Err: %v", err)
316+
select {
317+
case <-ctx.Done():
318+
if err := ctx.Err(); err != context.Canceled {
319+
t.Fatalf("context err = %v; want context.Canceled")
320+
}
321+
default:
322+
t.Fatalf("context err = nil; want context.Canceled")
319323
}
320324
want := []row{
321325
{age: 1, name: "Alice"},

0 commit comments

Comments
 (0)