Skip to content

Commit b843c7d

Browse files
bradfitzdmitshur
authored andcommitted
[internal-branch.go1.16-vendor] http2: fix Transport connection pool TOCTOU max concurrent stream bug
Updates golang/go#49076 Change-Id: I3e02072403f2f40ade4ef931058bbb5892776754 Reviewed-on: https://go-review.googlesource.com/c/net/+/352469 Run-TryBot: Brad Fitzpatrick <[email protected]> TryBot-Result: Go Bot <[email protected]> Reviewed-by: Damien Neil <[email protected]> Trust: Brad Fitzpatrick <[email protected]> Reviewed-on: https://go-review.googlesource.com/c/net/+/356986 Trust: Damien Neil <[email protected]> Run-TryBot: Damien Neil <[email protected]> Reviewed-by: Dmitri Shuralyov <[email protected]>
1 parent ab1d67c commit b843c7d

File tree

3 files changed

+102
-6
lines changed

3 files changed

+102
-6
lines changed

http2/client_conn_pool.go

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,12 @@ import (
1616

1717
// ClientConnPool manages a pool of HTTP/2 client connections.
1818
type ClientConnPool interface {
19+
// GetClientConn returns a specific HTTP/2 connection (usually
20+
// a TLS-TCP connection) to an HTTP/2 server. On success, the
21+
// returned ClientConn accounts for the upcoming RoundTrip
22+
// call, so the caller should not omit it. If the caller needs
23+
// to, ClientConn.RoundTrip can be called with a bogus
24+
// new(http.Request) to release the stream reservation.
1925
GetClientConn(req *http.Request, addr string) (*ClientConn, error)
2026
MarkDead(*ClientConn)
2127
}
@@ -61,7 +67,7 @@ const (
6167
// during the back-and-forth between net/http and x/net/http2 (when the
6268
// net/http.Transport is upgraded to also speak http2), as well as support
6369
// the case where x/net/http2 is being used directly.
64-
func (p *clientConnPool) shouldTraceGetConn(st clientConnIdleState) bool {
70+
func (p *clientConnPool) shouldTraceGetConn(cc *ClientConn) bool {
6571
// If our Transport wasn't made via ConfigureTransport, always
6672
// trace the GetConn hook if provided, because that means the
6773
// http2 package is being used directly and it's the one
@@ -72,7 +78,9 @@ func (p *clientConnPool) shouldTraceGetConn(st clientConnIdleState) bool {
7278
// Otherwise, only use the GetConn hook if this connection has
7379
// been used previously for other requests. For fresh
7480
// connections, the net/http package does the dialing.
75-
return !st.freshConn
81+
cc.mu.Lock()
82+
defer cc.mu.Unlock()
83+
return cc.nextStreamID == 1
7684
}
7785

7886
func (p *clientConnPool) getClientConn(req *http.Request, addr string, dialOnMiss bool) (*ClientConn, error) {
@@ -89,8 +97,8 @@ func (p *clientConnPool) getClientConn(req *http.Request, addr string, dialOnMis
8997
for {
9098
p.mu.Lock()
9199
for _, cc := range p.conns[addr] {
92-
if st := cc.idleState(); st.canTakeNewRequest {
93-
if p.shouldTraceGetConn(st) {
100+
if cc.ReserveNewRequest() {
101+
if p.shouldTraceGetConn(cc) {
94102
traceGetConn(req, addr)
95103
}
96104
p.mu.Unlock()
@@ -108,7 +116,13 @@ func (p *clientConnPool) getClientConn(req *http.Request, addr string, dialOnMis
108116
if shouldRetryDial(call, req) {
109117
continue
110118
}
111-
return call.res, call.err
119+
cc, err := call.res, call.err
120+
if err != nil {
121+
return nil, err
122+
}
123+
if cc.ReserveNewRequest() {
124+
return cc, nil
125+
}
112126
}
113127
}
114128

http2/transport.go

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,7 @@ type ClientConn struct {
261261
goAway *GoAwayFrame // if non-nil, the GoAwayFrame we received
262262
goAwayDebug string // goAway frame's debug data, retained as a string
263263
streams map[uint32]*clientStream // client-initiated
264+
streamsReserved int // incr by ReserveNewRequest; decr on RoundTrip
264265
nextStreamID uint32
265266
pendingRequests int // requests blocked and waiting to be sent because len(streams) == maxConcurrentStreams
266267
pings map[[8]byte]chan struct{} // in flight ping data to notification channel
@@ -782,12 +783,28 @@ func (cc *ClientConn) setGoAway(f *GoAwayFrame) {
782783

783784
// CanTakeNewRequest reports whether the connection can take a new request,
784785
// meaning it has not been closed or received or sent a GOAWAY.
786+
//
787+
// If the caller is going to immediately make a new request on this
788+
// connection, use ReserveNewRequest instead.
785789
func (cc *ClientConn) CanTakeNewRequest() bool {
786790
cc.mu.Lock()
787791
defer cc.mu.Unlock()
788792
return cc.canTakeNewRequestLocked()
789793
}
790794

795+
// ReserveNewRequest is like CanTakeNewRequest but also reserves a
796+
// concurrent stream in cc. The reservation is decremented on the
797+
// next call to RoundTrip.
798+
func (cc *ClientConn) ReserveNewRequest() bool {
799+
cc.mu.Lock()
800+
defer cc.mu.Unlock()
801+
if st := cc.idleStateLocked(); !st.canTakeNewRequest {
802+
return false
803+
}
804+
cc.streamsReserved++
805+
return true
806+
}
807+
791808
// clientConnIdleState describes the suitability of a client
792809
// connection to initiate a new RoundTrip request.
793810
type clientConnIdleState struct {
@@ -813,7 +830,7 @@ func (cc *ClientConn) idleStateLocked() (st clientConnIdleState) {
813830
// writing it.
814831
maxConcurrentOkay = true
815832
} else {
816-
maxConcurrentOkay = int64(len(cc.streams)+1) <= int64(cc.maxConcurrentStreams)
833+
maxConcurrentOkay = int64(len(cc.streams)+cc.streamsReserved+1) <= int64(cc.maxConcurrentStreams)
817834
}
818835

819836
st.canTakeNewRequest = cc.goAway == nil && !cc.closed && !cc.closing && maxConcurrentOkay &&
@@ -1033,6 +1050,18 @@ func actualContentLength(req *http.Request) int64 {
10331050
return -1
10341051
}
10351052

1053+
func (cc *ClientConn) decrStreamReservations() {
1054+
cc.mu.Lock()
1055+
defer cc.mu.Unlock()
1056+
cc.decrStreamReservationsLocked()
1057+
}
1058+
1059+
func (cc *ClientConn) decrStreamReservationsLocked() {
1060+
if cc.streamsReserved > 0 {
1061+
cc.streamsReserved--
1062+
}
1063+
}
1064+
10361065
func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) {
10371066
resp, _, err := cc.roundTrip(req)
10381067
return resp, err
@@ -1041,6 +1070,7 @@ func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) {
10411070
func (cc *ClientConn) roundTrip(req *http.Request) (res *http.Response, gotErrAfterReqBodyWrite bool, err error) {
10421071
ctx := req.Context()
10431072
if err := checkConnHeaders(req); err != nil {
1073+
cc.decrStreamReservations()
10441074
return nil, false, err
10451075
}
10461076
if cc.idleTimer != nil {
@@ -1049,6 +1079,7 @@ func (cc *ClientConn) roundTrip(req *http.Request) (res *http.Response, gotErrAf
10491079

10501080
trailers, err := commaSeparatedTrailers(req)
10511081
if err != nil {
1082+
cc.decrStreamReservations()
10521083
return nil, false, err
10531084
}
10541085
hasTrailers := trailers != ""
@@ -1062,8 +1093,10 @@ func (cc *ClientConn) roundTrip(req *http.Request) (res *http.Response, gotErrAf
10621093
select {
10631094
case cc.reqHeaderMu <- struct{}{}:
10641095
case <-req.Cancel:
1096+
cc.decrStreamReservations()
10651097
return nil, false, errRequestCanceled
10661098
case <-ctx.Done():
1099+
cc.decrStreamReservations()
10671100
return nil, false, ctx.Err()
10681101
}
10691102
reqHeaderMuNeedsUnlock := true
@@ -1074,6 +1107,11 @@ func (cc *ClientConn) roundTrip(req *http.Request) (res *http.Response, gotErrAf
10741107
}()
10751108

10761109
cc.mu.Lock()
1110+
cc.decrStreamReservationsLocked()
1111+
if req.URL == nil {
1112+
cc.mu.Unlock()
1113+
return nil, false, errNilRequestURL
1114+
}
10771115
if err := cc.awaitOpenSlotForRequest(req); err != nil {
10781116
cc.mu.Unlock()
10791117
return nil, false, err
@@ -1526,9 +1564,14 @@ func (cs *clientStream) awaitFlowControl(maxBytes int) (taken int32, err error)
15261564
}
15271565
}
15281566

1567+
var errNilRequestURL = errors.New("http2: Request.URI is nil")
1568+
15291569
// requires cc.wmu be held.
15301570
func (cc *ClientConn) encodeHeaders(req *http.Request, addGzipHeader bool, trailers string, contentLength int64) ([]byte, error) {
15311571
cc.hbuf.Reset()
1572+
if req.URL == nil {
1573+
return nil, errNilRequestURL
1574+
}
15321575

15331576
host := req.Host
15341577
if host == "" {

http2/transport_test.go

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5296,3 +5296,42 @@ func (p *collectClientsConnPool) GetClientConn(req *http.Request, addr string) (
52965296
func (p *collectClientsConnPool) MarkDead(cc *ClientConn) {
52975297
p.lower.MarkDead(cc)
52985298
}
5299+
5300+
func TestClientConnReservations(t *testing.T) {
5301+
cc := &ClientConn{
5302+
reqHeaderMu: make(chan struct{}, 1),
5303+
streams: make(map[uint32]*clientStream),
5304+
maxConcurrentStreams: initialMaxConcurrentStreams,
5305+
t: &Transport{},
5306+
}
5307+
n := 0
5308+
for n <= initialMaxConcurrentStreams && cc.ReserveNewRequest() {
5309+
n++
5310+
}
5311+
if n != initialMaxConcurrentStreams {
5312+
t.Errorf("did %v reservations; want %v", n, initialMaxConcurrentStreams)
5313+
}
5314+
if _, err := cc.RoundTrip(new(http.Request)); !errors.Is(err, errNilRequestURL) {
5315+
t.Fatalf("RoundTrip error = %v; want errNilRequestURL", err)
5316+
}
5317+
n2 := 0
5318+
for n2 <= 5 && cc.ReserveNewRequest() {
5319+
n2++
5320+
}
5321+
if n2 != 1 {
5322+
t.Fatalf("after one RoundTrip, did %v reservations; want 1", n2)
5323+
}
5324+
5325+
// Use up all the reservations
5326+
for i := 0; i < n; i++ {
5327+
cc.RoundTrip(new(http.Request))
5328+
}
5329+
5330+
n2 = 0
5331+
for n2 <= initialMaxConcurrentStreams && cc.ReserveNewRequest() {
5332+
n2++
5333+
}
5334+
if n2 != n {
5335+
t.Errorf("after reset, reservations = %v; want %v", n2, n)
5336+
}
5337+
}

0 commit comments

Comments
 (0)