Skip to content

Commit 9017642

Browse files
neilddmitshur
authored andcommitted
[release-branch.go1.14] net/http: fix cancelation of requests with a readTrackingBody wrapper
Use the original *Request in the reqCanceler map, not the transient wrapper created to handle body rewinding. Change the key of reqCanceler to a struct{*Request}, to make it more difficult to accidentally use the wrong request as the key. Updates #40453. Fixes #41016. Change-Id: I4e61ee9ff2c794fb4c920a3a66c9a0458693d757 Reviewed-on: https://go-review.googlesource.com/c/go/+/245357 Run-TryBot: Damien Neil <[email protected]> TryBot-Result: Gobot Gobot <[email protected]> Reviewed-by: Russ Cox <[email protected]> Reviewed-on: https://go-review.googlesource.com/c/go/+/250299 Run-TryBot: Dmitri Shuralyov <[email protected]> Reviewed-by: Damien Neil <[email protected]>
1 parent fae8e09 commit 9017642

File tree

2 files changed

+85
-30
lines changed

2 files changed

+85
-30
lines changed

src/net/http/transport.go

Lines changed: 41 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ type Transport struct {
100100
idleLRU connLRU
101101

102102
reqMu sync.Mutex
103-
reqCanceler map[*Request]func(error)
103+
reqCanceler map[cancelKey]func(error)
104104

105105
altMu sync.Mutex // guards changing altProto only
106106
altProto atomic.Value // of nil or map[string]RoundTripper, key is URI scheme
@@ -273,6 +273,13 @@ type Transport struct {
273273
ForceAttemptHTTP2 bool
274274
}
275275

276+
// A cancelKey is the key of the reqCanceler map.
277+
// We wrap the *Request in this type since we want to use the original request,
278+
// not any transient one created by roundTrip.
279+
type cancelKey struct {
280+
req *Request
281+
}
282+
276283
func (t *Transport) writeBufferSize() int {
277284
if t.WriteBufferSize > 0 {
278285
return t.WriteBufferSize
@@ -433,9 +440,10 @@ func ProxyURL(fixedURL *url.URL) func(*Request) (*url.URL, error) {
433440
// optional extra headers to write and stores any error to return
434441
// from roundTrip.
435442
type transportRequest struct {
436-
*Request // original request, not to be mutated
437-
extra Header // extra headers to write, or nil
438-
trace *httptrace.ClientTrace // optional
443+
*Request // original request, not to be mutated
444+
extra Header // extra headers to write, or nil
445+
trace *httptrace.ClientTrace // optional
446+
cancelKey cancelKey
439447

440448
mu sync.Mutex // guards err
441449
err error // first setError value for mapRoundTripError to consider
@@ -512,6 +520,7 @@ func (t *Transport) roundTrip(req *Request) (*Response, error) {
512520
}
513521

514522
origReq := req
523+
cancelKey := cancelKey{origReq}
515524
req = setupRewindBody(req)
516525

517526
if altRT := t.alternateRoundTripper(req); altRT != nil {
@@ -546,7 +555,7 @@ func (t *Transport) roundTrip(req *Request) (*Response, error) {
546555
}
547556

548557
// treq gets modified by roundTrip, so we need to recreate for each retry.
549-
treq := &transportRequest{Request: req, trace: trace}
558+
treq := &transportRequest{Request: req, trace: trace, cancelKey: cancelKey}
550559
cm, err := t.connectMethodForRequest(treq)
551560
if err != nil {
552561
req.closeBody()
@@ -559,15 +568,15 @@ func (t *Transport) roundTrip(req *Request) (*Response, error) {
559568
// to send it requests.
560569
pconn, err := t.getConn(treq, cm)
561570
if err != nil {
562-
t.setReqCanceler(req, nil)
571+
t.setReqCanceler(cancelKey, nil)
563572
req.closeBody()
564573
return nil, err
565574
}
566575

567576
var resp *Response
568577
if pconn.alt != nil {
569578
// HTTP/2 path.
570-
t.setReqCanceler(req, nil) // not cancelable with CancelRequest
579+
t.setReqCanceler(cancelKey, nil) // not cancelable with CancelRequest
571580
resp, err = pconn.alt.RoundTrip(req)
572581
} else {
573582
resp, err = pconn.roundTrip(treq)
@@ -756,14 +765,14 @@ func (t *Transport) CloseIdleConnections() {
756765
// cancelable context instead. CancelRequest cannot cancel HTTP/2
757766
// requests.
758767
func (t *Transport) CancelRequest(req *Request) {
759-
t.cancelRequest(req, errRequestCanceled)
768+
t.cancelRequest(cancelKey{req}, errRequestCanceled)
760769
}
761770

762771
// Cancel an in-flight request, recording the error value.
763-
func (t *Transport) cancelRequest(req *Request, err error) {
772+
func (t *Transport) cancelRequest(key cancelKey, err error) {
764773
t.reqMu.Lock()
765-
cancel := t.reqCanceler[req]
766-
delete(t.reqCanceler, req)
774+
cancel := t.reqCanceler[key]
775+
delete(t.reqCanceler, key)
767776
t.reqMu.Unlock()
768777
if cancel != nil {
769778
cancel(err)
@@ -1096,34 +1105,34 @@ func (t *Transport) removeIdleConnLocked(pconn *persistConn) bool {
10961105
return removed
10971106
}
10981107

1099-
func (t *Transport) setReqCanceler(r *Request, fn func(error)) {
1108+
func (t *Transport) setReqCanceler(key cancelKey, fn func(error)) {
11001109
t.reqMu.Lock()
11011110
defer t.reqMu.Unlock()
11021111
if t.reqCanceler == nil {
1103-
t.reqCanceler = make(map[*Request]func(error))
1112+
t.reqCanceler = make(map[cancelKey]func(error))
11041113
}
11051114
if fn != nil {
1106-
t.reqCanceler[r] = fn
1115+
t.reqCanceler[key] = fn
11071116
} else {
1108-
delete(t.reqCanceler, r)
1117+
delete(t.reqCanceler, key)
11091118
}
11101119
}
11111120

11121121
// replaceReqCanceler replaces an existing cancel function. If there is no cancel function
11131122
// for the request, we don't set the function and return false.
11141123
// Since CancelRequest will clear the canceler, we can use the return value to detect if
11151124
// the request was canceled since the last setReqCancel call.
1116-
func (t *Transport) replaceReqCanceler(r *Request, fn func(error)) bool {
1125+
func (t *Transport) replaceReqCanceler(key cancelKey, fn func(error)) bool {
11171126
t.reqMu.Lock()
11181127
defer t.reqMu.Unlock()
1119-
_, ok := t.reqCanceler[r]
1128+
_, ok := t.reqCanceler[key]
11201129
if !ok {
11211130
return false
11221131
}
11231132
if fn != nil {
1124-
t.reqCanceler[r] = fn
1133+
t.reqCanceler[key] = fn
11251134
} else {
1126-
delete(t.reqCanceler, r)
1135+
delete(t.reqCanceler, key)
11271136
}
11281137
return true
11291138
}
@@ -1327,12 +1336,12 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (pc *persi
13271336
// set request canceler to some non-nil function so we
13281337
// can detect whether it was cleared between now and when
13291338
// we enter roundTrip
1330-
t.setReqCanceler(req, func(error) {})
1339+
t.setReqCanceler(treq.cancelKey, func(error) {})
13311340
return pc, nil
13321341
}
13331342

13341343
cancelc := make(chan error, 1)
1335-
t.setReqCanceler(req, func(err error) { cancelc <- err })
1344+
t.setReqCanceler(treq.cancelKey, func(err error) { cancelc <- err })
13361345

13371346
// Queue for permission to dial.
13381347
t.queueForDial(w)
@@ -2075,7 +2084,7 @@ func (pc *persistConn) readLoop() {
20752084
}
20762085

20772086
if !hasBody || bodyWritable {
2078-
pc.t.setReqCanceler(rc.req, nil)
2087+
pc.t.setReqCanceler(rc.cancelKey, nil)
20792088

20802089
// Put the idle conn back into the pool before we send the response
20812090
// so if they process it quickly and make another request, they'll
@@ -2148,7 +2157,7 @@ func (pc *persistConn) readLoop() {
21482157
// reading the response body. (or for cancellation or death)
21492158
select {
21502159
case bodyEOF := <-waitForBodyRead:
2151-
pc.t.setReqCanceler(rc.req, nil) // before pc might return to idle pool
2160+
pc.t.setReqCanceler(rc.cancelKey, nil) // before pc might return to idle pool
21522161
alive = alive &&
21532162
bodyEOF &&
21542163
!pc.sawEOF &&
@@ -2162,7 +2171,7 @@ func (pc *persistConn) readLoop() {
21622171
pc.t.CancelRequest(rc.req)
21632172
case <-rc.req.Context().Done():
21642173
alive = false
2165-
pc.t.cancelRequest(rc.req, rc.req.Context().Err())
2174+
pc.t.cancelRequest(rc.cancelKey, rc.req.Context().Err())
21662175
case <-pc.closech:
21672176
alive = false
21682177
}
@@ -2403,8 +2412,9 @@ type responseAndError struct {
24032412
}
24042413

24052414
type requestAndChan struct {
2406-
req *Request
2407-
ch chan responseAndError // unbuffered; always send in select on callerGone
2415+
req *Request
2416+
cancelKey cancelKey
2417+
ch chan responseAndError // unbuffered; always send in select on callerGone
24082418

24092419
// whether the Transport (as opposed to the user client code)
24102420
// added the Accept-Encoding gzip header. If the Transport
@@ -2466,7 +2476,7 @@ var (
24662476

24672477
func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err error) {
24682478
testHookEnterRoundTrip()
2469-
if !pc.t.replaceReqCanceler(req.Request, pc.cancelRequest) {
2479+
if !pc.t.replaceReqCanceler(req.cancelKey, pc.cancelRequest) {
24702480
pc.t.putOrCloseIdleConn(pc)
24712481
return nil, errRequestCanceled
24722482
}
@@ -2518,7 +2528,7 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err
25182528

25192529
defer func() {
25202530
if err != nil {
2521-
pc.t.setReqCanceler(req.Request, nil)
2531+
pc.t.setReqCanceler(req.cancelKey, nil)
25222532
}
25232533
}()
25242534

@@ -2534,6 +2544,7 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err
25342544
resc := make(chan responseAndError)
25352545
pc.reqch <- requestAndChan{
25362546
req: req.Request,
2547+
cancelKey: req.cancelKey,
25372548
ch: resc,
25382549
addedGzip: requestedGzip,
25392550
continueCh: continueCh,
@@ -2585,10 +2596,10 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err
25852596
}
25862597
return re.res, nil
25872598
case <-cancelChan:
2588-
pc.t.CancelRequest(req.Request)
2599+
pc.t.cancelRequest(req.cancelKey, errRequestCanceled)
25892600
cancelChan = nil
25902601
case <-ctxDoneChan:
2591-
pc.t.cancelRequest(req.Request, req.Context().Err())
2602+
pc.t.cancelRequest(req.cancelKey, req.Context().Err())
25922603
cancelChan = nil
25932604
ctxDoneChan = nil
25942605
}

src/net/http/transport_test.go

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2346,6 +2346,50 @@ func TestTransportCancelRequest(t *testing.T) {
23462346
}
23472347
}
23482348

2349+
func testTransportCancelRequestInDo(t *testing.T, body io.Reader) {
2350+
setParallel(t)
2351+
defer afterTest(t)
2352+
if testing.Short() {
2353+
t.Skip("skipping test in -short mode")
2354+
}
2355+
unblockc := make(chan bool)
2356+
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
2357+
<-unblockc
2358+
}))
2359+
defer ts.Close()
2360+
defer close(unblockc)
2361+
2362+
c := ts.Client()
2363+
tr := c.Transport.(*Transport)
2364+
2365+
donec := make(chan bool)
2366+
req, _ := NewRequest("GET", ts.URL, body)
2367+
go func() {
2368+
defer close(donec)
2369+
c.Do(req)
2370+
}()
2371+
start := time.Now()
2372+
timeout := 10 * time.Second
2373+
for time.Since(start) < timeout {
2374+
time.Sleep(100 * time.Millisecond)
2375+
tr.CancelRequest(req)
2376+
select {
2377+
case <-donec:
2378+
return
2379+
default:
2380+
}
2381+
}
2382+
t.Errorf("Do of canceled request has not returned after %v", timeout)
2383+
}
2384+
2385+
func TestTransportCancelRequestInDo(t *testing.T) {
2386+
testTransportCancelRequestInDo(t, nil)
2387+
}
2388+
2389+
func TestTransportCancelRequestWithBodyInDo(t *testing.T) {
2390+
testTransportCancelRequestInDo(t, bytes.NewBuffer([]byte{0}))
2391+
}
2392+
23492393
func TestTransportCancelRequestInDial(t *testing.T) {
23502394
defer afterTest(t)
23512395
if testing.Short() {

0 commit comments

Comments
 (0)