Skip to content

Commit 8c6a855

Browse files
moredureMikhail Faraponov
authored and
Mikhail Faraponov
committed
net: add context-aware DialTCPContext, DialUDPContext, etc
Fixes #49097
1 parent 6df0957 commit 8c6a855

File tree

5 files changed

+180
-4
lines changed

5 files changed

+180
-4
lines changed

src/net/dial.go

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ package net
77
import (
88
"context"
99
"internal/nettrace"
10+
"net"
1011
"syscall"
1112
"time"
1213
)
@@ -443,6 +444,114 @@ func (d *Dialer) DialContext(ctx context.Context, network, address string) (Conn
443444
return c, nil
444445
}
445446

447+
func (d *Dialer) dialAddr(ctx context.Context, network string, address Addr) (Conn, error) {
448+
if ctx == nil {
449+
panic("nil context")
450+
}
451+
deadline := d.deadline(ctx, time.Now())
452+
if !deadline.IsZero() {
453+
if d, ok := ctx.Deadline(); !ok || deadline.Before(d) {
454+
subCtx, cancel := context.WithDeadline(ctx, deadline)
455+
defer cancel()
456+
ctx = subCtx
457+
}
458+
}
459+
if oldCancel := d.Cancel; oldCancel != nil {
460+
subCtx, cancel := context.WithCancel(ctx)
461+
defer cancel()
462+
go func() {
463+
select {
464+
case <-oldCancel:
465+
cancel()
466+
case <-subCtx.Done():
467+
}
468+
}()
469+
ctx = subCtx
470+
}
471+
sd := &sysDialer{
472+
Dialer: *d,
473+
network: network,
474+
address: address.String(),
475+
}
476+
return sd.dialSingle(ctx, address)
477+
}
478+
479+
func (d *Dialer) DialTCP(ctx context.Context, network string, address *TCPAddr) (Conn, error) {
480+
switch d.LocalAddr.(type) {
481+
case *TCPAddr, nil:
482+
switch network {
483+
case "tcp", "tcp4", "tcp6":
484+
default:
485+
return nil, &OpError{Op: "dial", Net: network, Source: d.LocalAddr, Addr: address, Err: UnknownNetworkError(network)}
486+
}
487+
if address == nil {
488+
return nil, &OpError{Op: "dial", Net: network, Source: d.LocalAddr, Addr: nil, Err: errMissingAddress}
489+
}
490+
c, err := d.dialAddr(ctx, network, address)
491+
if err != nil {
492+
return nil, err
493+
}
494+
if tc := c.(*TCPConn); d.KeepAlive >= 0 {
495+
setKeepAlive(tc.fd, true)
496+
ka := d.KeepAlive
497+
if d.KeepAlive == 0 {
498+
ka = defaultTCPKeepAlive
499+
}
500+
setKeepAlivePeriod(tc.fd, ka)
501+
testHookSetKeepAlive(ka)
502+
}
503+
return c, nil
504+
default:
505+
return nil, &AddrError{Err: "mismatched local address type", Addr: d.LocalAddr.String()}
506+
}
507+
}
508+
509+
func (d *Dialer) DialIP(ctx context.Context, network string, address *IPAddr) (Conn, error) {
510+
switch d.LocalAddr.(type) {
511+
case *IPAddr, nil:
512+
if address == nil {
513+
return nil, &OpError{Op: "dial", Net: network, Source: d.LocalAddr, Addr: nil, Err: errMissingAddress}
514+
}
515+
return d.dialAddr(ctx, network, address)
516+
default:
517+
return nil, &AddrError{Err: "mismatched local address type", Addr: d.LocalAddr.String()}
518+
}
519+
}
520+
521+
func (d *Dialer) DialUnix(ctx context.Context, network string, address *UnixAddr) (Conn, error) {
522+
switch d.LocalAddr.(type) {
523+
case *UnixAddr, nil:
524+
switch network {
525+
case "unix", "unixgram", "unixpacket":
526+
default:
527+
return nil, &OpError{Op: "dial", Net: network, Source: d.LocalAddr, Addr: address, Err: UnknownNetworkError(network)}
528+
}
529+
if address == nil {
530+
return nil, &OpError{Op: "dial", Net: network, Source: d.LocalAddr, Addr: nil, Err: errMissingAddress}
531+
}
532+
return d.dialAddr(ctx, network, address)
533+
default:
534+
return nil, &AddrError{Err: "mismatched local address type", Addr: d.LocalAddr.String()}
535+
}
536+
}
537+
538+
func (d *Dialer) DialUDP(ctx context.Context, network string, address *UDPAddr) (Conn, error) {
539+
switch d.LocalAddr.(type) {
540+
case *UDPAddr, nil:
541+
switch network {
542+
case "udp", "udp4", "udp6":
543+
default:
544+
return nil, &OpError{Op: "dial", Net: network, Source: d.LocalAddr, Addr: address, Err: UnknownNetworkError(network)}
545+
}
546+
if address == nil {
547+
return nil, &OpError{Op: "dial", Net: network, Source: d.LocalAddr, Addr: nil, Err: errMissingAddress}
548+
}
549+
return d.dialAddr(ctx, network, address)
550+
default:
551+
return nil, &AddrError{Err: "mismatched local address type", Addr: d.LocalAddr.String()}
552+
}
553+
}
554+
446555
// dialParallel races two copies of dialSerial, giving the first a
447556
// head start. It returns the first established connection and
448557
// closes the others. Otherwise it returns an error from the first

src/net/iprawsock.go

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,11 +209,28 @@ func newIPConn(fd *netFD) *IPConn { return &IPConn{conn{fd}} }
209209
// If the IP field of raddr is nil or an unspecified IP address, the
210210
// local system is assumed.
211211
func DialIP(network string, laddr, raddr *IPAddr) (*IPConn, error) {
212+
return DialIPContext(context.Background(), network, laddr, raddr)
213+
}
214+
215+
// DialIPContext acts like DialIP but connects using
216+
// the provided context.
217+
//
218+
// The provided Context must be non-nil.
219+
//
220+
// The network must be an IP network name; see func Dial for details.
221+
//
222+
// If laddr is nil, a local address is automatically chosen.
223+
// If the IP field of raddr is nil or an unspecified IP address, the
224+
// local system is assumed.
225+
func DialIPContext(ctx context.Context, network string, laddr, raddr *IPAddr) (*IPConn, error) {
226+
if ctx == nil {
227+
panic("nil context")
228+
}
212229
if raddr == nil {
213230
return nil, &OpError{Op: "dial", Net: network, Source: laddr.opAddr(), Addr: nil, Err: errMissingAddress}
214231
}
215232
sd := &sysDialer{network: network, address: raddr.String()}
216-
c, err := sd.dialIP(context.Background(), laddr, raddr)
233+
c, err := sd.dialIP(ctx, laddr, raddr)
217234
if err != nil {
218235
return nil, &OpError{Op: "dial", Net: network, Source: laddr.opAddr(), Addr: raddr.opAddr(), Err: err}
219236
}

src/net/tcpsock.go

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,23 @@ func newTCPConn(fd *netFD) *TCPConn {
231231
// If the IP field of raddr is nil or an unspecified IP address, the
232232
// local system is assumed.
233233
func DialTCP(network string, laddr, raddr *TCPAddr) (*TCPConn, error) {
234+
return DialTCPContext(context.Background(), network, laddr, raddr)
235+
}
236+
237+
// DialTCPContext acts like DialTCP but connects using
238+
// the provided context.
239+
//
240+
// The provided Context must be non-nil.
241+
//
242+
// The network must be a TCP network name; see func Dial for details.
243+
//
244+
// If laddr is nil, a local address is automatically chosen.
245+
// If the IP field of raddr is nil or an unspecified IP address, the
246+
// local system is assumed.
247+
func DialTCPContext(ctx context.Context, network string, laddr, raddr *TCPAddr) (*TCPConn, error) {
248+
if ctx == nil {
249+
panic("nil context")
250+
}
234251
switch network {
235252
case "tcp", "tcp4", "tcp6":
236253
default:
@@ -240,7 +257,7 @@ func DialTCP(network string, laddr, raddr *TCPAddr) (*TCPConn, error) {
240257
return nil, &OpError{Op: "dial", Net: network, Source: laddr.opAddr(), Addr: nil, Err: errMissingAddress}
241258
}
242259
sd := &sysDialer{network: network, address: raddr.String()}
243-
c, err := sd.dialTCP(context.Background(), laddr, raddr)
260+
c, err := sd.dialTCP(ctx, laddr, raddr)
244261
if err != nil {
245262
return nil, &OpError{Op: "dial", Net: network, Source: laddr.opAddr(), Addr: raddr.opAddr(), Err: err}
246263
}

src/net/udpsock.go

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,23 @@ func newUDPConn(fd *netFD) *UDPConn { return &UDPConn{conn{fd}} }
284284
// If the IP field of raddr is nil or an unspecified IP address, the
285285
// local system is assumed.
286286
func DialUDP(network string, laddr, raddr *UDPAddr) (*UDPConn, error) {
287+
return DialUDPContext(context.Background(), network, laddr, raddr)
288+
}
289+
290+
// DialUDPContext acts like DialUDP but connects using
291+
// the provided context.
292+
//
293+
// The provided Context must be non-nil.
294+
//
295+
// The network must be a UDP network name; see func Dial for details.
296+
//
297+
// If laddr is nil, a local address is automatically chosen.
298+
// If the IP field of raddr is nil or an unspecified IP address, the
299+
// local system is assumed.
300+
func DialUDPContext(ctx context.Context, network string, laddr, raddr *UDPAddr) (*UDPConn, error) {
301+
if ctx == nil {
302+
panic("nil context")
303+
}
287304
switch network {
288305
case "udp", "udp4", "udp6":
289306
default:
@@ -293,7 +310,7 @@ func DialUDP(network string, laddr, raddr *UDPAddr) (*UDPConn, error) {
293310
return nil, &OpError{Op: "dial", Net: network, Source: laddr.opAddr(), Addr: nil, Err: errMissingAddress}
294311
}
295312
sd := &sysDialer{network: network, address: raddr.String()}
296-
c, err := sd.dialUDP(context.Background(), laddr, raddr)
313+
c, err := sd.dialUDP(ctx, laddr, raddr)
297314
if err != nil {
298315
return nil, &OpError{Op: "dial", Net: network, Source: laddr.opAddr(), Addr: raddr.opAddr(), Err: err}
299316
}

src/net/unixsock.go

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,13 +201,29 @@ func newUnixConn(fd *netFD) *UnixConn { return &UnixConn{conn{fd}} }
201201
// If laddr is non-nil, it is used as the local address for the
202202
// connection.
203203
func DialUnix(network string, laddr, raddr *UnixAddr) (*UnixConn, error) {
204+
return DialUnixContext(context.Background(), network, laddr, raddr)
205+
}
206+
207+
// DialUnixContext acts like DialUnix but connects using
208+
// the provided context.
209+
//
210+
// The provided Context must be non-nil.
211+
//
212+
// The network must be a Unix network name; see func Dial for details.
213+
//
214+
// If laddr is non-nil, it is used as the local address for the
215+
// connection.
216+
func DialUnixContext(ctx context.Context, network string, laddr, raddr *UnixAddr) (*UnixConn, error) {
217+
if ctx == nil {
218+
panic("nil context")
219+
}
204220
switch network {
205221
case "unix", "unixgram", "unixpacket":
206222
default:
207223
return nil, &OpError{Op: "dial", Net: network, Source: laddr.opAddr(), Addr: raddr.opAddr(), Err: UnknownNetworkError(network)}
208224
}
209225
sd := &sysDialer{network: network, address: raddr.String()}
210-
c, err := sd.dialUnix(context.Background(), laddr, raddr)
226+
c, err := sd.dialUnix(ctx, laddr, raddr)
211227
if err != nil {
212228
return nil, &OpError{Op: "dial", Net: network, Source: laddr.opAddr(), Addr: raddr.opAddr(), Err: err}
213229
}

0 commit comments

Comments
 (0)