@@ -637,6 +637,93 @@ func TestReverseProxyModifyResponse(t *testing.T) {
637
637
}
638
638
}
639
639
640
+ type failingRoundTripper struct {}
641
+
642
+ func (failingRoundTripper ) RoundTrip (* http.Request ) (* http.Response , error ) {
643
+ return nil , errors .New ("some error" )
644
+ }
645
+
646
+ type staticResponseRoundTripper struct { res * http.Response }
647
+
648
+ func (rt staticResponseRoundTripper ) RoundTrip (* http.Request ) (* http.Response , error ) {
649
+ return rt .res , nil
650
+ }
651
+
652
+ func TestReverseProxyErrorHandler (t * testing.T ) {
653
+ tests := []struct {
654
+ name string
655
+ wantCode int
656
+ errorHandler func (http.ResponseWriter , * http.Request , error )
657
+ transport http.RoundTripper // defaults to failingRoundTripper
658
+ modifyResponse func (* http.Response ) error
659
+ }{
660
+ {
661
+ name : "default" ,
662
+ wantCode : http .StatusBadGateway ,
663
+ },
664
+ {
665
+ name : "errorhandler" ,
666
+ wantCode : http .StatusTeapot ,
667
+ errorHandler : func (rw http.ResponseWriter , req * http.Request , err error ) { rw .WriteHeader (http .StatusTeapot ) },
668
+ },
669
+ {
670
+ name : "modifyresponse_noerr" ,
671
+ transport : staticResponseRoundTripper {
672
+ & http.Response {StatusCode : 345 , Body : http .NoBody },
673
+ },
674
+ modifyResponse : func (res * http.Response ) error {
675
+ res .StatusCode ++
676
+ return nil
677
+ },
678
+ errorHandler : func (rw http.ResponseWriter , req * http.Request , err error ) { rw .WriteHeader (http .StatusTeapot ) },
679
+ wantCode : 346 ,
680
+ },
681
+ {
682
+ name : "modifyresponse_err" ,
683
+ transport : staticResponseRoundTripper {
684
+ & http.Response {StatusCode : 345 , Body : http .NoBody },
685
+ },
686
+ modifyResponse : func (res * http.Response ) error {
687
+ res .StatusCode ++
688
+ return errors .New ("some error to trigger errorHandler" )
689
+ },
690
+ errorHandler : func (rw http.ResponseWriter , req * http.Request , err error ) { rw .WriteHeader (http .StatusTeapot ) },
691
+ wantCode : http .StatusTeapot ,
692
+ },
693
+ }
694
+
695
+ for _ , tt := range tests {
696
+ t .Run (tt .name , func (t * testing.T ) {
697
+ target := & url.URL {
698
+ Scheme : "http" ,
699
+ Host : "dummy.tld" ,
700
+ Path : "/" ,
701
+ }
702
+ rproxy := NewSingleHostReverseProxy (target )
703
+ rproxy .Transport = tt .transport
704
+ rproxy .ModifyResponse = tt .modifyResponse
705
+ if rproxy .Transport == nil {
706
+ rproxy .Transport = failingRoundTripper {}
707
+ }
708
+ rproxy .ErrorLog = log .New (ioutil .Discard , "" , 0 ) // quiet for tests
709
+ if tt .errorHandler != nil {
710
+ rproxy .ErrorHandler = tt .errorHandler
711
+ }
712
+ frontendProxy := httptest .NewServer (rproxy )
713
+ defer frontendProxy .Close ()
714
+
715
+ resp , err := http .Get (frontendProxy .URL + "/test" )
716
+ if err != nil {
717
+ t .Fatalf ("failed to reach proxy: %v" , err )
718
+ }
719
+ if g , e := resp .StatusCode , tt .wantCode ; g != e {
720
+ t .Errorf ("got res.StatusCode %d; expected %d" , g , e )
721
+ }
722
+ resp .Body .Close ()
723
+ })
724
+ }
725
+ }
726
+
640
727
// Issue 16659: log errors from short read
641
728
func TestReverseProxy_CopyBuffer (t * testing.T ) {
642
729
backendServer := httptest .NewServer (http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
0 commit comments