Skip to content

Commit 14fa2a6

Browse files
committed
net: add context-aware DialTCPContext, DialUDPContext, etc
Fixes #49097
1 parent 6df0957 commit 14fa2a6

File tree

6 files changed

+307
-1
lines changed

6 files changed

+307
-1
lines changed

src/net/dial.go

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ package net
77
import (
88
"context"
99
"internal/nettrace"
10+
"net/netip"
1011
"syscall"
1112
"time"
1213
)
@@ -443,6 +444,114 @@ func (d *Dialer) DialContext(ctx context.Context, network, address string) (Conn
443444
return c, nil
444445
}
445446

447+
func (d *Dialer) dialAddr(ctx context.Context, network string, address Addr) (Conn, error) {
448+
if ctx == nil {
449+
panic("nil context")
450+
}
451+
if address == nil {
452+
return nil, &OpError{Op: "dial", Net: network, Source: d.LocalAddr, Addr: nil, Err: errMissingAddress}
453+
}
454+
deadline := d.deadline(ctx, time.Now())
455+
if !deadline.IsZero() {
456+
if d, ok := ctx.Deadline(); !ok || deadline.Before(d) {
457+
subCtx, cancel := context.WithDeadline(ctx, deadline)
458+
defer cancel()
459+
ctx = subCtx
460+
}
461+
}
462+
if oldCancel := d.Cancel; oldCancel != nil {
463+
subCtx, cancel := context.WithCancel(ctx)
464+
defer cancel()
465+
go func() {
466+
select {
467+
case <-oldCancel:
468+
cancel()
469+
case <-subCtx.Done():
470+
}
471+
}()
472+
ctx = subCtx
473+
}
474+
sd := &sysDialer{
475+
Dialer: *d,
476+
network: network,
477+
address: address.String(),
478+
}
479+
return sd.dialSingle(ctx, address)
480+
}
481+
482+
func (d *Dialer) DialTCP(ctx context.Context, network string, address netip.AddrPort) (Conn, error) {
483+
if _, ok := d.LocalAddr.(*TCPAddr); !ok && d.LocalAddr != nil {
484+
return nil, &AddrError{Err: "mismatched local address type", Addr: d.LocalAddr.String()}
485+
}
486+
var addr *TCPAddr
487+
if address.IsValid() {
488+
addr = &TCPAddr{
489+
IP: address.Addr().AsSlice(),
490+
Port: int(address.Port()),
491+
Zone: address.Addr().Zone(),
492+
}
493+
}
494+
switch network {
495+
case "tcp", "tcp4", "tcp6":
496+
default:
497+
return nil, &OpError{Op: "dial", Net: network, Source: d.LocalAddr, Addr: addr, Err: UnknownNetworkError(network)}
498+
}
499+
c, err := d.dialAddr(ctx, network, addr)
500+
if err != nil {
501+
return nil, err
502+
}
503+
if tc := c.(*TCPConn); d.KeepAlive >= 0 {
504+
setKeepAlive(tc.fd, true)
505+
ka := d.KeepAlive
506+
if d.KeepAlive == 0 {
507+
ka = defaultTCPKeepAlive
508+
}
509+
setKeepAlivePeriod(tc.fd, ka)
510+
testHookSetKeepAlive(ka)
511+
}
512+
return c, nil
513+
}
514+
515+
func (d *Dialer) DialIP(ctx context.Context, network string, address netip.Addr) (Conn, error) {
516+
if _, ok := d.LocalAddr.(*IPAddr); !ok && d.LocalAddr != nil {
517+
return nil, &AddrError{Err: "mismatched local address type", Addr: d.LocalAddr.String()}
518+
}
519+
addr := &IPAddr{
520+
IP: address.AsSlice(),
521+
Zone: address.Zone(),
522+
}
523+
return d.dialAddr(ctx, network, addr)
524+
}
525+
526+
func (d *Dialer) DialUnix(ctx context.Context, network string, address *UnixAddr) (Conn, error) {
527+
if _, ok := d.LocalAddr.(*UnixAddr); !ok && d.LocalAddr != nil {
528+
return nil, &AddrError{Err: "mismatched local address type", Addr: d.LocalAddr.String()}
529+
}
530+
switch network {
531+
case "unix", "unixgram", "unixpacket":
532+
default:
533+
return nil, &OpError{Op: "dial", Net: network, Source: d.LocalAddr, Addr: address, Err: UnknownNetworkError(network)}
534+
}
535+
return d.dialAddr(ctx, network, address)
536+
}
537+
538+
func (d *Dialer) DialUDP(ctx context.Context, network string, address netip.AddrPort) (Conn, error) {
539+
if _, ok := d.LocalAddr.(*UDPAddr); !ok && d.LocalAddr != nil {
540+
return nil, &AddrError{Err: "mismatched local address type", Addr: d.LocalAddr.String()}
541+
}
542+
addr := &UDPAddr{
543+
IP: address.Addr().AsSlice(),
544+
Zone: address.Addr().Zone(),
545+
Port: int(address.Port()),
546+
}
547+
switch network {
548+
case "udp", "udp4", "udp6":
549+
default:
550+
return nil, &OpError{Op: "dial", Net: network, Source: d.LocalAddr, Addr: addr, Err: UnknownNetworkError(network)}
551+
}
552+
return d.dialAddr(ctx, network, addr)
553+
}
554+
446555
// dialParallel races two copies of dialSerial, giving the first a
447556
// head start. It returns the first established connection and
448557
// closes the others. Otherwise it returns an error from the first

src/net/iprawsock.go

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ package net
66

77
import (
88
"context"
9+
"net/netip"
910
"syscall"
1011
)
1112

@@ -220,6 +221,42 @@ func DialIP(network string, laddr, raddr *IPAddr) (*IPConn, error) {
220221
return c, nil
221222
}
222223

224+
// DialIPContext acts like DialIP but connects using
225+
// the provided context.
226+
//
227+
// The provided Context must be non-nil.
228+
//
229+
// The network must be an IP network name; see func Dial for details.
230+
//
231+
// If laddr is nil, a local address is automatically chosen.
232+
// If the IP field of raddr is nil or an unspecified IP address, the
233+
// local system is assumed.
234+
func DialIPContext(ctx context.Context, network string, laddr, raddr netip.Addr) (*IPConn, error) {
235+
if ctx == nil {
236+
panic("nil context")
237+
}
238+
if !raddr.IsValid() {
239+
return nil, &OpErrorNetIP{Op: "dial", Net: network, Source: netip.AddrPortFrom(laddr, 0), Err: errMissingAddress}
240+
}
241+
sd := &sysDialer{network: network, address: raddr.String()}
242+
var src *IPAddr
243+
if laddr.IsValid() {
244+
src = &IPAddr{
245+
IP: laddr.AsSlice(),
246+
Zone: laddr.Zone(),
247+
}
248+
}
249+
dst := &IPAddr{
250+
IP: raddr.AsSlice(),
251+
Zone: raddr.Zone(),
252+
}
253+
c, err := sd.dialIP(ctx, src, dst)
254+
if err != nil {
255+
return nil, &OpErrorNetIP{Op: "dial", Net: network, Source: netip.AddrPortFrom(laddr, 0), Addr: netip.AddrPortFrom(raddr, 0), Err: err}
256+
}
257+
return c, nil
258+
}
259+
223260
// ListenIP acts like ListenPacket for IP networks.
224261
//
225262
// The network must be an IP network name; see func Dial for details.

src/net/net.go

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ import (
8383
"errors"
8484
"internal/poll"
8585
"io"
86+
"net/netip"
8687
"os"
8788
"sync"
8889
"syscall"
@@ -488,6 +489,61 @@ func (e *OpError) Error() string {
488489
return s
489490
}
490491

492+
// OpErrorNetIP is the error type usually returned by functions in the net
493+
// package. It describes the operation, network type, and address of
494+
// an error.
495+
type OpErrorNetIP struct {
496+
// Op is the operation which caused the error, such as
497+
// "read" or "write".
498+
Op string
499+
500+
// Net is the network type on which this error occurred,
501+
// such as "tcp" or "udp6".
502+
Net string
503+
504+
// For operations involving a remote network connection, like
505+
// Dial, Read, or Write, Source is the corresponding local
506+
// network address.
507+
Source netip.AddrPort
508+
509+
// Addr is the network address for which this error occurred.
510+
// For local operations, like Listen or SetDeadline, Addr is
511+
// the address of the local endpoint being manipulated.
512+
// For operations involving a remote network connection, like
513+
// Dial, Read, or Write, Addr is the remote address of that
514+
// connection.
515+
Addr netip.AddrPort
516+
517+
// Err is the error that occurred during the operation.
518+
// The Error method panics if the error is nil.
519+
Err error
520+
}
521+
522+
func (e *OpErrorNetIP) Unwrap() error { return e.Err }
523+
524+
func (e *OpErrorNetIP) Error() string {
525+
if e == nil {
526+
return "<nil>"
527+
}
528+
s := e.Op
529+
if e.Net != "" {
530+
s += " " + e.Net
531+
}
532+
if e.Source.IsValid() {
533+
s += " " + e.Source.String()
534+
}
535+
if e.Addr.IsValid() {
536+
if e.Source.IsValid() {
537+
s += "->"
538+
} else {
539+
s += " "
540+
}
541+
s += e.Addr.String()
542+
}
543+
s += ": " + e.Err.Error()
544+
return s
545+
}
546+
491547
var (
492548
// aLongTimeAgo is a non-zero time, far in the past, used for
493549
// immediate cancellation of dials.

src/net/tcpsock.go

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,50 @@ func DialTCP(network string, laddr, raddr *TCPAddr) (*TCPConn, error) {
247247
return c, nil
248248
}
249249

250+
// DialTCPContext acts like DialTCP but connects using
251+
// the provided context.
252+
//
253+
// The provided Context must be non-nil.
254+
//
255+
// The network must be a TCP network name; see func Dial for details.
256+
//
257+
// If laddr is nil, a local address is automatically chosen.
258+
// If the IP field of raddr is nil or an unspecified IP address, the
259+
// local system is assumed.
260+
func DialTCPContext(ctx context.Context, network string, laddr, raddr netip.AddrPort) (*TCPConn, error) {
261+
if ctx == nil {
262+
panic("nil context")
263+
}
264+
switch network {
265+
case "tcp", "tcp4", "tcp6":
266+
default:
267+
return nil, &OpErrorNetIP{Op: "dial", Net: network, Source: laddr, Addr: raddr, Err: UnknownNetworkError(network)}
268+
}
269+
if !raddr.IsValid() {
270+
return nil, &OpErrorNetIP{Op: "dial", Net: network, Source: laddr, Err: errMissingAddress}
271+
}
272+
273+
var src *TCPAddr
274+
if laddr.IsValid() {
275+
src = &TCPAddr{
276+
IP: laddr.Addr().AsSlice(),
277+
Zone: laddr.Addr().Zone(),
278+
Port: int(laddr.Port()),
279+
}
280+
}
281+
dst := &TCPAddr{
282+
IP: raddr.Addr().AsSlice(),
283+
Zone: raddr.Addr().Zone(),
284+
Port: int(raddr.Port()),
285+
}
286+
sd := &sysDialer{network: network, address: raddr.String()}
287+
c, err := sd.dialTCP(ctx, src, dst)
288+
if err != nil {
289+
return nil, &OpErrorNetIP{Op: "dial", Net: network, Source: laddr, Addr: raddr, Err: err}
290+
}
291+
return c, nil
292+
}
293+
250294
// TCPListener is a TCP network listener. Clients should typically
251295
// use variables of type Listener instead of assuming TCP.
252296
type TCPListener struct {

src/net/udpsock.go

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,50 @@ func DialUDP(network string, laddr, raddr *UDPAddr) (*UDPConn, error) {
300300
return c, nil
301301
}
302302

303+
// DialUDPContext acts like DialUDP but connects using
304+
// the provided context.
305+
//
306+
// The provided Context must be non-nil.
307+
//
308+
// The network must be a UDP network name; see func Dial for details.
309+
//
310+
// If laddr is nil, a local address is automatically chosen.
311+
// If the IP field of raddr is nil or an unspecified IP address, the
312+
// local system is assumed.
313+
func DialUDPContext(ctx context.Context, network string, laddr, raddr netip.AddrPort) (*UDPConn, error) {
314+
if ctx == nil {
315+
panic("nil context")
316+
}
317+
switch network {
318+
case "udp", "udp4", "udp6":
319+
default:
320+
return nil, &OpErrorNetIP{Op: "dial", Net: network, Source: laddr, Addr: raddr, Err: UnknownNetworkError(network)}
321+
}
322+
if !raddr.IsValid() {
323+
return nil, &OpErrorNetIP{Op: "dial", Net: network, Source: laddr, Err: errMissingAddress}
324+
}
325+
sd := &sysDialer{network: network, address: raddr.String()}
326+
327+
var src *UDPAddr
328+
if laddr.IsValid() {
329+
src = &UDPAddr{
330+
IP: laddr.Addr().AsSlice(),
331+
Zone: laddr.Addr().Zone(),
332+
Port: int(laddr.Port()),
333+
}
334+
}
335+
dst := &UDPAddr{
336+
IP: raddr.Addr().AsSlice(),
337+
Zone: raddr.Addr().Zone(),
338+
Port: int(raddr.Port()),
339+
}
340+
c, err := sd.dialUDP(ctx, src, dst)
341+
if err != nil {
342+
return nil, &OpErrorNetIP{Op: "dial", Net: network, Source: laddr, Addr: raddr, Err: err}
343+
}
344+
return c, nil
345+
}
346+
303347
// ListenUDP acts like ListenPacket for UDP networks.
304348
//
305349
// The network must be a UDP network name; see func Dial for details.

src/net/unixsock.go

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,13 +201,29 @@ func newUnixConn(fd *netFD) *UnixConn { return &UnixConn{conn{fd}} }
201201
// If laddr is non-nil, it is used as the local address for the
202202
// connection.
203203
func DialUnix(network string, laddr, raddr *UnixAddr) (*UnixConn, error) {
204+
return DialUnixContext(context.Background(), network, laddr, raddr)
205+
}
206+
207+
// DialUnixContext acts like DialUnix but connects using
208+
// the provided context.
209+
//
210+
// The provided Context must be non-nil.
211+
//
212+
// The network must be a Unix network name; see func Dial for details.
213+
//
214+
// If laddr is non-nil, it is used as the local address for the
215+
// connection.
216+
func DialUnixContext(ctx context.Context, network string, laddr, raddr *UnixAddr) (*UnixConn, error) {
217+
if ctx == nil {
218+
panic("nil context")
219+
}
204220
switch network {
205221
case "unix", "unixgram", "unixpacket":
206222
default:
207223
return nil, &OpError{Op: "dial", Net: network, Source: laddr.opAddr(), Addr: raddr.opAddr(), Err: UnknownNetworkError(network)}
208224
}
209225
sd := &sysDialer{network: network, address: raddr.String()}
210-
c, err := sd.dialUnix(context.Background(), laddr, raddr)
226+
c, err := sd.dialUnix(ctx, laddr, raddr)
211227
if err != nil {
212228
return nil, &OpError{Op: "dial", Net: network, Source: laddr.opAddr(), Addr: raddr.opAddr(), Err: err}
213229
}

0 commit comments

Comments
 (0)