diff --git a/proxy/dial.go b/proxy/dial.go new file mode 100644 index 000000000..811c2e4e9 --- /dev/null +++ b/proxy/dial.go @@ -0,0 +1,54 @@ +// Copyright 2019 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package proxy + +import ( + "context" + "net" +) + +// A ContextDialer dials using a context. +type ContextDialer interface { + DialContext(ctx context.Context, network, address string) (net.Conn, error) +} + +// Dial works like DialContext on net.Dialer but using a dialer returned by FromEnvironment. +// +// The passed ctx is only used for returning the Conn, not the lifetime of the Conn. +// +// Custom dialers (registered via RegisterDialerType) that do not implement ContextDialer +// can leak a goroutine for as long as it takes the underlying Dialer implementation to timeout. +// +// A Conn returned from a successful Dial after the context has been cancelled will be immediately closed. +func Dial(ctx context.Context, network, address string) (net.Conn, error) { + d := FromEnvironment() + if xd, ok := d.(ContextDialer); ok { + return xd.DialContext(ctx, network, address) + } + return dialContext(ctx, d, network, address) +} + +// WARNING: this can leak a goroutine for as long as the underlying Dialer implementation takes to timeout +// A Conn returned from a successful Dial after the context has been cancelled will be immediately closed. +func dialContext(ctx context.Context, d Dialer, network, address string) (net.Conn, error) { + var ( + conn net.Conn + done = make(chan struct{}, 1) + err error + ) + go func() { + conn, err = d.Dial(network, address) + close(done) + if conn != nil && ctx.Err() != nil { + conn.Close() + } + }() + select { + case <-ctx.Done(): + err = ctx.Err() + case <-done: + } + return conn, err +} diff --git a/proxy/dial_test.go b/proxy/dial_test.go new file mode 100644 index 000000000..3edab49d1 --- /dev/null +++ b/proxy/dial_test.go @@ -0,0 +1,131 @@ +// Copyright 2019 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package proxy + +import ( + "context" + "fmt" + "net" + "os" + "testing" + "time" + + "golang.org/x/net/internal/sockstest" +) + +func TestDial(t *testing.T) { + ResetProxyEnv() + t.Run("DirectWithCancel", func(t *testing.T) { + defer ResetProxyEnv() + l, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + defer l.Close() + _, port, err := net.SplitHostPort(l.Addr().String()) + if err != nil { + t.Fatal(err) + } + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + c, err := Dial(ctx, l.Addr().Network(), net.JoinHostPort("", port)) + if err != nil { + t.Fatal(err) + } + c.Close() + }) + t.Run("DirectWithTimeout", func(t *testing.T) { + defer ResetProxyEnv() + l, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + defer l.Close() + _, port, err := net.SplitHostPort(l.Addr().String()) + if err != nil { + t.Fatal(err) + } + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + c, err := Dial(ctx, l.Addr().Network(), net.JoinHostPort("", port)) + if err != nil { + t.Fatal(err) + } + c.Close() + }) + t.Run("DirectWithTimeoutExceeded", func(t *testing.T) { + defer ResetProxyEnv() + l, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + defer l.Close() + _, port, err := net.SplitHostPort(l.Addr().String()) + if err != nil { + t.Fatal(err) + } + ctx, cancel := context.WithTimeout(context.Background(), time.Nanosecond) + time.Sleep(time.Millisecond) + defer cancel() + c, err := Dial(ctx, l.Addr().Network(), net.JoinHostPort("", port)) + if err == nil { + defer c.Close() + t.Fatal("failed to timeout") + } + }) + t.Run("SOCKS5", func(t *testing.T) { + defer ResetProxyEnv() + s, err := sockstest.NewServer(sockstest.NoAuthRequired, sockstest.NoProxyRequired) + if err != nil { + t.Fatal(err) + } + defer s.Close() + if err = os.Setenv("ALL_PROXY", fmt.Sprintf("socks5://%s", s.Addr().String())); err != nil { + t.Fatal(err) + } + c, err := Dial(context.Background(), s.TargetAddr().Network(), s.TargetAddr().String()) + if err != nil { + t.Fatal(err) + } + c.Close() + }) + t.Run("SOCKS5WithTimeout", func(t *testing.T) { + defer ResetProxyEnv() + s, err := sockstest.NewServer(sockstest.NoAuthRequired, sockstest.NoProxyRequired) + if err != nil { + t.Fatal(err) + } + defer s.Close() + if err = os.Setenv("ALL_PROXY", fmt.Sprintf("socks5://%s", s.Addr().String())); err != nil { + t.Fatal(err) + } + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + c, err := Dial(ctx, s.TargetAddr().Network(), s.TargetAddr().String()) + if err != nil { + t.Fatal(err) + } + c.Close() + }) + t.Run("SOCKS5WithTimeoutExceeded", func(t *testing.T) { + defer ResetProxyEnv() + s, err := sockstest.NewServer(sockstest.NoAuthRequired, sockstest.NoProxyRequired) + if err != nil { + t.Fatal(err) + } + defer s.Close() + if err = os.Setenv("ALL_PROXY", fmt.Sprintf("socks5://%s", s.Addr().String())); err != nil { + t.Fatal(err) + } + ctx, cancel := context.WithTimeout(context.Background(), time.Nanosecond) + time.Sleep(time.Millisecond) + defer cancel() + c, err := Dial(ctx, s.TargetAddr().Network(), s.TargetAddr().String()) + if err == nil { + defer c.Close() + t.Fatal("failed to timeout") + } + }) +} diff --git a/proxy/direct.go b/proxy/direct.go index 4c5ad88b1..26b51c346 100644 --- a/proxy/direct.go +++ b/proxy/direct.go @@ -5,6 +5,7 @@ package proxy import ( + "context" "net" ) @@ -13,6 +14,13 @@ type direct struct{} // Direct is a direct proxy: one that makes network connections directly. var Direct = direct{} +// Dial directly invokes net.Dial with the supplied parameters. func (direct) Dial(network, addr string) (net.Conn, error) { return net.Dial(network, addr) } + +// DialContext instantiates a net.Dialer and invokes its DialContext receiver with the supplied parameters. +func (direct) DialContext(ctx context.Context, network, addr string) (net.Conn, error) { + var d net.Dialer + return d.DialContext(ctx, network, addr) +} diff --git a/proxy/per_host.go b/proxy/per_host.go index 0689bb6a7..573fe79e8 100644 --- a/proxy/per_host.go +++ b/proxy/per_host.go @@ -5,6 +5,7 @@ package proxy import ( + "context" "net" "strings" ) @@ -41,6 +42,20 @@ func (p *PerHost) Dial(network, addr string) (c net.Conn, err error) { return p.dialerForRequest(host).Dial(network, addr) } +// DialContext connects to the address addr on the given network through either +// defaultDialer or bypass. +func (p *PerHost) DialContext(ctx context.Context, network, addr string) (c net.Conn, err error) { + host, _, err := net.SplitHostPort(addr) + if err != nil { + return nil, err + } + d := p.dialerForRequest(host) + if x, ok := d.(ContextDialer); ok { + return x.DialContext(ctx, network, addr) + } + return dialContext(ctx, d, network, addr) +} + func (p *PerHost) dialerForRequest(host string) Dialer { if ip := net.ParseIP(host); ip != nil { for _, net := range p.bypassNetworks { diff --git a/proxy/per_host_test.go b/proxy/per_host_test.go index a7d809571..0447eb427 100644 --- a/proxy/per_host_test.go +++ b/proxy/per_host_test.go @@ -5,6 +5,7 @@ package proxy import ( + "context" "errors" "net" "reflect" @@ -21,10 +22,6 @@ func (r *recordingProxy) Dial(network, addr string) (net.Conn, error) { } func TestPerHost(t *testing.T) { - var def, bypass recordingProxy - perHost := NewPerHost(&def, &bypass) - perHost.AddFromString("localhost,*.zone,127.0.0.1,10.0.0.1/8,1000::/16") - expectedDef := []string{ "example.com:123", "1.2.3.4:123", @@ -39,17 +36,41 @@ func TestPerHost(t *testing.T) { "[1000::]:123", } - for _, addr := range expectedDef { - perHost.Dial("tcp", addr) - } - for _, addr := range expectedBypass { - perHost.Dial("tcp", addr) - } + t.Run("Dial", func(t *testing.T) { + var def, bypass recordingProxy + perHost := NewPerHost(&def, &bypass) + perHost.AddFromString("localhost,*.zone,127.0.0.1,10.0.0.1/8,1000::/16") + for _, addr := range expectedDef { + perHost.Dial("tcp", addr) + } + for _, addr := range expectedBypass { + perHost.Dial("tcp", addr) + } - if !reflect.DeepEqual(expectedDef, def.addrs) { - t.Errorf("Hosts which went to the default proxy didn't match. Got %v, want %v", def.addrs, expectedDef) - } - if !reflect.DeepEqual(expectedBypass, bypass.addrs) { - t.Errorf("Hosts which went to the bypass proxy didn't match. Got %v, want %v", bypass.addrs, expectedBypass) - } + if !reflect.DeepEqual(expectedDef, def.addrs) { + t.Errorf("Hosts which went to the default proxy didn't match. Got %v, want %v", def.addrs, expectedDef) + } + if !reflect.DeepEqual(expectedBypass, bypass.addrs) { + t.Errorf("Hosts which went to the bypass proxy didn't match. Got %v, want %v", bypass.addrs, expectedBypass) + } + }) + + t.Run("DialContext", func(t *testing.T) { + var def, bypass recordingProxy + perHost := NewPerHost(&def, &bypass) + perHost.AddFromString("localhost,*.zone,127.0.0.1,10.0.0.1/8,1000::/16") + for _, addr := range expectedDef { + perHost.DialContext(context.Background(), "tcp", addr) + } + for _, addr := range expectedBypass { + perHost.DialContext(context.Background(), "tcp", addr) + } + + if !reflect.DeepEqual(expectedDef, def.addrs) { + t.Errorf("Hosts which went to the default proxy didn't match. Got %v, want %v", def.addrs, expectedDef) + } + if !reflect.DeepEqual(expectedBypass, bypass.addrs) { + t.Errorf("Hosts which went to the bypass proxy didn't match. Got %v, want %v", bypass.addrs, expectedBypass) + } + }) } diff --git a/proxy/proxy.go b/proxy/proxy.go index f6026b902..37d3cabd7 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -15,6 +15,7 @@ import ( ) // A Dialer is a means to establish a connection. +// Custom dialers should also implement ContextDialer. type Dialer interface { // Dial connects to the given address via the proxy. Dial(network, addr string) (c net.Conn, err error) diff --git a/proxy/socks5.go b/proxy/socks5.go index 56345ec8b..c91651f96 100644 --- a/proxy/socks5.go +++ b/proxy/socks5.go @@ -17,8 +17,14 @@ import ( func SOCKS5(network, address string, auth *Auth, forward Dialer) (Dialer, error) { d := socks.NewDialer(network, address) if forward != nil { - d.ProxyDial = func(_ context.Context, network string, address string) (net.Conn, error) { - return forward.Dial(network, address) + if f, ok := forward.(ContextDialer); ok { + d.ProxyDial = func(ctx context.Context, network string, address string) (net.Conn, error) { + return f.DialContext(ctx, network, address) + } + } else { + d.ProxyDial = func(ctx context.Context, network string, address string) (net.Conn, error) { + return dialContext(ctx, forward, network, address) + } } } if auth != nil {