Skip to content

Commit 6f75cf0

Browse files
moredureMikhail Faraponov
authored and
Mikhail Faraponov
committed
net: add context-aware DialTCPContext, DialUDPContext, etc
Fixes #49097
1 parent 6df0957 commit 6f75cf0

File tree

5 files changed

+163
-4
lines changed

5 files changed

+163
-4
lines changed

src/net/dial.go

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ package net
77
import (
88
"context"
99
"internal/nettrace"
10+
"net"
1011
"syscall"
1112
"time"
1213
)
@@ -443,6 +444,97 @@ func (d *Dialer) DialContext(ctx context.Context, network, address string) (Conn
443444
return c, nil
444445
}
445446

447+
func (d *Dialer) dialAddr(ctx context.Context, network string, address Addr) (Conn, error) {
448+
if ctx == nil {
449+
panic("nil context")
450+
}
451+
if address == nil {
452+
return nil, &OpError{Op: "dial", Net: network, Source: d.LocalAddr, Addr: nil, Err: errMissingAddress}
453+
}
454+
deadline := d.deadline(ctx, time.Now())
455+
if !deadline.IsZero() {
456+
if d, ok := ctx.Deadline(); !ok || deadline.Before(d) {
457+
subCtx, cancel := context.WithDeadline(ctx, deadline)
458+
defer cancel()
459+
ctx = subCtx
460+
}
461+
}
462+
if oldCancel := d.Cancel; oldCancel != nil {
463+
subCtx, cancel := context.WithCancel(ctx)
464+
defer cancel()
465+
go func() {
466+
select {
467+
case <-oldCancel:
468+
cancel()
469+
case <-subCtx.Done():
470+
}
471+
}()
472+
ctx = subCtx
473+
}
474+
sd := &sysDialer{
475+
Dialer: *d,
476+
network: network,
477+
address: address.String(),
478+
}
479+
return sd.dialSingle(ctx, address)
480+
}
481+
482+
func (d *Dialer) DialTCP(ctx context.Context, network string, address *TCPAddr) (Conn, error) {
483+
if _, ok := d.LocalAddr.(*TCPAddr); !ok && d.LocalAddr != nil {
484+
return nil, &AddrError{Err: "mismatched local address type", Addr: d.LocalAddr.String()}
485+
}
486+
switch network {
487+
case "tcp", "tcp4", "tcp6":
488+
default:
489+
return nil, &OpError{Op: "dial", Net: network, Source: d.LocalAddr, Addr: address, Err: UnknownNetworkError(network)}
490+
}
491+
c, err := d.dialAddr(ctx, network, address)
492+
if err != nil {
493+
return nil, err
494+
}
495+
if tc := c.(*TCPConn); d.KeepAlive >= 0 {
496+
setKeepAlive(tc.fd, true)
497+
ka := d.KeepAlive
498+
if d.KeepAlive == 0 {
499+
ka = defaultTCPKeepAlive
500+
}
501+
setKeepAlivePeriod(tc.fd, ka)
502+
testHookSetKeepAlive(ka)
503+
}
504+
return c, nil
505+
}
506+
507+
func (d *Dialer) DialIP(ctx context.Context, network string, address *IPAddr) (Conn, error) {
508+
if _, ok := d.LocalAddr.(*IPAddr); !ok && d.LocalAddr != nil {
509+
return nil, &AddrError{Err: "mismatched local address type", Addr: d.LocalAddr.String()}
510+
}
511+
return d.dialAddr(ctx, network, address)
512+
}
513+
514+
func (d *Dialer) DialUnix(ctx context.Context, network string, address *UnixAddr) (Conn, error) {
515+
if _, ok := d.LocalAddr.(*UnixAddr); !ok && d.LocalAddr != nil {
516+
return nil, &AddrError{Err: "mismatched local address type", Addr: d.LocalAddr.String()}
517+
}
518+
switch network {
519+
case "unix", "unixgram", "unixpacket":
520+
default:
521+
return nil, &OpError{Op: "dial", Net: network, Source: d.LocalAddr, Addr: address, Err: UnknownNetworkError(network)}
522+
}
523+
return d.dialAddr(ctx, network, address)
524+
}
525+
526+
func (d *Dialer) DialUDP(ctx context.Context, network string, address *UDPAddr) (Conn, error) {
527+
if _, ok := d.LocalAddr.(*UDPAddr); !ok && d.LocalAddr != nil {
528+
return nil, &AddrError{Err: "mismatched local address type", Addr: d.LocalAddr.String()}
529+
}
530+
switch network {
531+
case "udp", "udp4", "udp6":
532+
default:
533+
return nil, &OpError{Op: "dial", Net: network, Source: d.LocalAddr, Addr: address, Err: UnknownNetworkError(network)}
534+
}
535+
return d.dialAddr(ctx, network, address)
536+
}
537+
446538
// dialParallel races two copies of dialSerial, giving the first a
447539
// head start. It returns the first established connection and
448540
// closes the others. Otherwise it returns an error from the first

src/net/iprawsock.go

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,11 +209,28 @@ func newIPConn(fd *netFD) *IPConn { return &IPConn{conn{fd}} }
209209
// If the IP field of raddr is nil or an unspecified IP address, the
210210
// local system is assumed.
211211
func DialIP(network string, laddr, raddr *IPAddr) (*IPConn, error) {
212+
return DialIPContext(context.Background(), network, laddr, raddr)
213+
}
214+
215+
// DialIPContext acts like DialIP but connects using
216+
// the provided context.
217+
//
218+
// The provided Context must be non-nil.
219+
//
220+
// The network must be an IP network name; see func Dial for details.
221+
//
222+
// If laddr is nil, a local address is automatically chosen.
223+
// If the IP field of raddr is nil or an unspecified IP address, the
224+
// local system is assumed.
225+
func DialIPContext(ctx context.Context, network string, laddr, raddr *IPAddr) (*IPConn, error) {
226+
if ctx == nil {
227+
panic("nil context")
228+
}
212229
if raddr == nil {
213230
return nil, &OpError{Op: "dial", Net: network, Source: laddr.opAddr(), Addr: nil, Err: errMissingAddress}
214231
}
215232
sd := &sysDialer{network: network, address: raddr.String()}
216-
c, err := sd.dialIP(context.Background(), laddr, raddr)
233+
c, err := sd.dialIP(ctx, laddr, raddr)
217234
if err != nil {
218235
return nil, &OpError{Op: "dial", Net: network, Source: laddr.opAddr(), Addr: raddr.opAddr(), Err: err}
219236
}

src/net/tcpsock.go

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,23 @@ func newTCPConn(fd *netFD) *TCPConn {
231231
// If the IP field of raddr is nil or an unspecified IP address, the
232232
// local system is assumed.
233233
func DialTCP(network string, laddr, raddr *TCPAddr) (*TCPConn, error) {
234+
return DialTCPContext(context.Background(), network, laddr, raddr)
235+
}
236+
237+
// DialTCPContext acts like DialTCP but connects using
238+
// the provided context.
239+
//
240+
// The provided Context must be non-nil.
241+
//
242+
// The network must be a TCP network name; see func Dial for details.
243+
//
244+
// If laddr is nil, a local address is automatically chosen.
245+
// If the IP field of raddr is nil or an unspecified IP address, the
246+
// local system is assumed.
247+
func DialTCPContext(ctx context.Context, network string, laddr, raddr *TCPAddr) (*TCPConn, error) {
248+
if ctx == nil {
249+
panic("nil context")
250+
}
234251
switch network {
235252
case "tcp", "tcp4", "tcp6":
236253
default:
@@ -240,7 +257,7 @@ func DialTCP(network string, laddr, raddr *TCPAddr) (*TCPConn, error) {
240257
return nil, &OpError{Op: "dial", Net: network, Source: laddr.opAddr(), Addr: nil, Err: errMissingAddress}
241258
}
242259
sd := &sysDialer{network: network, address: raddr.String()}
243-
c, err := sd.dialTCP(context.Background(), laddr, raddr)
260+
c, err := sd.dialTCP(ctx, laddr, raddr)
244261
if err != nil {
245262
return nil, &OpError{Op: "dial", Net: network, Source: laddr.opAddr(), Addr: raddr.opAddr(), Err: err}
246263
}

src/net/udpsock.go

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,23 @@ func newUDPConn(fd *netFD) *UDPConn { return &UDPConn{conn{fd}} }
284284
// If the IP field of raddr is nil or an unspecified IP address, the
285285
// local system is assumed.
286286
func DialUDP(network string, laddr, raddr *UDPAddr) (*UDPConn, error) {
287+
return DialUDPContext(context.Background(), network, laddr, raddr)
288+
}
289+
290+
// DialUDPContext acts like DialUDP but connects using
291+
// the provided context.
292+
//
293+
// The provided Context must be non-nil.
294+
//
295+
// The network must be a UDP network name; see func Dial for details.
296+
//
297+
// If laddr is nil, a local address is automatically chosen.
298+
// If the IP field of raddr is nil or an unspecified IP address, the
299+
// local system is assumed.
300+
func DialUDPContext(ctx context.Context, network string, laddr, raddr *UDPAddr) (*UDPConn, error) {
301+
if ctx == nil {
302+
panic("nil context")
303+
}
287304
switch network {
288305
case "udp", "udp4", "udp6":
289306
default:
@@ -293,7 +310,7 @@ func DialUDP(network string, laddr, raddr *UDPAddr) (*UDPConn, error) {
293310
return nil, &OpError{Op: "dial", Net: network, Source: laddr.opAddr(), Addr: nil, Err: errMissingAddress}
294311
}
295312
sd := &sysDialer{network: network, address: raddr.String()}
296-
c, err := sd.dialUDP(context.Background(), laddr, raddr)
313+
c, err := sd.dialUDP(ctx, laddr, raddr)
297314
if err != nil {
298315
return nil, &OpError{Op: "dial", Net: network, Source: laddr.opAddr(), Addr: raddr.opAddr(), Err: err}
299316
}

src/net/unixsock.go

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,13 +201,29 @@ func newUnixConn(fd *netFD) *UnixConn { return &UnixConn{conn{fd}} }
201201
// If laddr is non-nil, it is used as the local address for the
202202
// connection.
203203
func DialUnix(network string, laddr, raddr *UnixAddr) (*UnixConn, error) {
204+
return DialUnixContext(context.Background(), network, laddr, raddr)
205+
}
206+
207+
// DialUnixContext acts like DialUnix but connects using
208+
// the provided context.
209+
//
210+
// The provided Context must be non-nil.
211+
//
212+
// The network must be a Unix network name; see func Dial for details.
213+
//
214+
// If laddr is non-nil, it is used as the local address for the
215+
// connection.
216+
func DialUnixContext(ctx context.Context, network string, laddr, raddr *UnixAddr) (*UnixConn, error) {
217+
if ctx == nil {
218+
panic("nil context")
219+
}
204220
switch network {
205221
case "unix", "unixgram", "unixpacket":
206222
default:
207223
return nil, &OpError{Op: "dial", Net: network, Source: laddr.opAddr(), Addr: raddr.opAddr(), Err: UnknownNetworkError(network)}
208224
}
209225
sd := &sysDialer{network: network, address: raddr.String()}
210-
c, err := sd.dialUnix(context.Background(), laddr, raddr)
226+
c, err := sd.dialUnix(ctx, laddr, raddr)
211227
if err != nil {
212228
return nil, &OpError{Op: "dial", Net: network, Source: laddr.opAddr(), Addr: raddr.opAddr(), Err: err}
213229
}

0 commit comments

Comments
 (0)