Skip to content

Commit 902f8a7

Browse files
committed
net/http: reverseproxy: forward 1xx responses
Support for 1xx responses has recently been merged in net/http (#42597). As discussed in this CL (https://go-review.googlesource.com/c/go/+/269997/comments/1ff70bef_c25a829a), support for forwarding 1xx responses in ReverseProxy has been extracted in this separate patch. According to RFC 7231, "a proxy MUST forward 1xx responses unless the proxy itself requested the generation of the 1xx response". Consequently, all received 1xx responses are automatically forwarded as long as the underlying transport supports ClientTrace.Got1xxResponse. Fixes #26088 Fixes #36734
1 parent cfd202c commit 902f8a7

File tree

2 files changed

+105
-1
lines changed

2 files changed

+105
-1
lines changed

src/net/http/httputil/reverseproxy.go

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import (
1414
"mime"
1515
"net"
1616
"net/http"
17+
"net/http/httptrace"
1718
"net/http/internal/ascii"
1819
"net/textproto"
1920
"net/url"
@@ -40,6 +41,9 @@ import (
4041
// To prevent IP spoofing, be sure to delete any pre-existing
4142
// X-Forwarded-For header coming from the client or
4243
// an untrusted proxy.
44+
//
45+
// 1xx responses are forwarded to the client if the underlying
46+
// transport supports ClientTrace.Got1xxResponse.
4347
type ReverseProxy struct {
4448
// Director must be a function which modifies
4549
// the request into a new request to be sent
@@ -307,6 +311,23 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
307311
}
308312
}
309313

314+
var headerSet bool
315+
trace := &httptrace.ClientTrace{
316+
Got1xxResponse: func(code int, header textproto.MIMEHeader) error {
317+
h := rw.Header()
318+
copyHeader(h, http.Header(header))
319+
rw.WriteHeader(code)
320+
321+
// Clear headers, it's not automatically done by ResponseWriter.WriteHeader() for 1xx responses
322+
for k, _ := range h {
323+
h.Del(k)
324+
}
325+
326+
return nil
327+
},
328+
}
329+
outreq = outreq.WithContext(httptrace.WithClientTrace(outreq.Context(), trace))
330+
310331
res, err := transport.RoundTrip(outreq)
311332
if err != nil {
312333
p.getErrorHandler()(rw, outreq, err)
@@ -332,7 +353,14 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
332353
return
333354
}
334355

335-
copyHeader(rw.Header(), res.Header)
356+
h := rw.Header()
357+
if headerSet {
358+
for k, _ := range h {
359+
h.Del(k)
360+
}
361+
}
362+
363+
copyHeader(h, res.Header)
336364

337365
// The "Trailer" header isn't included in the Transport's response,
338366
// at least for *http.Transport. Build it up from Trailer.

src/net/http/httputil/reverseproxy_test.go

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@ import (
1616
"log"
1717
"net/http"
1818
"net/http/httptest"
19+
"net/http/httptrace"
1920
"net/http/internal/ascii"
21+
"net/textproto"
2022
"net/url"
2123
"os"
2224
"reflect"
@@ -1537,3 +1539,77 @@ func TestJoinURLPath(t *testing.T) {
15371539
}
15381540
}
15391541
}
1542+
1543+
func Test1xxResponses(t *testing.T) {
1544+
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1545+
h := w.Header()
1546+
h.Add("Link", "</style.css>; rel=preload; as=style")
1547+
h.Add("Link", "</script.js>; rel=preload; as=script")
1548+
w.WriteHeader(http.StatusEarlyHints)
1549+
1550+
h.Add("Link", "</foo.js>; rel=preload; as=script")
1551+
w.WriteHeader(http.StatusProcessing)
1552+
1553+
w.Write([]byte("Hello"))
1554+
}))
1555+
defer backend.Close()
1556+
backendURL, err := url.Parse(backend.URL)
1557+
if err != nil {
1558+
t.Fatal(err)
1559+
}
1560+
proxyHandler := NewSingleHostReverseProxy(backendURL)
1561+
proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
1562+
frontend := httptest.NewServer(proxyHandler)
1563+
defer frontend.Close()
1564+
frontendClient := frontend.Client()
1565+
1566+
checkLinkHeaders := func(t *testing.T, expected, got []string) {
1567+
t.Helper()
1568+
1569+
if len(expected) != len(got) {
1570+
t.Errorf("Expected %d link headers; got %d", len(expected), len(got))
1571+
}
1572+
1573+
for i := range expected {
1574+
if expected[i] != got[i] {
1575+
t.Errorf("Expected %q link header; got %q", expected[i], got[i])
1576+
}
1577+
}
1578+
}
1579+
1580+
var respCounter uint8
1581+
trace := &httptrace.ClientTrace{
1582+
Got1xxResponse: func(code int, header textproto.MIMEHeader) error {
1583+
switch code {
1584+
case http.StatusEarlyHints:
1585+
checkLinkHeaders(t, []string{"</style.css>; rel=preload; as=style", "</script.js>; rel=preload; as=script"}, header["Link"])
1586+
case http.StatusProcessing:
1587+
checkLinkHeaders(t, []string{"</style.css>; rel=preload; as=style", "</script.js>; rel=preload; as=script", "</foo.js>; rel=preload; as=script"}, header["Link"])
1588+
default:
1589+
t.Error("Unexpected 1xx response")
1590+
}
1591+
1592+
respCounter++
1593+
1594+
return nil
1595+
},
1596+
}
1597+
req, _ := http.NewRequestWithContext(httptrace.WithClientTrace(context.Background(), trace), "GET", frontend.URL, nil)
1598+
1599+
res, err := frontendClient.Do(req)
1600+
if err != nil {
1601+
t.Fatalf("Get: %v", err)
1602+
}
1603+
1604+
defer res.Body.Close()
1605+
1606+
if respCounter != 2 {
1607+
t.Errorf("Excpected 2 1xx responses; got %d", respCounter)
1608+
}
1609+
checkLinkHeaders(t, []string{"</style.css>; rel=preload; as=style", "</script.js>; rel=preload; as=script", "</foo.js>; rel=preload; as=script"}, res.Header["Link"])
1610+
1611+
body, _ := io.ReadAll(res.Body)
1612+
if string(body) != "Hello" {
1613+
t.Errorf("Read body %q; want Hello", body)
1614+
}
1615+
}

0 commit comments

Comments
 (0)