Skip to content

Commit f4a0879

Browse files
committed
proxy: add Dial, DialTimeout, and DialContext
The existing API does not allow client code to take advantage of Dialer implementations that implement DialTimeout and DialContext receivers. These functions provide a familiar API, see Dial and DialTimeout in the net package. Signed-off-by: Jacob Blain Christen <[email protected]>
1 parent addf6b3 commit f4a0879

File tree

2 files changed

+363
-0
lines changed

2 files changed

+363
-0
lines changed

proxy/dial.go

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
package proxy
2+
3+
import (
4+
"context"
5+
"net"
6+
"time"
7+
)
8+
9+
// ContextDialer implements DialContext akin to net.Dialer#DialContext.
10+
type ContextDialer interface {
11+
DialContext(ctx context.Context, network, address string) (net.Conn, error)
12+
}
13+
14+
// TimeoutDialer implements DialTimeout much like net.DialTimeout.
15+
type TimeoutDialer interface {
16+
DialTimeout(network, address string, timeout time.Duration) (net.Conn, error)
17+
}
18+
19+
// Dial works like net.Dial but using a Dialer derived from the configured proxy environment.
20+
func Dial(network, address string) (net.Conn, error) {
21+
var d = FromEnvironment()
22+
return d.Dial(network, address)
23+
}
24+
25+
// DialTimeout works like net.DialTimeout but using a Dialer derived from the configured proxy environment.
26+
// Custom dialers (registered via RegisterDialerType) that do not implement TimeoutDialer (or ContextDialer)
27+
// can leak a goroutine for as long as it takes the underlying Dialer implementation to timeout.
28+
func DialTimeout(network, address string, timeout time.Duration) (net.Conn, error) {
29+
var d = FromEnvironment()
30+
if td, ok := d.(TimeoutDialer); ok {
31+
return td.DialTimeout(network, address, timeout)
32+
}
33+
34+
ctx, cancel := context.WithTimeout(context.Background(), timeout)
35+
defer cancel()
36+
37+
if cd, ok := d.(ContextDialer); ok {
38+
return cd.DialContext(ctx, network, address)
39+
}
40+
return dialContext(ctx, network, address, d)
41+
}
42+
43+
// DialContext works like DialContext on net.Dialer but using a dialer derived from the configured proxy environment.
44+
// Custom dialers (registered via RegisterDialerType) that do not implement ContextDialer
45+
// can leak a goroutine for as long as it takes the underlying Dialer implementation to timeout.
46+
func DialContext(ctx context.Context, network, address string) (net.Conn, error) {
47+
var d = FromEnvironment()
48+
if td, ok := d.(ContextDialer); ok {
49+
return td.DialContext(ctx, network, address)
50+
}
51+
return dialContext(ctx, network, address, d)
52+
}
53+
54+
// WARNING: for custom dialers that only implement proxy.Dialer this will leak a goroutine
55+
// that will live as long as it takes the underlying Dialer implementation to timeout
56+
func dialContext(ctx context.Context, network, address string, d Dialer) (net.Conn, error) {
57+
var (
58+
conn net.Conn
59+
done = make(chan struct{}, 1)
60+
err error
61+
)
62+
go func() {
63+
conn, err = d.Dial(network, address)
64+
close(done)
65+
if conn != nil && ctx.Err() != nil {
66+
conn.Close()
67+
}
68+
}()
69+
select {
70+
case <-ctx.Done():
71+
err = ctx.Err()
72+
case <-done:
73+
}
74+
return conn, err
75+
}

proxy/dial_test.go

Lines changed: 288 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,288 @@
1+
package proxy
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"net"
7+
"os"
8+
"testing"
9+
"time"
10+
11+
"golang.org/x/net/internal/sockstest"
12+
)
13+
14+
func TestDial(t *testing.T) {
15+
t.Run("Direct", func(t *testing.T) {
16+
defer ResetProxyEnv()
17+
l, err := net.Listen("tcp", "127.0.0.1:0")
18+
if err != nil {
19+
t.Fatal(err)
20+
}
21+
defer l.Close()
22+
_, port, err := net.SplitHostPort(l.Addr().String())
23+
if err != nil {
24+
t.Fatal(err)
25+
}
26+
if err = os.Unsetenv("ALL_PROXY"); err != nil {
27+
t.Fatal(err)
28+
}
29+
if err = os.Unsetenv("all_proxy"); err != nil {
30+
t.Fatal(err)
31+
}
32+
c, err := Dial(l.Addr().Network(), net.JoinHostPort("", port))
33+
if err != nil {
34+
t.Fatal(err)
35+
}
36+
c.Close()
37+
})
38+
t.Run("SOCKS5", func(t *testing.T) {
39+
defer ResetProxyEnv()
40+
s, err := sockstest.NewServer(sockstest.NoAuthRequired, sockstest.NoProxyRequired)
41+
if err != nil {
42+
t.Fatal(err)
43+
}
44+
defer s.Close()
45+
p := fmt.Sprintf("socks5://%s", s.Addr().String())
46+
if err = os.Setenv("ALL_PROXY", p); err != nil {
47+
t.Fatal(err)
48+
}
49+
if err = os.Unsetenv("all_proxy"); err != nil {
50+
t.Fatal(err)
51+
}
52+
c, err := Dial(s.TargetAddr().Network(), s.TargetAddr().String())
53+
if err != nil {
54+
t.Fatal(err)
55+
}
56+
c.Close()
57+
})
58+
}
59+
60+
func TestDialContext(t *testing.T) {
61+
t.Run("DirectWithCancel", func(t *testing.T) {
62+
defer ResetProxyEnv()
63+
l, err := net.Listen("tcp", "127.0.0.1:0")
64+
if err != nil {
65+
t.Fatal(err)
66+
}
67+
defer l.Close()
68+
_, port, err := net.SplitHostPort(l.Addr().String())
69+
if err != nil {
70+
t.Fatal(err)
71+
}
72+
if err = os.Unsetenv("ALL_PROXY"); err != nil {
73+
t.Fatal(err)
74+
}
75+
if err = os.Unsetenv("all_proxy"); err != nil {
76+
t.Fatal(err)
77+
}
78+
ctx, cancel := context.WithCancel(context.Background())
79+
defer cancel()
80+
c, err := DialContext(ctx, l.Addr().Network(), net.JoinHostPort("", port))
81+
if err != nil {
82+
t.Fatal(err)
83+
}
84+
c.Close()
85+
})
86+
t.Run("DirectWithTimeout", func(t *testing.T) {
87+
defer ResetProxyEnv()
88+
l, err := net.Listen("tcp", "127.0.0.1:0")
89+
if err != nil {
90+
t.Fatal(err)
91+
}
92+
defer l.Close()
93+
_, port, err := net.SplitHostPort(l.Addr().String())
94+
if err != nil {
95+
t.Fatal(err)
96+
}
97+
if err = os.Unsetenv("ALL_PROXY"); err != nil {
98+
t.Fatal(err)
99+
}
100+
if err = os.Unsetenv("all_proxy"); err != nil {
101+
t.Fatal(err)
102+
}
103+
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
104+
defer cancel()
105+
c, err := DialContext(ctx, l.Addr().Network(), net.JoinHostPort("", port))
106+
if err != nil {
107+
t.Fatal(err)
108+
}
109+
c.Close()
110+
})
111+
t.Run("DirectWithTimeoutExceeded", func(t *testing.T) {
112+
defer ResetProxyEnv()
113+
l, err := net.Listen("tcp", "127.0.0.1:0")
114+
if err != nil {
115+
t.Fatal(err)
116+
}
117+
defer l.Close()
118+
_, port, err := net.SplitHostPort(l.Addr().String())
119+
if err != nil {
120+
t.Fatal(err)
121+
}
122+
if err = os.Unsetenv("ALL_PROXY"); err != nil {
123+
t.Fatal(err)
124+
}
125+
if err = os.Unsetenv("all_proxy"); err != nil {
126+
t.Fatal(err)
127+
}
128+
ctx, cancel := context.WithTimeout(context.Background(), time.Nanosecond)
129+
time.Sleep(time.Millisecond)
130+
defer cancel()
131+
c, err := DialContext(ctx, l.Addr().Network(), net.JoinHostPort("", port))
132+
if err == nil {
133+
defer c.Close()
134+
t.Fatal("failed to timeout")
135+
}
136+
})
137+
t.Run("SOCKS5", func(t *testing.T) {
138+
defer ResetProxyEnv()
139+
s, err := sockstest.NewServer(sockstest.NoAuthRequired, sockstest.NoProxyRequired)
140+
if err != nil {
141+
t.Fatal(err)
142+
}
143+
defer s.Close()
144+
if err = os.Setenv("ALL_PROXY", fmt.Sprintf("socks5://%s", s.Addr().String())); err != nil {
145+
t.Fatal(err)
146+
}
147+
if err = os.Unsetenv("all_proxy"); err != nil {
148+
t.Fatal(err)
149+
}
150+
c, err := DialContext(context.Background(), s.TargetAddr().Network(), s.TargetAddr().String())
151+
if err != nil {
152+
t.Fatal(err)
153+
}
154+
c.Close()
155+
})
156+
t.Run("SOCKS5WithTimeout", func(t *testing.T) {
157+
defer ResetProxyEnv()
158+
s, err := sockstest.NewServer(sockstest.NoAuthRequired, sockstest.NoProxyRequired)
159+
if err != nil {
160+
t.Fatal(err)
161+
}
162+
defer s.Close()
163+
if err = os.Setenv("ALL_PROXY", fmt.Sprintf("socks5://%s", s.Addr().String())); err != nil {
164+
t.Fatal(err)
165+
}
166+
if err = os.Unsetenv("all_proxy"); err != nil {
167+
t.Fatal(err)
168+
}
169+
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
170+
defer cancel()
171+
c, err := DialContext(ctx, s.TargetAddr().Network(), s.TargetAddr().String())
172+
if err != nil {
173+
t.Fatal(err)
174+
}
175+
c.Close()
176+
})
177+
t.Run("SOCKS5WithTimeoutExceeded", func(t *testing.T) {
178+
defer ResetProxyEnv()
179+
s, err := sockstest.NewServer(sockstest.NoAuthRequired, sockstest.NoProxyRequired)
180+
if err != nil {
181+
t.Fatal(err)
182+
}
183+
defer s.Close()
184+
if err = os.Setenv("ALL_PROXY", fmt.Sprintf("socks5://%s", s.Addr().String())); err != nil {
185+
t.Fatal(err)
186+
}
187+
if err = os.Unsetenv("all_proxy"); err != nil {
188+
t.Fatal(err)
189+
}
190+
ctx, cancel := context.WithTimeout(context.Background(), time.Nanosecond)
191+
time.Sleep(time.Millisecond)
192+
defer cancel()
193+
c, err := DialContext(ctx, s.TargetAddr().Network(), s.TargetAddr().String())
194+
if err == nil {
195+
defer c.Close()
196+
t.Fatal("failed to timeout")
197+
}
198+
})
199+
}
200+
201+
func TestDialTimeout(t *testing.T) {
202+
t.Run("Direct", func(t *testing.T) {
203+
defer ResetProxyEnv()
204+
l, err := net.Listen("tcp", "127.0.0.1:0")
205+
if err != nil {
206+
t.Fatal(err)
207+
}
208+
defer l.Close()
209+
_, port, err := net.SplitHostPort(l.Addr().String())
210+
if err != nil {
211+
t.Fatal(err)
212+
}
213+
if err = os.Unsetenv("ALL_PROXY"); err != nil {
214+
t.Fatal(err)
215+
}
216+
if err = os.Unsetenv("all_proxy"); err != nil {
217+
t.Fatal(err)
218+
}
219+
c, err := DialTimeout(l.Addr().Network(), net.JoinHostPort("", port), 5*time.Second)
220+
if err != nil {
221+
t.Fatal(err)
222+
}
223+
c.Close()
224+
})
225+
t.Run("DirectTooSlow", func(t *testing.T) {
226+
defer ResetProxyEnv()
227+
l, err := net.Listen("tcp", "127.0.0.1:0")
228+
if err != nil {
229+
t.Fatal(err)
230+
}
231+
defer l.Close()
232+
_, port, err := net.SplitHostPort(l.Addr().String())
233+
if err != nil {
234+
t.Fatal(err)
235+
}
236+
if err = os.Unsetenv("ALL_PROXY"); err != nil {
237+
t.Fatal(err)
238+
}
239+
if err = os.Unsetenv("all_proxy"); err != nil {
240+
t.Fatal(err)
241+
}
242+
c, err := DialTimeout(l.Addr().Network(), net.JoinHostPort("", port), time.Nanosecond)
243+
if err == nil {
244+
defer c.Close()
245+
t.Fatal("failed to timeout")
246+
}
247+
})
248+
t.Run("SOCKS5", func(t *testing.T) {
249+
defer ResetProxyEnv()
250+
s, err := sockstest.NewServer(sockstest.NoAuthRequired, sockstest.NoProxyRequired)
251+
if err != nil {
252+
t.Fatal(err)
253+
}
254+
defer s.Close()
255+
p := fmt.Sprintf("socks5://%s", s.Addr().String())
256+
if err = os.Setenv("ALL_PROXY", p); err != nil {
257+
t.Fatal(err)
258+
}
259+
if err = os.Unsetenv("all_proxy"); err != nil {
260+
t.Fatal(err)
261+
}
262+
c, err := DialTimeout(s.TargetAddr().Network(), s.TargetAddr().String(), 5*time.Second)
263+
if err != nil {
264+
t.Fatal(err)
265+
}
266+
c.Close()
267+
})
268+
t.Run("SOCKS5TooSlow", func(t *testing.T) {
269+
defer ResetProxyEnv()
270+
s, err := sockstest.NewServer(sockstest.NoAuthRequired, sockstest.NoProxyRequired)
271+
if err != nil {
272+
t.Fatal(err)
273+
}
274+
defer s.Close()
275+
p := fmt.Sprintf("socks5://%s", s.Addr().String())
276+
if err = os.Setenv("ALL_PROXY", p); err != nil {
277+
t.Fatal(err)
278+
}
279+
if err = os.Unsetenv("all_proxy"); err != nil {
280+
t.Fatal(err)
281+
}
282+
c, err := DialTimeout(s.TargetAddr().Network(), s.TargetAddr().String(), time.Nanosecond)
283+
if err == nil {
284+
defer c.Close()
285+
t.Fatal("failed to timeout")
286+
}
287+
})
288+
}

0 commit comments

Comments
 (0)