Skip to content

Commit 5bb944e

Browse files
committed
dial: add DialContext function
In order to replace timeouts with contexts in `Connect` instance creation (go-tarantool), I need a `DialContext` function. It accepts context, and cancels, if context is canceled by user. Part of tarantool/go-tarantool#136
1 parent b452431 commit 5bb944e

File tree

3 files changed

+206
-40
lines changed

3 files changed

+206
-40
lines changed

net.go

Lines changed: 102 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
package openssl
1616

1717
import (
18+
"context"
1819
"errors"
1920
"net"
2021
"time"
@@ -49,15 +50,15 @@ func NewListener(inner net.Listener, ctx *Ctx) net.Listener {
4950

5051
// Listen is a wrapper around net.Listen that wraps incoming connections with
5152
// an OpenSSL server connection using the provided context ctx.
52-
func Listen(network, laddr string, ctx *Ctx) (net.Listener, error) {
53-
if ctx == nil {
53+
func Listen(network, laddr string, sslCtx *Ctx) (net.Listener, error) {
54+
if sslCtx == nil {
5455
return nil, errors.New("no ssl context provided")
5556
}
5657
l, err := net.Listen(network, laddr)
5758
if err != nil {
5859
return nil, err
5960
}
60-
return NewListener(l, ctx), nil
61+
return NewListener(l, sslCtx), nil
6162
}
6263

6364
type DialFlags int
@@ -77,8 +78,8 @@ const (
7778
// some certs to the certificate store of the client context you're using.
7879
// This library is not nice enough to use the system certificate store by
7980
// default for you yet.
80-
func Dial(network, addr string, ctx *Ctx, flags DialFlags) (*Conn, error) {
81-
return DialSession(network, addr, ctx, flags, nil)
81+
func Dial(network, addr string, sslCtx *Ctx, flags DialFlags) (*Conn, error) {
82+
return DialSession(network, addr, sslCtx, flags, nil)
8283
}
8384

8485
// DialTimeout acts like Dial but takes a timeout for network dial.
@@ -87,10 +88,57 @@ func Dial(network, addr string, ctx *Ctx, flags DialFlags) (*Conn, error) {
8788
//
8889
// See func Dial for a description of the network, addr, ctx and flags
8990
// parameters.
90-
func DialTimeout(network, addr string, timeout time.Duration, ctx *Ctx,
91+
func DialTimeout(network, addr string, timeout time.Duration, sslCtx *Ctx,
9192
flags DialFlags) (*Conn, error) {
92-
d := net.Dialer{Timeout: timeout}
93-
return dialSession(d, network, addr, ctx, flags, nil)
93+
host, err := parseHost(addr)
94+
if err != nil {
95+
return nil, err
96+
}
97+
98+
conn, err := net.DialTimeout(network, addr, timeout)
99+
if err != nil {
100+
return nil, err
101+
}
102+
sslCtx, err = prepareCtx(sslCtx)
103+
if err != nil {
104+
conn.Close()
105+
return nil, err
106+
}
107+
client, err := createSession(conn, flags, host, sslCtx, nil)
108+
if err != nil {
109+
conn.Close()
110+
}
111+
return client, err
112+
}
113+
114+
// DialContext acts like Dial but takes a context for network dial.
115+
//
116+
// The context includes only network dial. It does not include OpenSSL calls.
117+
//
118+
// See func Dial for a description of the network, addr, ctx and flags
119+
// parameters.
120+
func DialContext(ctx context.Context, network, addr string,
121+
sslCtx *Ctx, flags DialFlags) (*Conn, error) {
122+
host, err := parseHost(addr)
123+
if err != nil {
124+
return nil, err
125+
}
126+
127+
dialer := net.Dialer{}
128+
conn, err := dialer.DialContext(ctx, network, addr)
129+
if err != nil {
130+
return nil, err
131+
}
132+
sslCtx, err = prepareCtx(sslCtx)
133+
if err != nil {
134+
conn.Close()
135+
return nil, err
136+
}
137+
client, err := createSession(conn, flags, host, sslCtx, nil)
138+
if err != nil {
139+
conn.Close()
140+
}
141+
return client, err
94142
}
95143

96144
// DialSession will connect to network/address and then wrap the corresponding
@@ -106,61 +154,78 @@ func DialTimeout(network, addr string, timeout time.Duration, ctx *Ctx,
106154
//
107155
// If session is not nil it will be used to resume the tls state. The session
108156
// can be retrieved from the GetSession method on the Conn.
109-
func DialSession(network, addr string, ctx *Ctx, flags DialFlags,
157+
func DialSession(network, addr string, sslCtx *Ctx, flags DialFlags,
110158
session []byte) (*Conn, error) {
111-
var d net.Dialer
112-
return dialSession(d, network, addr, ctx, flags, session)
113-
}
114-
115-
func dialSession(d net.Dialer, network, addr string, ctx *Ctx, flags DialFlags,
116-
session []byte) (*Conn, error) {
117-
host, _, err := net.SplitHostPort(addr)
159+
host, err := parseHost(addr)
118160
if err != nil {
119161
return nil, err
120162
}
121-
if ctx == nil {
122-
var err error
123-
ctx, err = NewCtx()
124-
if err != nil {
125-
return nil, err
126-
}
127-
// TODO: use operating system default certificate chain?
128-
}
129163

130-
c, err := d.Dial(network, addr)
164+
conn, err := net.Dial(network, addr)
131165
if err != nil {
132166
return nil, err
133167
}
134-
conn, err := Client(c, ctx)
168+
sslCtx, err = prepareCtx(sslCtx)
135169
if err != nil {
136-
c.Close()
170+
conn.Close()
137171
return nil, err
138172
}
139-
if session != nil {
140-
err := conn.setSession(session)
141-
if err != nil {
142-
c.Close()
143-
return nil, err
144-
}
173+
client, err := createSession(conn, flags, host, sslCtx, session)
174+
if err != nil {
175+
conn.Close()
176+
}
177+
return client, err
178+
}
179+
180+
func prepareCtx(sslCtx *Ctx) (*Ctx, error) {
181+
if sslCtx == nil {
182+
return NewCtx()
145183
}
184+
return sslCtx, nil
185+
}
186+
187+
func parseHost(addr string) (string, error) {
188+
host, _, err := net.SplitHostPort(addr)
189+
return host, err
190+
}
191+
192+
func handshake(conn *Conn, host string, flags DialFlags) error {
193+
var err error
146194
if flags&DisableSNI == 0 {
147195
err = conn.SetTlsExtHostName(host)
148196
if err != nil {
149-
conn.Close()
150-
return nil, err
197+
return err
151198
}
152199
}
153200
err = conn.Handshake()
154201
if err != nil {
155-
conn.Close()
156-
return nil, err
202+
return err
157203
}
158204
if flags&InsecureSkipHostVerification == 0 {
159205
err = conn.VerifyHostname(host)
206+
if err != nil {
207+
return err
208+
}
209+
}
210+
return nil
211+
}
212+
213+
func createSession(c net.Conn, flags DialFlags, host string, sslCtx *Ctx,
214+
session []byte) (*Conn, error) {
215+
conn, err := Client(c, sslCtx)
216+
if err != nil {
217+
return nil, err
218+
}
219+
if session != nil {
220+
err := conn.setSession(session)
160221
if err != nil {
161222
conn.Close()
162223
return nil, err
163224
}
164225
}
226+
if err := handshake(conn, host, flags); err != nil {
227+
conn.Close()
228+
return nil, err
229+
}
165230
return conn, nil
166231
}

net_test.go

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
package openssl_test
2+
3+
import (
4+
"context"
5+
"crypto/rand"
6+
"io"
7+
"net"
8+
"sync"
9+
"testing"
10+
"time"
11+
12+
"github.com/tarantool/go-openssl"
13+
)
14+
15+
func sslConnect(t *testing.T, ssl_listener net.Listener) {
16+
for {
17+
var err error
18+
conn, err := ssl_listener.Accept()
19+
if err != nil {
20+
t.Errorf("failed accept: %s", err)
21+
continue
22+
}
23+
io.Copy(conn, io.LimitReader(rand.Reader, 1024))
24+
break
25+
}
26+
}
27+
28+
func TestDial(t *testing.T) {
29+
ctx := openssl.GetCtx(t)
30+
if err := ctx.SetCipherList("AES128-SHA"); err != nil {
31+
t.Fatal(err)
32+
}
33+
ssl_listener, err := openssl.Listen("tcp", "localhost:0", ctx)
34+
if err != nil {
35+
t.Fatal(err)
36+
}
37+
38+
wg := sync.WaitGroup{}
39+
wg.Add(1)
40+
go func() {
41+
sslConnect(t, ssl_listener)
42+
wg.Done()
43+
}()
44+
45+
client, err := openssl.Dial(ssl_listener.Addr().Network(),
46+
ssl_listener.Addr().String(), ctx, openssl.InsecureSkipHostVerification)
47+
48+
wg.Wait()
49+
50+
if err != nil {
51+
t.Fatalf("unexpected err: %v", err)
52+
}
53+
n, err := io.Copy(io.Discard, io.LimitReader(client, 1024))
54+
if err != nil {
55+
t.Fatalf("unexpected err: %v", err)
56+
}
57+
if n != 1024 {
58+
if n == 0 {
59+
t.Fatal("client is closed after creation")
60+
}
61+
t.Fatalf("client lost some bytes, expected %d, got %d", 1024, n)
62+
}
63+
}
64+
65+
func TestDialTimeout(t *testing.T) {
66+
ctx := openssl.GetCtx(t)
67+
if err := ctx.SetCipherList("AES128-SHA"); err != nil {
68+
t.Fatal(err)
69+
}
70+
ssl_listener, err := openssl.Listen("tcp", "localhost:0", ctx)
71+
if err != nil {
72+
t.Fatal(err)
73+
}
74+
75+
client, err := openssl.DialTimeout(ssl_listener.Addr().Network(),
76+
ssl_listener.Addr().String(), time.Nanosecond, ctx, 0)
77+
78+
if client != nil || err == nil {
79+
t.Fatalf("expected error")
80+
}
81+
}
82+
83+
func TestDialContext(t *testing.T) {
84+
ctx := openssl.GetCtx(t)
85+
if err := ctx.SetCipherList("AES128-SHA"); err != nil {
86+
t.Fatal(err)
87+
}
88+
ssl_listener, err := openssl.Listen("tcp", "localhost:0", ctx)
89+
if err != nil {
90+
t.Fatal(err)
91+
}
92+
93+
cancelCtx, cancel := context.WithCancel(context.Background())
94+
cancel()
95+
client, err := openssl.DialContext(cancelCtx, ssl_listener.Addr().Network(),
96+
ssl_listener.Addr().String(), ctx, 0)
97+
98+
if client != nil || err == nil {
99+
t.Fatalf("expected error")
100+
}
101+
}

ssl_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -738,7 +738,7 @@ func TestStdlibLotsOfConns(t *testing.T) {
738738
})
739739
}
740740

741-
func getCtx(t *testing.T) *Ctx {
741+
func GetCtx(t *testing.T) *Ctx {
742742
ctx, err := NewCtx()
743743
if err != nil {
744744
t.Fatal(err)
@@ -761,7 +761,7 @@ func getCtx(t *testing.T) *Ctx {
761761
}
762762

763763
func TestOpenSSLLotsOfConns(t *testing.T) {
764-
ctx := getCtx(t)
764+
ctx := GetCtx(t)
765765
if err := ctx.SetCipherList("AES128-SHA"); err != nil {
766766
t.Fatal(err)
767767
}
@@ -928,7 +928,7 @@ func TestOpenSSLLotsOfConnsWithFail(t *testing.T) {
928928
t.Run(name, func(t *testing.T) {
929929
LotsOfConns(t, 1024*64, 10, 100, 0*time.Second,
930930
func(l net.Listener) net.Listener {
931-
return NewListener(l, getCtx(t))
931+
return NewListener(l, GetCtx(t))
932932
}, func(c net.Conn) (net.Conn, error) {
933933
return Client(c, getClientCtx(t))
934934
})

0 commit comments

Comments
 (0)