Skip to content

Commit f7fcdcf

Browse files
committed
net/http: allow upgrading non keepalive connections
Fixes #36381
1 parent 3a6cd4c commit f7fcdcf

File tree

4 files changed

+52
-8
lines changed

4 files changed

+52
-8
lines changed

src/net/http/response.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -352,10 +352,10 @@ func (r *Response) bodyIsWritable() bool {
352352
return ok
353353
}
354354

355-
// isProtocolSwitch reports whether r is a response to a successful
356-
// protocol upgrade.
357-
func (r *Response) isProtocolSwitch() bool {
358-
return r.StatusCode == StatusSwitchingProtocols &&
359-
r.Header.Get("Upgrade") != "" &&
360-
httpguts.HeaderValuesContainsToken(r.Header["Connection"], "Upgrade")
355+
// isProtocolSwitch reports whether the response code and header
356+
// indicate a successful protocol upgrade response.
357+
func isProtocolSwitchResponse(code int, h Header) bool {
358+
return code == StatusSwitchingProtocols &&
359+
h.Get("Upgrade") != "" &&
360+
httpguts.HeaderValuesContainsToken(h["Connection"], "Upgrade")
361361
}

src/net/http/serve_test.go

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6443,3 +6443,44 @@ func BenchmarkResponseStatusLine(b *testing.B) {
64436443
}
64446444
})
64456445
}
6446+
func TestDisableKeepAliveUpgrade(t *testing.T) {
6447+
if testing.Short() {
6448+
t.Skip("skipping in short mode")
6449+
}
6450+
6451+
setParallel(t)
6452+
defer afterTest(t)
6453+
6454+
s := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
6455+
c, _, err := w.(Hijacker).Hijack()
6456+
if err != nil {
6457+
return
6458+
}
6459+
defer c.Close()
6460+
6461+
c.Write([]byte("hello"))
6462+
}))
6463+
defer s.Close()
6464+
6465+
s.Client().Transport.(*Transport).DisableKeepAlives = true
6466+
6467+
resp, err := s.Client().Get("/")
6468+
if err != nil {
6469+
t.Fatal("failed to perform request: %v", err)
6470+
}
6471+
defer resp.Body.Close()
6472+
6473+
rwc, ok := resp.Body.(io.ReadWriteCloser)
6474+
if !ok {
6475+
t.Fatal("body is not a rwc: %T", resp.Body)
6476+
}
6477+
6478+
b, err := ioutil.ReadAll(rwc)
6479+
if err != nil {
6480+
t.Fatal("failed to read rwc: %v", err)
6481+
}
6482+
6483+
if string(b) != "hello" {
6484+
t.Fatal("unexpected value read from rwc: %s", b)
6485+
}
6486+
}

src/net/http/server.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1290,7 +1290,10 @@ func (cw *chunkWriter) writeHeader(p []byte) {
12901290
if !connectionHeaderSet {
12911291
setHeader.connection = "keep-alive"
12921292
}
1293-
} else if !w.req.ProtoAtLeast(1, 1) || w.wantsClose {
1293+
} else if !w.req.ProtoAtLeast(1, 1) ||
1294+
// Only close if the request indicates the client wants to close
1295+
// and we are not upgrading the connection.
1296+
(w.wantsClose && !isProtocolSwitchResponse(w.status, header)) {
12941297
w.closeAfterReply = true
12951298
}
12961299

src/net/http/transport.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2195,7 +2195,7 @@ func (pc *persistConn) readResponse(rc requestAndChan, trace *httptrace.ClientTr
21952195
}
21962196
break
21972197
}
2198-
if resp.isProtocolSwitch() {
2198+
if isProtocolSwitchResponse(resp.StatusCode, resp.Header) {
21992199
resp.Body = newReadWriteCloserBody(pc.br, pc.conn)
22002200
}
22012201

0 commit comments

Comments
 (0)