15
15
package openssl
16
16
17
17
import (
18
+ "context"
18
19
"errors"
19
20
"net"
20
21
"time"
@@ -49,15 +50,15 @@ func NewListener(inner net.Listener, ctx *Ctx) net.Listener {
49
50
50
51
// Listen is a wrapper around net.Listen that wraps incoming connections with
51
52
// an OpenSSL server connection using the provided context ctx.
52
- func Listen (network , laddr string , ctx * Ctx ) (net.Listener , error ) {
53
- if ctx == nil {
53
+ func Listen (network , laddr string , sslCtx * Ctx ) (net.Listener , error ) {
54
+ if sslCtx == nil {
54
55
return nil , errors .New ("no ssl context provided" )
55
56
}
56
57
l , err := net .Listen (network , laddr )
57
58
if err != nil {
58
59
return nil , err
59
60
}
60
- return NewListener (l , ctx ), nil
61
+ return NewListener (l , sslCtx ), nil
61
62
}
62
63
63
64
type DialFlags int
@@ -77,8 +78,8 @@ const (
77
78
// some certs to the certificate store of the client context you're using.
78
79
// This library is not nice enough to use the system certificate store by
79
80
// default for you yet.
80
- func Dial (network , addr string , ctx * Ctx , flags DialFlags ) (* Conn , error ) {
81
- return DialSession (network , addr , ctx , flags , nil )
81
+ func Dial (network , addr string , sslCtx * Ctx , flags DialFlags ) (* Conn , error ) {
82
+ return DialSession (network , addr , sslCtx , flags , nil )
82
83
}
83
84
84
85
// DialTimeout acts like Dial but takes a timeout for network dial.
@@ -87,10 +88,57 @@ func Dial(network, addr string, ctx *Ctx, flags DialFlags) (*Conn, error) {
87
88
//
88
89
// See func Dial for a description of the network, addr, ctx and flags
89
90
// parameters.
90
- func DialTimeout (network , addr string , timeout time.Duration , ctx * Ctx ,
91
+ func DialTimeout (network , addr string , timeout time.Duration , sslCtx * Ctx ,
91
92
flags DialFlags ) (* Conn , error ) {
92
- d := net.Dialer {Timeout : timeout }
93
- return dialSession (d , network , addr , ctx , flags , nil )
93
+ host , err := parseHost (addr )
94
+ if err != nil {
95
+ return nil , err
96
+ }
97
+
98
+ conn , err := net .DialTimeout (network , addr , timeout )
99
+ if err != nil {
100
+ return nil , err
101
+ }
102
+ sslCtx , err = prepareCtx (sslCtx )
103
+ if err != nil {
104
+ conn .Close ()
105
+ return nil , err
106
+ }
107
+ client , err := createSession (conn , flags , host , sslCtx , nil )
108
+ if err != nil {
109
+ conn .Close ()
110
+ }
111
+ return client , err
112
+ }
113
+
114
+ // DialContext acts like Dial but takes a context for network dial.
115
+ //
116
+ // The context includes only network dial. It does not include OpenSSL calls.
117
+ //
118
+ // See func Dial for a description of the network, addr, ctx and flags
119
+ // parameters.
120
+ func DialContext (ctx context.Context , network , addr string ,
121
+ sslCtx * Ctx , flags DialFlags ) (* Conn , error ) {
122
+ host , err := parseHost (addr )
123
+ if err != nil {
124
+ return nil , err
125
+ }
126
+
127
+ dialer := net.Dialer {}
128
+ conn , err := dialer .DialContext (ctx , network , addr )
129
+ if err != nil {
130
+ return nil , err
131
+ }
132
+ sslCtx , err = prepareCtx (sslCtx )
133
+ if err != nil {
134
+ conn .Close ()
135
+ return nil , err
136
+ }
137
+ client , err := createSession (conn , flags , host , sslCtx , nil )
138
+ if err != nil {
139
+ conn .Close ()
140
+ }
141
+ return client , err
94
142
}
95
143
96
144
// DialSession will connect to network/address and then wrap the corresponding
@@ -106,61 +154,78 @@ func DialTimeout(network, addr string, timeout time.Duration, ctx *Ctx,
106
154
//
107
155
// If session is not nil it will be used to resume the tls state. The session
108
156
// can be retrieved from the GetSession method on the Conn.
109
- func DialSession (network , addr string , ctx * Ctx , flags DialFlags ,
157
+ func DialSession (network , addr string , sslCtx * Ctx , flags DialFlags ,
110
158
session []byte ) (* Conn , error ) {
111
- var d net.Dialer
112
- return dialSession (d , network , addr , ctx , flags , session )
113
- }
114
-
115
- func dialSession (d net.Dialer , network , addr string , ctx * Ctx , flags DialFlags ,
116
- session []byte ) (* Conn , error ) {
117
- host , _ , err := net .SplitHostPort (addr )
159
+ host , err := parseHost (addr )
118
160
if err != nil {
119
161
return nil , err
120
162
}
121
- if ctx == nil {
122
- var err error
123
- ctx , err = NewCtx ()
124
- if err != nil {
125
- return nil , err
126
- }
127
- // TODO: use operating system default certificate chain?
128
- }
129
163
130
- c , err := d .Dial (network , addr )
164
+ conn , err := net .Dial (network , addr )
131
165
if err != nil {
132
166
return nil , err
133
167
}
134
- conn , err := Client ( c , ctx )
168
+ sslCtx , err = prepareCtx ( sslCtx )
135
169
if err != nil {
136
- c .Close ()
170
+ conn .Close ()
137
171
return nil , err
138
172
}
139
- if session != nil {
140
- err := conn .setSession (session )
141
- if err != nil {
142
- c .Close ()
143
- return nil , err
144
- }
173
+ client , err := createSession (conn , flags , host , sslCtx , session )
174
+ if err != nil {
175
+ conn .Close ()
176
+ }
177
+ return client , err
178
+ }
179
+
180
+ func prepareCtx (sslCtx * Ctx ) (* Ctx , error ) {
181
+ if sslCtx == nil {
182
+ return NewCtx ()
145
183
}
184
+ return sslCtx , nil
185
+ }
186
+
187
+ func parseHost (addr string ) (string , error ) {
188
+ host , _ , err := net .SplitHostPort (addr )
189
+ return host , err
190
+ }
191
+
192
+ func handshake (conn * Conn , host string , flags DialFlags ) error {
193
+ var err error
146
194
if flags & DisableSNI == 0 {
147
195
err = conn .SetTlsExtHostName (host )
148
196
if err != nil {
149
- conn .Close ()
150
- return nil , err
197
+ return err
151
198
}
152
199
}
153
200
err = conn .Handshake ()
154
201
if err != nil {
155
- conn .Close ()
156
- return nil , err
202
+ return err
157
203
}
158
204
if flags & InsecureSkipHostVerification == 0 {
159
205
err = conn .VerifyHostname (host )
206
+ if err != nil {
207
+ return err
208
+ }
209
+ }
210
+ return nil
211
+ }
212
+
213
+ func createSession (c net.Conn , flags DialFlags , host string , sslCtx * Ctx ,
214
+ session []byte ) (* Conn , error ) {
215
+ conn , err := Client (c , sslCtx )
216
+ if err != nil {
217
+ return nil , err
218
+ }
219
+ if session != nil {
220
+ err := conn .setSession (session )
160
221
if err != nil {
161
222
conn .Close ()
162
223
return nil , err
163
224
}
164
225
}
226
+ if err := handshake (conn , host , flags ); err != nil {
227
+ conn .Close ()
228
+ return nil , err
229
+ }
165
230
return conn , nil
166
231
}
0 commit comments