Skip to content

Commit 5201b1a

Browse files
juliensbradfitz
authored andcommitted
http/http/httputil: add ReverseProxy.ErrorHandler
This permits specifying an ErrorHandler to customize the RoundTrip error handling if the backend fails to return a response. Fixes #22700 Fixes #21255 Change-Id: I8879f0956e2472a07f584660afa10105ef23bf11 Reviewed-on: https://go-review.googlesource.com/77410 Reviewed-by: Brad Fitzpatrick <[email protected]>
1 parent 86a0e67 commit 5201b1a

File tree

2 files changed

+117
-7
lines changed

2 files changed

+117
-7
lines changed

src/net/http/httputil/reverseproxy.go

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -55,10 +55,23 @@ type ReverseProxy struct {
5555
// copying HTTP response bodies.
5656
BufferPool BufferPool
5757

58-
// ModifyResponse is an optional function that
59-
// modifies the Response from the backend.
60-
// If it returns an error, the proxy returns a StatusBadGateway error.
58+
// ModifyResponse is an optional function that modifies the
59+
// Response from the backend. It is called if the backend
60+
// returns a response at all, with any HTTP status code.
61+
// If the backend is unreachable, the optional ErrorHandler is
62+
// called without any call to ModifyResponse.
63+
//
64+
// If ModifyResponse returns an error, ErrorHandler is called
65+
// with its error value. If ErrorHandler is nil, its default
66+
// implementation is used.
6167
ModifyResponse func(*http.Response) error
68+
69+
// ErrorHandler is an optional function that handles errors
70+
// reaching the backend or errors from ModifyResponse.
71+
//
72+
// If nil, the default is to log the provided error and return
73+
// a 502 Status Bad Gateway response.
74+
ErrorHandler func(http.ResponseWriter, *http.Request, error)
6275
}
6376

6477
// A BufferPool is an interface for getting and returning temporary
@@ -141,6 +154,18 @@ var hopHeaders = []string{
141154
"Upgrade",
142155
}
143156

157+
func (p *ReverseProxy) defaultErrorHandler(rw http.ResponseWriter, req *http.Request, err error) {
158+
p.logf("http: proxy error: %v", err)
159+
rw.WriteHeader(http.StatusBadGateway)
160+
}
161+
162+
func (p *ReverseProxy) getErrorHandler() func(http.ResponseWriter, *http.Request, error) {
163+
if p.ErrorHandler != nil {
164+
return p.ErrorHandler
165+
}
166+
return p.defaultErrorHandler
167+
}
168+
144169
func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
145170
transport := p.Transport
146171
if transport == nil {
@@ -206,8 +231,7 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
206231

207232
res, err := transport.RoundTrip(outreq)
208233
if err != nil {
209-
p.logf("http: proxy error: %v", err)
210-
rw.WriteHeader(http.StatusBadGateway)
234+
p.getErrorHandler()(rw, outreq, err)
211235
return
212236
}
213237

@@ -219,9 +243,8 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
219243

220244
if p.ModifyResponse != nil {
221245
if err := p.ModifyResponse(res); err != nil {
222-
p.logf("http: proxy error: %v", err)
223-
rw.WriteHeader(http.StatusBadGateway)
224246
res.Body.Close()
247+
p.getErrorHandler()(rw, outreq, err)
225248
return
226249
}
227250
}

src/net/http/httputil/reverseproxy_test.go

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -637,6 +637,93 @@ func TestReverseProxyModifyResponse(t *testing.T) {
637637
}
638638
}
639639

640+
type failingRoundTripper struct{}
641+
642+
func (failingRoundTripper) RoundTrip(*http.Request) (*http.Response, error) {
643+
return nil, errors.New("some error")
644+
}
645+
646+
type staticResponseRoundTripper struct{ res *http.Response }
647+
648+
func (rt staticResponseRoundTripper) RoundTrip(*http.Request) (*http.Response, error) {
649+
return rt.res, nil
650+
}
651+
652+
func TestReverseProxyErrorHandler(t *testing.T) {
653+
tests := []struct {
654+
name string
655+
wantCode int
656+
errorHandler func(http.ResponseWriter, *http.Request, error)
657+
transport http.RoundTripper // defaults to failingRoundTripper
658+
modifyResponse func(*http.Response) error
659+
}{
660+
{
661+
name: "default",
662+
wantCode: http.StatusBadGateway,
663+
},
664+
{
665+
name: "errorhandler",
666+
wantCode: http.StatusTeapot,
667+
errorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { rw.WriteHeader(http.StatusTeapot) },
668+
},
669+
{
670+
name: "modifyresponse_noerr",
671+
transport: staticResponseRoundTripper{
672+
&http.Response{StatusCode: 345, Body: http.NoBody},
673+
},
674+
modifyResponse: func(res *http.Response) error {
675+
res.StatusCode++
676+
return nil
677+
},
678+
errorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { rw.WriteHeader(http.StatusTeapot) },
679+
wantCode: 346,
680+
},
681+
{
682+
name: "modifyresponse_err",
683+
transport: staticResponseRoundTripper{
684+
&http.Response{StatusCode: 345, Body: http.NoBody},
685+
},
686+
modifyResponse: func(res *http.Response) error {
687+
res.StatusCode++
688+
return errors.New("some error to trigger errorHandler")
689+
},
690+
errorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { rw.WriteHeader(http.StatusTeapot) },
691+
wantCode: http.StatusTeapot,
692+
},
693+
}
694+
695+
for _, tt := range tests {
696+
t.Run(tt.name, func(t *testing.T) {
697+
target := &url.URL{
698+
Scheme: "http",
699+
Host: "dummy.tld",
700+
Path: "/",
701+
}
702+
rproxy := NewSingleHostReverseProxy(target)
703+
rproxy.Transport = tt.transport
704+
rproxy.ModifyResponse = tt.modifyResponse
705+
if rproxy.Transport == nil {
706+
rproxy.Transport = failingRoundTripper{}
707+
}
708+
rproxy.ErrorLog = log.New(ioutil.Discard, "", 0) // quiet for tests
709+
if tt.errorHandler != nil {
710+
rproxy.ErrorHandler = tt.errorHandler
711+
}
712+
frontendProxy := httptest.NewServer(rproxy)
713+
defer frontendProxy.Close()
714+
715+
resp, err := http.Get(frontendProxy.URL + "/test")
716+
if err != nil {
717+
t.Fatalf("failed to reach proxy: %v", err)
718+
}
719+
if g, e := resp.StatusCode, tt.wantCode; g != e {
720+
t.Errorf("got res.StatusCode %d; expected %d", g, e)
721+
}
722+
resp.Body.Close()
723+
})
724+
}
725+
}
726+
640727
// Issue 16659: log errors from short read
641728
func TestReverseProxy_CopyBuffer(t *testing.T) {
642729
backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {

0 commit comments

Comments
 (0)