Skip to content

Commit ee55f08

Browse files
committed
net/http/httputil: make ReverseProxy automatically proxy WebSocket requests
Fixes #26937 Change-Id: I6cdc1bad4cf476cd2ea1462b53444eccd8841e14 Reviewed-on: https://go-review.googlesource.com/c/146437 Run-TryBot: Brad Fitzpatrick <[email protected]> TryBot-Result: Gobot Gobot <[email protected]> Reviewed-by: Dmitri Shuralyov <[email protected]>
1 parent de578dc commit ee55f08

File tree

4 files changed

+161
-12
lines changed

4 files changed

+161
-12
lines changed

src/go/build/deps_test.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -436,7 +436,7 @@ var pkgDeps = map[string][]string{
436436
"L4", "NET", "OS", "crypto/tls", "flag", "net/http", "net/http/internal", "crypto/x509",
437437
"golang_org/x/net/http/httpguts",
438438
},
439-
"net/http/httputil": {"L4", "NET", "OS", "context", "net/http", "net/http/internal"},
439+
"net/http/httputil": {"L4", "NET", "OS", "context", "net/http", "net/http/internal", "golang_org/x/net/http/httpguts"},
440440
"net/http/pprof": {"L4", "OS", "html/template", "net/http", "runtime/pprof", "runtime/trace"},
441441
"net/rpc": {"L4", "NET", "encoding/gob", "html/template", "net/http"},
442442
"net/rpc/jsonrpc": {"L4", "NET", "encoding/json", "net/rpc"},

src/net/http/httputil/reverseproxy.go

+81
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ package httputil
88

99
import (
1010
"context"
11+
"fmt"
1112
"io"
1213
"log"
1314
"net"
@@ -16,6 +17,8 @@ import (
1617
"strings"
1718
"sync"
1819
"time"
20+
21+
"golang_org/x/net/http/httpguts"
1922
)
2023

2124
// ReverseProxy is an HTTP Handler that takes an incoming request and
@@ -199,6 +202,7 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
199202
p.Director(outreq)
200203
outreq.Close = false
201204

205+
reqUpType := upgradeType(outreq.Header)
202206
removeConnectionHeaders(outreq.Header)
203207

204208
// Remove hop-by-hop headers to the backend. Especially
@@ -221,6 +225,13 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
221225
outreq.Header.Del(h)
222226
}
223227

228+
// After stripping all the hop-by-hop connection headers above, add back any
229+
// necessary for protocol upgrades, such as for websockets.
230+
if reqUpType != "" {
231+
outreq.Header.Set("Connection", "Upgrade")
232+
outreq.Header.Set("Upgrade", reqUpType)
233+
}
234+
224235
if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil {
225236
// If we aren't the first proxy retain prior
226237
// X-Forwarded-For information as a comma+space
@@ -237,6 +248,12 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
237248
return
238249
}
239250

251+
// Deal with 101 Switching Protocols responses: (WebSocket, h2c, etc)
252+
if res.StatusCode == http.StatusSwitchingProtocols {
253+
p.handleUpgradeResponse(rw, outreq, res)
254+
return
255+
}
256+
240257
removeConnectionHeaders(res.Header)
241258

242259
for _, h := range hopHeaders {
@@ -463,3 +480,67 @@ func (m *maxLatencyWriter) stop() {
463480
m.t.Stop()
464481
}
465482
}
483+
484+
func upgradeType(h http.Header) string {
485+
if !httpguts.HeaderValuesContainsToken(h["Connection"], "Upgrade") {
486+
return ""
487+
}
488+
return strings.ToLower(h.Get("Upgrade"))
489+
}
490+
491+
func (p *ReverseProxy) handleUpgradeResponse(rw http.ResponseWriter, req *http.Request, res *http.Response) {
492+
reqUpType := upgradeType(req.Header)
493+
resUpType := upgradeType(res.Header)
494+
if reqUpType != resUpType {
495+
p.getErrorHandler()(rw, req, fmt.Errorf("backend tried to switch protocol %q when %q was requested", resUpType, reqUpType))
496+
return
497+
}
498+
hj, ok := rw.(http.Hijacker)
499+
if !ok {
500+
p.getErrorHandler()(rw, req, fmt.Errorf("can't switch protocols using non-Hijacker ResponseWriter type %T", rw))
501+
return
502+
}
503+
backConn, ok := res.Body.(io.ReadWriteCloser)
504+
if !ok {
505+
p.getErrorHandler()(rw, req, fmt.Errorf("internal error: 101 switching protocols response with non-writable body"))
506+
return
507+
}
508+
defer backConn.Close()
509+
conn, brw, err := hj.Hijack()
510+
if err != nil {
511+
p.getErrorHandler()(rw, req, fmt.Errorf("Hijack failed on protocol switch: %v", err))
512+
return
513+
}
514+
defer conn.Close()
515+
res.Body = nil // so res.Write only writes the headers; we have res.Body in backConn above
516+
if err := res.Write(brw); err != nil {
517+
p.getErrorHandler()(rw, req, fmt.Errorf("response write: %v", err))
518+
return
519+
}
520+
if err := brw.Flush(); err != nil {
521+
p.getErrorHandler()(rw, req, fmt.Errorf("response flush: %v", err))
522+
return
523+
}
524+
errc := make(chan error, 1)
525+
spc := switchProtocolCopier{user: conn, backend: backConn}
526+
go spc.copyToBackend(errc)
527+
go spc.copyFromBackend(errc)
528+
<-errc
529+
return
530+
}
531+
532+
// switchProtocolCopier exists so goroutines proxying data back and
533+
// forth have nice names in stacks.
534+
type switchProtocolCopier struct {
535+
user, backend io.ReadWriter
536+
}
537+
538+
func (c switchProtocolCopier) copyFromBackend(errc chan<- error) {
539+
_, err := io.Copy(c.user, c.backend)
540+
errc <- err
541+
}
542+
543+
func (c switchProtocolCopier) copyToBackend(errc chan<- error) {
544+
_, err := io.Copy(c.backend, c.user)
545+
errc <- err
546+
}

src/net/http/httputil/reverseproxy_test.go

+78-10
Original file line numberDiff line numberDiff line change
@@ -153,15 +153,20 @@ func TestReverseProxy(t *testing.T) {
153153
func TestReverseProxyStripHeadersPresentInConnection(t *testing.T) {
154154
const fakeConnectionToken = "X-Fake-Connection-Token"
155155
const backendResponse = "I am the backend"
156+
157+
// someConnHeader is some arbitrary header to be declared as a hop-by-hop header
158+
// in the Request's Connection header.
159+
const someConnHeader = "X-Some-Conn-Header"
160+
156161
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
157162
if c := r.Header.Get(fakeConnectionToken); c != "" {
158163
t.Errorf("handler got header %q = %q; want empty", fakeConnectionToken, c)
159164
}
160-
if c := r.Header.Get("Upgrade"); c != "" {
161-
t.Errorf("handler got header %q = %q; want empty", "Upgrade", c)
165+
if c := r.Header.Get(someConnHeader); c != "" {
166+
t.Errorf("handler got header %q = %q; want empty", someConnHeader, c)
162167
}
163-
w.Header().Set("Connection", "Upgrade, "+fakeConnectionToken)
164-
w.Header().Set("Upgrade", "should be deleted")
168+
w.Header().Set("Connection", someConnHeader+", "+fakeConnectionToken)
169+
w.Header().Set(someConnHeader, "should be deleted")
165170
w.Header().Set(fakeConnectionToken, "should be deleted")
166171
io.WriteString(w, backendResponse)
167172
}))
@@ -173,15 +178,15 @@ func TestReverseProxyStripHeadersPresentInConnection(t *testing.T) {
173178
proxyHandler := NewSingleHostReverseProxy(backendURL)
174179
frontend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
175180
proxyHandler.ServeHTTP(w, r)
176-
if c := r.Header.Get("Upgrade"); c != "original value" {
177-
t.Errorf("handler modified header %q = %q; want %q", "Upgrade", c, "original value")
181+
if c := r.Header.Get(someConnHeader); c != "original value" {
182+
t.Errorf("handler modified header %q = %q; want %q", someConnHeader, c, "original value")
178183
}
179184
}))
180185
defer frontend.Close()
181186

182187
getReq, _ := http.NewRequest("GET", frontend.URL, nil)
183-
getReq.Header.Set("Connection", "Upgrade, "+fakeConnectionToken)
184-
getReq.Header.Set("Upgrade", "original value")
188+
getReq.Header.Set("Connection", someConnHeader+", "+fakeConnectionToken)
189+
getReq.Header.Set(someConnHeader, "original value")
185190
getReq.Header.Set(fakeConnectionToken, "should be deleted")
186191
res, err := frontend.Client().Do(getReq)
187192
if err != nil {
@@ -195,8 +200,8 @@ func TestReverseProxyStripHeadersPresentInConnection(t *testing.T) {
195200
if got, want := string(bodyBytes), backendResponse; got != want {
196201
t.Errorf("got body %q; want %q", got, want)
197202
}
198-
if c := res.Header.Get("Upgrade"); c != "" {
199-
t.Errorf("handler got header %q = %q; want empty", "Upgrade", c)
203+
if c := res.Header.Get(someConnHeader); c != "" {
204+
t.Errorf("handler got header %q = %q; want empty", someConnHeader, c)
200205
}
201206
if c := res.Header.Get(fakeConnectionToken); c != "" {
202207
t.Errorf("handler got header %q = %q; want empty", fakeConnectionToken, c)
@@ -980,3 +985,66 @@ func TestSelectFlushInterval(t *testing.T) {
980985
})
981986
}
982987
}
988+
989+
func TestReverseProxyWebSocket(t *testing.T) {
990+
backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
991+
if upgradeType(r.Header) != "websocket" {
992+
t.Error("unexpected backend request")
993+
http.Error(w, "unexpected request", 400)
994+
return
995+
}
996+
c, _, err := w.(http.Hijacker).Hijack()
997+
if err != nil {
998+
t.Error(err)
999+
return
1000+
}
1001+
defer c.Close()
1002+
io.WriteString(c, "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\nUpgrade: WebSocket\r\n\r\n")
1003+
bs := bufio.NewScanner(c)
1004+
if !bs.Scan() {
1005+
t.Errorf("backend failed to read line from client: %v", bs.Err())
1006+
return
1007+
}
1008+
fmt.Fprintf(c, "backend got %q\n", bs.Text())
1009+
}))
1010+
defer backendServer.Close()
1011+
1012+
backURL, _ := url.Parse(backendServer.URL)
1013+
rproxy := NewSingleHostReverseProxy(backURL)
1014+
rproxy.ErrorLog = log.New(ioutil.Discard, "", 0) // quiet for tests
1015+
1016+
frontendProxy := httptest.NewServer(rproxy)
1017+
defer frontendProxy.Close()
1018+
1019+
req, _ := http.NewRequest("GET", frontendProxy.URL, nil)
1020+
req.Header.Set("Connection", "Upgrade")
1021+
req.Header.Set("Upgrade", "websocket")
1022+
1023+
c := frontendProxy.Client()
1024+
res, err := c.Do(req)
1025+
if err != nil {
1026+
t.Fatal(err)
1027+
}
1028+
if res.StatusCode != 101 {
1029+
t.Fatalf("status = %v; want 101", res.Status)
1030+
}
1031+
if upgradeType(res.Header) != "websocket" {
1032+
t.Fatalf("not websocket upgrade; got %#v", res.Header)
1033+
}
1034+
rwc, ok := res.Body.(io.ReadWriteCloser)
1035+
if !ok {
1036+
t.Fatalf("response body is of type %T; does not implement ReadWriteCloser", res.Body)
1037+
}
1038+
defer rwc.Close()
1039+
1040+
io.WriteString(rwc, "Hello\n")
1041+
bs := bufio.NewScanner(rwc)
1042+
if !bs.Scan() {
1043+
t.Fatalf("Scan: %v", bs.Err())
1044+
}
1045+
got := bs.Text()
1046+
want := `backend got "Hello"`
1047+
if got != want {
1048+
t.Errorf("got %#q, want %#q", got, want)
1049+
}
1050+
}

src/net/http/transport.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -1714,7 +1714,7 @@ func (pc *persistConn) readLoop() {
17141714
alive = false
17151715
}
17161716

1717-
if !hasBody {
1717+
if !hasBody || bodyWritable {
17181718
pc.t.setReqCanceler(rc.req, nil)
17191719

17201720
// Put the idle conn back into the pool before we send the response

0 commit comments

Comments
 (0)