Skip to content

Commit 2f4d5c3

Browse files
cuiweixiebradfitz
authored andcommitted
net/http: add Transport.OnProxyConnectResponse
Fixes #54299 Change-Id: I3a29527bde7ac71f3824e771982db4257234e9ef Reviewed-on: https://go-review.googlesource.com/c/go/+/447216 Reviewed-by: Brad Fitzpatrick <[email protected]> Run-TryBot: xie cui <[email protected]> TryBot-Result: Gopher Robot <[email protected]> Reviewed-by: Bryan Mills <[email protected]> Reviewed-by: Damien Neil <[email protected]>
1 parent cd8d1bc commit 2f4d5c3

File tree

3 files changed

+111
-1
lines changed

3 files changed

+111
-1
lines changed

api/next/54299.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
pkg net/http, type Transport struct, OnProxyConnectResponse func(context.Context, *url.URL, *Request, *Response) error #54299

src/net/http/transport.go

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,11 @@ type Transport struct {
120120
// If Proxy is nil or returns a nil *URL, no proxy is used.
121121
Proxy func(*Request) (*url.URL, error)
122122

123+
// OnProxyConnectResponse is called when the Transport gets an HTTP response from
124+
// a proxy for a CONNECT request. It's called before the check for a 200 OK response.
125+
// If it returns an error, the request fails with that error.
126+
OnProxyConnectResponse func(ctx context.Context, proxyURL *url.URL, connectReq *Request, connectRes *Response) error
127+
123128
// DialContext specifies the dial function for creating unencrypted TCP connections.
124129
// If DialContext is nil (and the deprecated Dial below is also nil),
125130
// then the transport dials using package net.
@@ -309,6 +314,7 @@ func (t *Transport) Clone() *Transport {
309314
t.nextProtoOnce.Do(t.onceSetNextProtoDefaults)
310315
t2 := &Transport{
311316
Proxy: t.Proxy,
317+
OnProxyConnectResponse: t.OnProxyConnectResponse,
312318
DialContext: t.DialContext,
313319
Dial: t.Dial,
314320
DialTLS: t.DialTLS,
@@ -1716,6 +1722,14 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (pconn *pers
17161722
conn.Close()
17171723
return nil, err
17181724
}
1725+
1726+
if t.OnProxyConnectResponse != nil {
1727+
err = t.OnProxyConnectResponse(ctx, cm.proxyURL, connectReq, resp)
1728+
if err != nil {
1729+
return nil, err
1730+
}
1731+
}
1732+
17191733
if resp.StatusCode != 200 {
17201734
_, text, ok := strings.Cut(resp.Status, " ")
17211735
conn.Close()

src/net/http/transport_test.go

Lines changed: 96 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1465,6 +1465,98 @@ func TestTransportProxy(t *testing.T) {
14651465
}
14661466
}
14671467

1468+
func TestOnProxyConnectResponse(t *testing.T) {
1469+
1470+
var tcases = []struct {
1471+
proxyStatusCode int
1472+
err error
1473+
}{
1474+
{
1475+
StatusOK,
1476+
nil,
1477+
},
1478+
{
1479+
StatusForbidden,
1480+
errors.New("403"),
1481+
},
1482+
}
1483+
for _, tcase := range tcases {
1484+
h1 := HandlerFunc(func(w ResponseWriter, r *Request) {
1485+
1486+
})
1487+
1488+
h2 := HandlerFunc(func(w ResponseWriter, r *Request) {
1489+
// Implement an entire CONNECT proxy
1490+
if r.Method == "CONNECT" {
1491+
if tcase.proxyStatusCode != StatusOK {
1492+
w.WriteHeader(tcase.proxyStatusCode)
1493+
return
1494+
}
1495+
hijacker, ok := w.(Hijacker)
1496+
if !ok {
1497+
t.Errorf("hijack not allowed")
1498+
return
1499+
}
1500+
clientConn, _, err := hijacker.Hijack()
1501+
if err != nil {
1502+
t.Errorf("hijacking failed")
1503+
return
1504+
}
1505+
res := &Response{
1506+
StatusCode: StatusOK,
1507+
Proto: "HTTP/1.1",
1508+
ProtoMajor: 1,
1509+
ProtoMinor: 1,
1510+
Header: make(Header),
1511+
}
1512+
1513+
targetConn, err := net.Dial("tcp", r.URL.Host)
1514+
if err != nil {
1515+
t.Errorf("net.Dial(%q) failed: %v", r.URL.Host, err)
1516+
return
1517+
}
1518+
1519+
if err := res.Write(clientConn); err != nil {
1520+
t.Errorf("Writing 200 OK failed: %v", err)
1521+
return
1522+
}
1523+
1524+
go io.Copy(targetConn, clientConn)
1525+
go func() {
1526+
io.Copy(clientConn, targetConn)
1527+
targetConn.Close()
1528+
}()
1529+
}
1530+
})
1531+
ts := newClientServerTest(t, https1Mode, h1).ts
1532+
proxy := newClientServerTest(t, https1Mode, h2).ts
1533+
1534+
pu, err := url.Parse(proxy.URL)
1535+
if err != nil {
1536+
t.Fatal(err)
1537+
}
1538+
1539+
c := proxy.Client()
1540+
1541+
c.Transport.(*Transport).Proxy = ProxyURL(pu)
1542+
c.Transport.(*Transport).OnProxyConnectResponse = func(ctx context.Context, proxyURL *url.URL, connectReq *Request, connectRes *Response) error {
1543+
if proxyURL.String() != pu.String() {
1544+
t.Errorf("proxy url got %s, want %s", proxyURL, pu)
1545+
}
1546+
1547+
if "https://"+connectReq.URL.String() != ts.URL {
1548+
t.Errorf("connect url got %s, want %s", connectReq.URL, ts.URL)
1549+
}
1550+
return tcase.err
1551+
}
1552+
if _, err := c.Head(ts.URL); err != nil {
1553+
if tcase.err != nil && !strings.Contains(err.Error(), tcase.err.Error()) {
1554+
t.Errorf("got %v, want %v", err, tcase.err)
1555+
}
1556+
}
1557+
}
1558+
}
1559+
14681560
// Issue 28012: verify that the Transport closes its TCP connection to http proxies
14691561
// when they're slow to reply to HTTPS CONNECT responses.
14701562
func TestTransportProxyHTTPSConnectLeak(t *testing.T) {
@@ -5906,7 +5998,10 @@ func testTransportRequestWriteRoundTrip(t *testing.T, mode testMode) {
59065998

59075999
func TestTransportClone(t *testing.T) {
59086000
tr := &Transport{
5909-
Proxy: func(*Request) (*url.URL, error) { panic("") },
6001+
Proxy: func(*Request) (*url.URL, error) { panic("") },
6002+
OnProxyConnectResponse: func(ctx context.Context, proxyURL *url.URL, connectReq *Request, connectRes *Response) error {
6003+
return nil
6004+
},
59106005
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { panic("") },
59116006
Dial: func(network, addr string) (net.Conn, error) { panic("") },
59126007
DialTLS: func(network, addr string) (net.Conn, error) { panic("") },

0 commit comments

Comments
 (0)