Skip to content

Commit eb55a0c

Browse files
rosenhousebradfitz
authored andcommitted
net/http: add DialTLSContext hook to Transport
Fixes #21526 Change-Id: I2f8215cd671641cddfa8499f8a8c0130db93dbc6 Reviewed-on: https://go-review.googlesource.com/c/go/+/61291 Reviewed-by: Brad Fitzpatrick <[email protected]> Run-TryBot: Brad Fitzpatrick <[email protected]>
1 parent c9a4b01 commit eb55a0c

File tree

2 files changed

+117
-9
lines changed

2 files changed

+117
-9
lines changed

src/net/http/transport.go

+32-9
Original file line numberDiff line numberDiff line change
@@ -142,15 +142,24 @@ type Transport struct {
142142
// If both are set, DialContext takes priority.
143143
Dial func(network, addr string) (net.Conn, error)
144144

145-
// DialTLS specifies an optional dial function for creating
145+
// DialTLSContext specifies an optional dial function for creating
146146
// TLS connections for non-proxied HTTPS requests.
147147
//
148-
// If DialTLS is nil, Dial and TLSClientConfig are used.
148+
// If DialTLSContext is nil (and the deprecated DialTLS below is also nil),
149+
// DialContext and TLSClientConfig are used.
149150
//
150-
// If DialTLS is set, the Dial hook is not used for HTTPS
151+
// If DialTLSContext is set, the Dial and DialContext hooks are not used for HTTPS
151152
// requests and the TLSClientConfig and TLSHandshakeTimeout
152153
// are ignored. The returned net.Conn is assumed to already be
153154
// past the TLS handshake.
155+
DialTLSContext func(ctx context.Context, network, addr string) (net.Conn, error)
156+
157+
// DialTLS specifies an optional dial function for creating
158+
// TLS connections for non-proxied HTTPS requests.
159+
//
160+
// Deprecated: Use DialTLSContext instead, which allows the transport
161+
// to cancel dials as soon as they are no longer needed.
162+
// If both are set, DialTLSContext takes priority.
154163
DialTLS func(network, addr string) (net.Conn, error)
155164

156165
// TLSClientConfig specifies the TLS configuration to use with
@@ -286,6 +295,7 @@ func (t *Transport) Clone() *Transport {
286295
DialContext: t.DialContext,
287296
Dial: t.Dial,
288297
DialTLS: t.DialTLS,
298+
DialTLSContext: t.DialTLSContext,
289299
TLSHandshakeTimeout: t.TLSHandshakeTimeout,
290300
DisableKeepAlives: t.DisableKeepAlives,
291301
DisableCompression: t.DisableCompression,
@@ -324,6 +334,10 @@ type h2Transport interface {
324334
CloseIdleConnections()
325335
}
326336

337+
func (t *Transport) hasCustomTLSDialer() bool {
338+
return t.DialTLS != nil || t.DialTLSContext != nil
339+
}
340+
327341
// onceSetNextProtoDefaults initializes TLSNextProto.
328342
// It must be called via t.nextProtoOnce.Do.
329343
func (t *Transport) onceSetNextProtoDefaults() {
@@ -352,7 +366,7 @@ func (t *Transport) onceSetNextProtoDefaults() {
352366
// Transport.
353367
return
354368
}
355-
if !t.ForceAttemptHTTP2 && (t.TLSClientConfig != nil || t.Dial != nil || t.DialTLS != nil || t.DialContext != nil) {
369+
if !t.ForceAttemptHTTP2 && (t.TLSClientConfig != nil || t.Dial != nil || t.DialContext != nil || t.hasCustomTLSDialer()) {
356370
// Be conservative and don't automatically enable
357371
// http2 if they've specified a custom TLS config or
358372
// custom dialers. Let them opt-in themselves via
@@ -1185,6 +1199,18 @@ func (q *wantConnQueue) cleanFront() (cleaned bool) {
11851199
}
11861200
}
11871201

1202+
func (t *Transport) customDialTLS(ctx context.Context, network, addr string) (conn net.Conn, err error) {
1203+
if t.DialTLSContext != nil {
1204+
conn, err = t.DialTLSContext(ctx, network, addr)
1205+
} else {
1206+
conn, err = t.DialTLS(network, addr)
1207+
}
1208+
if conn == nil && err == nil {
1209+
err = errors.New("net/http: Transport.DialTLS or DialTLSContext returned (nil, nil)")
1210+
}
1211+
return
1212+
}
1213+
11881214
// getConn dials and creates a new persistConn to the target as
11891215
// specified in the connectMethod. This includes doing a proxy CONNECT
11901216
// and/or setting up TLS. If this doesn't return an error, the persistConn
@@ -1435,15 +1461,12 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (pconn *pers
14351461
}
14361462
return err
14371463
}
1438-
if cm.scheme() == "https" && t.DialTLS != nil {
1464+
if cm.scheme() == "https" && t.hasCustomTLSDialer() {
14391465
var err error
1440-
pconn.conn, err = t.DialTLS("tcp", cm.addr())
1466+
pconn.conn, err = t.customDialTLS(ctx, "tcp", cm.addr())
14411467
if err != nil {
14421468
return nil, wrapErr(err)
14431469
}
1444-
if pconn.conn == nil {
1445-
return nil, wrapErr(errors.New("net/http: Transport.DialTLS returned (nil, nil)"))
1446-
}
14471470
if tc, ok := pconn.conn.(*tls.Conn); ok {
14481471
// Handshake here, in case DialTLS didn't. TLSNextProto below
14491472
// depends on it for knowing the connection state.

src/net/http/transport_test.go

+85
Original file line numberDiff line numberDiff line change
@@ -3506,6 +3506,90 @@ func TestTransportDialTLS(t *testing.T) {
35063506
}
35073507
}
35083508

3509+
func TestTransportDialContext(t *testing.T) {
3510+
setParallel(t)
3511+
defer afterTest(t)
3512+
var mu sync.Mutex // guards following
3513+
var gotReq bool
3514+
var receivedContext context.Context
3515+
3516+
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
3517+
mu.Lock()
3518+
gotReq = true
3519+
mu.Unlock()
3520+
}))
3521+
defer ts.Close()
3522+
c := ts.Client()
3523+
c.Transport.(*Transport).DialContext = func(ctx context.Context, netw, addr string) (net.Conn, error) {
3524+
mu.Lock()
3525+
receivedContext = ctx
3526+
mu.Unlock()
3527+
return net.Dial(netw, addr)
3528+
}
3529+
3530+
req, err := NewRequest("GET", ts.URL, nil)
3531+
if err != nil {
3532+
t.Fatal(err)
3533+
}
3534+
ctx := context.WithValue(context.Background(), "some-key", "some-value")
3535+
res, err := c.Do(req.WithContext(ctx))
3536+
if err != nil {
3537+
t.Fatal(err)
3538+
}
3539+
res.Body.Close()
3540+
mu.Lock()
3541+
if !gotReq {
3542+
t.Error("didn't get request")
3543+
}
3544+
if receivedContext != ctx {
3545+
t.Error("didn't receive correct context")
3546+
}
3547+
}
3548+
3549+
func TestTransportDialTLSContext(t *testing.T) {
3550+
setParallel(t)
3551+
defer afterTest(t)
3552+
var mu sync.Mutex // guards following
3553+
var gotReq bool
3554+
var receivedContext context.Context
3555+
3556+
ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) {
3557+
mu.Lock()
3558+
gotReq = true
3559+
mu.Unlock()
3560+
}))
3561+
defer ts.Close()
3562+
c := ts.Client()
3563+
c.Transport.(*Transport).DialTLSContext = func(ctx context.Context, netw, addr string) (net.Conn, error) {
3564+
mu.Lock()
3565+
receivedContext = ctx
3566+
mu.Unlock()
3567+
c, err := tls.Dial(netw, addr, c.Transport.(*Transport).TLSClientConfig)
3568+
if err != nil {
3569+
return nil, err
3570+
}
3571+
return c, c.Handshake()
3572+
}
3573+
3574+
req, err := NewRequest("GET", ts.URL, nil)
3575+
if err != nil {
3576+
t.Fatal(err)
3577+
}
3578+
ctx := context.WithValue(context.Background(), "some-key", "some-value")
3579+
res, err := c.Do(req.WithContext(ctx))
3580+
if err != nil {
3581+
t.Fatal(err)
3582+
}
3583+
res.Body.Close()
3584+
mu.Lock()
3585+
if !gotReq {
3586+
t.Error("didn't get request")
3587+
}
3588+
if receivedContext != ctx {
3589+
t.Error("didn't receive correct context")
3590+
}
3591+
}
3592+
35093593
// Test for issue 8755
35103594
// Ensure that if a proxy returns an error, it is exposed by RoundTrip
35113595
func TestRoundTripReturnsProxyError(t *testing.T) {
@@ -5577,6 +5661,7 @@ func TestTransportClone(t *testing.T) {
55775661
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { panic("") },
55785662
Dial: func(network, addr string) (net.Conn, error) { panic("") },
55795663
DialTLS: func(network, addr string) (net.Conn, error) { panic("") },
5664+
DialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) { panic("") },
55805665
TLSClientConfig: new(tls.Config),
55815666
TLSHandshakeTimeout: time.Second,
55825667
DisableKeepAlives: true,

0 commit comments

Comments
 (0)