Skip to content

Commit 40a144b

Browse files
committed
crypto/tls: add Dialer
Fixes #18482 Change-Id: I99d65dc5d824c00093ea61e7445fc121314af87f Reviewed-on: https://go-review.googlesource.com/c/go/+/214977 Run-TryBot: Brad Fitzpatrick <[email protected]> Reviewed-by: Ian Lance Taylor <[email protected]>
1 parent 7004be9 commit 40a144b

File tree

5 files changed

+139
-14
lines changed

5 files changed

+139
-14
lines changed

doc/go1.15.html

+12
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,18 @@ <h3 id="minor_library_changes">Minor changes to the library</h3>
133133
TODO
134134
</p>
135135

136+
<dl id="crypto/tls"><dt><a href="/crypto/tls/">crypto/tls</a></dt>
137+
<dd>
138+
<p><!-- CL 214977 -->
139+
The new
140+
<a href="/pkg/crypto/tls/#Dialer"><code>Dialer</code></a>
141+
type and its
142+
<a href="/pkg/crypto/tls/#Dialer.DialContext"><code>DialContext</code></a>
143+
method permits using a context to both connect and handshake with a TLS server.
144+
</p>
145+
</dd>
146+
</dl>
147+
136148
<dl id="flag"><dt><a href="/pkg/flag/">flag</a></dt>
137149
<dd>
138150
<p><!-- CL 221427 -->

src/crypto/tls/conn.go

+6-2
Original file line numberDiff line numberDiff line change
@@ -1334,8 +1334,12 @@ func (c *Conn) closeNotify() error {
13341334

13351335
// Handshake runs the client or server handshake
13361336
// protocol if it has not yet been run.
1337-
// Most uses of this package need not call Handshake
1338-
// explicitly: the first Read or Write will call it automatically.
1337+
//
1338+
// Most uses of this package need not call Handshake explicitly: the
1339+
// first Read or Write will call it automatically.
1340+
//
1341+
// For control over canceling or setting a timeout on a handshake, use
1342+
// the Dialer's DialContext method.
13391343
func (c *Conn) Handshake() error {
13401344
c.handshakeMutex.Lock()
13411345
defer c.handshakeMutex.Unlock()

src/crypto/tls/tls.go

+78-11
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ package tls
1313

1414
import (
1515
"bytes"
16+
"context"
1617
"crypto"
1718
"crypto/ecdsa"
1819
"crypto/ed25519"
@@ -111,29 +112,35 @@ func (timeoutError) Temporary() bool { return true }
111112
// DialWithDialer interprets a nil configuration as equivalent to the zero
112113
// configuration; see the documentation of Config for the defaults.
113114
func DialWithDialer(dialer *net.Dialer, network, addr string, config *Config) (*Conn, error) {
115+
return dial(context.Background(), dialer, network, addr, config)
116+
}
117+
118+
func dial(ctx context.Context, netDialer *net.Dialer, network, addr string, config *Config) (*Conn, error) {
114119
// We want the Timeout and Deadline values from dialer to cover the
115120
// whole process: TCP connection and TLS handshake. This means that we
116121
// also need to start our own timers now.
117-
timeout := dialer.Timeout
122+
timeout := netDialer.Timeout
118123

119-
if !dialer.Deadline.IsZero() {
120-
deadlineTimeout := time.Until(dialer.Deadline)
124+
if !netDialer.Deadline.IsZero() {
125+
deadlineTimeout := time.Until(netDialer.Deadline)
121126
if timeout == 0 || deadlineTimeout < timeout {
122127
timeout = deadlineTimeout
123128
}
124129
}
125130

126-
var errChannel chan error
127-
131+
// hsErrCh is non-nil if we might not wait for Handshake to complete.
132+
var hsErrCh chan error
133+
if timeout != 0 || ctx.Done() != nil {
134+
hsErrCh = make(chan error, 2)
135+
}
128136
if timeout != 0 {
129-
errChannel = make(chan error, 2)
130137
timer := time.AfterFunc(timeout, func() {
131-
errChannel <- timeoutError{}
138+
hsErrCh <- timeoutError{}
132139
})
133140
defer timer.Stop()
134141
}
135142

136-
rawConn, err := dialer.Dial(network, addr)
143+
rawConn, err := netDialer.DialContext(ctx, network, addr)
137144
if err != nil {
138145
return nil, err
139146
}
@@ -158,14 +165,26 @@ func DialWithDialer(dialer *net.Dialer, network, addr string, config *Config) (*
158165

159166
conn := Client(rawConn, config)
160167

161-
if timeout == 0 {
168+
if hsErrCh == nil {
162169
err = conn.Handshake()
163170
} else {
164171
go func() {
165-
errChannel <- conn.Handshake()
172+
hsErrCh <- conn.Handshake()
166173
}()
167174

168-
err = <-errChannel
175+
select {
176+
case <-ctx.Done():
177+
err = ctx.Err()
178+
case err = <-hsErrCh:
179+
if err != nil {
180+
// If the error was due to the context
181+
// closing, prefer the context's error, rather
182+
// than some random network teardown error.
183+
if e := ctx.Err(); e != nil {
184+
err = e
185+
}
186+
}
187+
}
169188
}
170189

171190
if err != nil {
@@ -186,6 +205,54 @@ func Dial(network, addr string, config *Config) (*Conn, error) {
186205
return DialWithDialer(new(net.Dialer), network, addr, config)
187206
}
188207

208+
// Dialer dials TLS connections given a configuration and a Dialer for the
209+
// underlying connection.
210+
type Dialer struct {
211+
// NetDialer is the optional dialer to use for the TLS connections'
212+
// underlying TCP connections.
213+
// A nil NetDialer is equivalent to the net.Dialer zero value.
214+
NetDialer *net.Dialer
215+
216+
// Config is the TLS configuration to use for new connections.
217+
// A nil configuration is equivalent to the zero
218+
// configuration; see the documentation of Config for the
219+
// defaults.
220+
Config *Config
221+
}
222+
223+
// Dial connects to the given network address and initiates a TLS
224+
// handshake, returning the resulting TLS connection.
225+
//
226+
// The returned Conn, if any, will always be of type *Conn.
227+
func (d *Dialer) Dial(network, addr string) (net.Conn, error) {
228+
return d.DialContext(context.Background(), network, addr)
229+
}
230+
231+
func (d *Dialer) netDialer() *net.Dialer {
232+
if d.NetDialer != nil {
233+
return d.NetDialer
234+
}
235+
return new(net.Dialer)
236+
}
237+
238+
// Dial connects to the given network address and initiates a TLS
239+
// handshake, returning the resulting TLS connection.
240+
//
241+
// The provided Context must be non-nil. If the context expires before
242+
// the connection is complete, an error is returned. Once successfully
243+
// connected, any expiration of the context will not affect the
244+
// connection.
245+
//
246+
// The returned Conn, if any, will always be of type *Conn.
247+
func (d *Dialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) {
248+
c, err := dial(ctx, d.netDialer(), network, addr, d.Config)
249+
if err != nil {
250+
// Don't return c (a typed nil) in an interface.
251+
return nil, err
252+
}
253+
return c, nil
254+
}
255+
189256
// LoadX509KeyPair reads and parses a public/private key pair from a pair
190257
// of files. The files must contain PEM encoded data. The certificate file
191258
// may contain intermediate certificates following the leaf certificate to

src/crypto/tls/tls_test.go

+42
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ package tls
66

77
import (
88
"bytes"
9+
"context"
910
"crypto"
1011
"crypto/x509"
1112
"encoding/json"
@@ -272,6 +273,47 @@ func TestDeadlineOnWrite(t *testing.T) {
272273
}
273274
}
274275

276+
type readerFunc func([]byte) (int, error)
277+
278+
func (f readerFunc) Read(b []byte) (int, error) { return f(b) }
279+
280+
// TestDialer tests that tls.Dialer.DialContext can abort in the middle of a handshake.
281+
// (The other cases are all handled by the existing dial tests in this package, which
282+
// all also flow through the same code shared code paths)
283+
func TestDialer(t *testing.T) {
284+
ln := newLocalListener(t)
285+
defer ln.Close()
286+
287+
unblockServer := make(chan struct{}) // close-only
288+
defer close(unblockServer)
289+
go func() {
290+
conn, err := ln.Accept()
291+
if err != nil {
292+
return
293+
}
294+
defer conn.Close()
295+
<-unblockServer
296+
}()
297+
298+
ctx, cancel := context.WithCancel(context.Background())
299+
d := Dialer{Config: &Config{
300+
Rand: readerFunc(func(b []byte) (n int, err error) {
301+
// By the time crypto/tls wants randomness, that means it has a TCP
302+
// connection, so we're past the Dialer's dial and now blocked
303+
// in a handshake. Cancel our context and see if we get unstuck.
304+
// (Our TCP listener above never reads or writes, so the Handshake
305+
// would otherwise be stuck forever)
306+
cancel()
307+
return len(b), nil
308+
}),
309+
ServerName: "foo",
310+
}}
311+
_, err := d.DialContext(ctx, "tcp", ln.Addr().String())
312+
if err != context.Canceled {
313+
t.Errorf("err = %v; want context.Canceled", err)
314+
}
315+
}
316+
275317
func isTimeoutError(err error) bool {
276318
if ne, ok := err.(net.Error); ok {
277319
return ne.Timeout()

src/go/build/deps_test.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -408,7 +408,7 @@ var pkgDeps = map[string][]string{
408408
// SSL/TLS.
409409
"crypto/tls": {
410410
"L4", "CRYPTO-MATH", "OS", "golang.org/x/crypto/cryptobyte", "golang.org/x/crypto/hkdf",
411-
"container/list", "crypto/x509", "encoding/pem", "net", "syscall", "crypto/ed25519",
411+
"container/list", "context", "crypto/x509", "encoding/pem", "net", "syscall", "crypto/ed25519",
412412
},
413413
"crypto/x509": {
414414
"L4", "CRYPTO-MATH", "OS", "CGO", "crypto/ed25519",

0 commit comments

Comments
 (0)