Skip to content

Commit 0d2c43c

Browse files
fraenkeldmitshur
authored andcommitted
[internal-branch.go1.17-vendor] http2: close the request body if needed
As per client.Do and Request.Body, the transport is responsible to close the request Body. If there was an error or non 1xx/2xx status code, the transport will wait for the body writer to complete. If there is no data available to read, the body writer will block indefinitely. To prevent this, the body will be closed if it hasn't already. If there was a 1xx/2xx status code, the body will be closed eventually. Updates golang/go#49077 Change-Id: I9a4a5f13658122c562baf915e2c0c8992a023278 Reviewed-on: https://go-review.googlesource.com/c/net/+/323689 Reviewed-by: Damien Neil <[email protected]> Trust: Damien Neil <[email protected]> Trust: Alexander Rakoczy <[email protected]> Run-TryBot: Damien Neil <[email protected]> TryBot-Result: Go Bot <[email protected]> Reviewed-on: https://go-review.googlesource.com/c/net/+/357671 Reviewed-by: Dmitri Shuralyov <[email protected]>
1 parent 5627bb0 commit 0d2c43c

File tree

2 files changed

+74
-30
lines changed

2 files changed

+74
-30
lines changed

http2/transport.go

+29-30
Original file line numberDiff line numberDiff line change
@@ -385,8 +385,13 @@ func (cs *clientStream) abortRequestBodyWrite(err error) {
385385
}
386386
cc := cs.cc
387387
cc.mu.Lock()
388-
cs.stopReqBody = err
389-
cc.cond.Broadcast()
388+
if cs.stopReqBody == nil {
389+
cs.stopReqBody = err
390+
if cs.req.Body != nil {
391+
cs.req.Body.Close()
392+
}
393+
cc.cond.Broadcast()
394+
}
390395
cc.mu.Unlock()
391396
}
392397

@@ -1110,40 +1115,28 @@ func (cc *ClientConn) roundTrip(req *http.Request) (res *http.Response, gotErrAf
11101115
return res, false, nil
11111116
}
11121117

1118+
handleError := func(err error) (*http.Response, bool, error) {
1119+
if !hasBody || bodyWritten {
1120+
cc.writeStreamReset(cs.ID, ErrCodeCancel, nil)
1121+
} else {
1122+
bodyWriter.cancel()
1123+
cs.abortRequestBodyWrite(errStopReqBodyWriteAndCancel)
1124+
<-bodyWriter.resc
1125+
}
1126+
cc.forgetStreamID(cs.ID)
1127+
return nil, cs.getStartedWrite(), err
1128+
}
1129+
11131130
for {
11141131
select {
11151132
case re := <-readLoopResCh:
11161133
return handleReadLoopResponse(re)
11171134
case <-respHeaderTimer:
1118-
if !hasBody || bodyWritten {
1119-
cc.writeStreamReset(cs.ID, ErrCodeCancel, nil)
1120-
} else {
1121-
bodyWriter.cancel()
1122-
cs.abortRequestBodyWrite(errStopReqBodyWriteAndCancel)
1123-
<-bodyWriter.resc
1124-
}
1125-
cc.forgetStreamID(cs.ID)
1126-
return nil, cs.getStartedWrite(), errTimeout
1135+
return handleError(errTimeout)
11271136
case <-ctx.Done():
1128-
if !hasBody || bodyWritten {
1129-
cc.writeStreamReset(cs.ID, ErrCodeCancel, nil)
1130-
} else {
1131-
bodyWriter.cancel()
1132-
cs.abortRequestBodyWrite(errStopReqBodyWriteAndCancel)
1133-
<-bodyWriter.resc
1134-
}
1135-
cc.forgetStreamID(cs.ID)
1136-
return nil, cs.getStartedWrite(), ctx.Err()
1137+
return handleError(ctx.Err())
11371138
case <-req.Cancel:
1138-
if !hasBody || bodyWritten {
1139-
cc.writeStreamReset(cs.ID, ErrCodeCancel, nil)
1140-
} else {
1141-
bodyWriter.cancel()
1142-
cs.abortRequestBodyWrite(errStopReqBodyWriteAndCancel)
1143-
<-bodyWriter.resc
1144-
}
1145-
cc.forgetStreamID(cs.ID)
1146-
return nil, cs.getStartedWrite(), errRequestCanceled
1139+
return handleError(errRequestCanceled)
11471140
case <-cs.peerReset:
11481141
// processResetStream already removed the
11491142
// stream from the streams map; no need for
@@ -1290,7 +1283,13 @@ func (cs *clientStream) writeRequestBody(body io.Reader, bodyCloser io.Closer) (
12901283
// Request.Body is closed by the Transport,
12911284
// and in multiple cases: server replies <=299 and >299
12921285
// while still writing request body
1293-
cerr := bodyCloser.Close()
1286+
var cerr error
1287+
cc.mu.Lock()
1288+
if cs.stopReqBody == nil {
1289+
cs.stopReqBody = errStopReqBodyWrite
1290+
cerr = bodyCloser.Close()
1291+
}
1292+
cc.mu.Unlock()
12941293
if err == nil {
12951294
err = cerr
12961295
}

http2/transport_test.go

+45
Original file line numberDiff line numberDiff line change
@@ -4899,3 +4899,48 @@ func TestTransportServerResetStreamAtHeaders(t *testing.T) {
48994899
}
49004900
res.Body.Close()
49014901
}
4902+
4903+
type closeChecker struct {
4904+
io.ReadCloser
4905+
closed chan struct{}
4906+
}
4907+
4908+
func (rc *closeChecker) Close() error {
4909+
close(rc.closed)
4910+
return rc.ReadCloser.Close()
4911+
}
4912+
4913+
func TestTransportCloseRequestBody(t *testing.T) {
4914+
var statusCode int
4915+
st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
4916+
w.WriteHeader(statusCode)
4917+
}, optOnlyServer)
4918+
defer st.Close()
4919+
4920+
tr := &Transport{TLSClientConfig: tlsConfigInsecure}
4921+
defer tr.CloseIdleConnections()
4922+
ctx := context.Background()
4923+
cc, err := tr.dialClientConn(ctx, st.ts.Listener.Addr().String(), false)
4924+
if err != nil {
4925+
t.Fatal(err)
4926+
}
4927+
4928+
for _, status := range []int{200, 401} {
4929+
t.Run(fmt.Sprintf("status=%d", status), func(t *testing.T) {
4930+
statusCode = status
4931+
pr, pw := io.Pipe()
4932+
pipeClosed := make(chan struct{})
4933+
req, err := http.NewRequest("PUT", "https://dummy.tld/", &closeChecker{pr, pipeClosed})
4934+
if err != nil {
4935+
t.Fatal(err)
4936+
}
4937+
res, err := cc.RoundTrip(req)
4938+
if err != nil {
4939+
t.Fatal(err)
4940+
}
4941+
res.Body.Close()
4942+
pw.Close()
4943+
<-pipeClosed
4944+
})
4945+
}
4946+
}

0 commit comments

Comments
 (0)