Skip to content

Commit 5ed2f45

Browse files
committed
Refactor client handshake
- To take advantage of the Host header cleanup in the net/http Request.Write method, use a net/http Request to write the handshake to the wire. - Move code from the deprecated NewClientConn function to Dialer.Dial. This change makes it easier to add proxy support to Dialer.Dial. Add comment noting that NewClientConn is deprecated. - Update the code so that parseURL can be replaced with net/url Parse. We need to wait until we can require 1.5 before making the swap.
1 parent 4239127 commit 5ed2f45

File tree

2 files changed

+94
-93
lines changed

2 files changed

+94
-93
lines changed

client.go

Lines changed: 91 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -30,50 +30,17 @@ var ErrBadHandshake = errors.New("websocket: bad handshake")
3030
// If the WebSocket handshake fails, ErrBadHandshake is returned along with a
3131
// non-nil *http.Response so that callers can handle redirects, authentication,
3232
// etc.
33+
//
34+
// Deprecated: Use Dialer instead.
3335
func NewClient(netConn net.Conn, u *url.URL, requestHeader http.Header, readBufSize, writeBufSize int) (c *Conn, response *http.Response, err error) {
34-
challengeKey, err := generateChallengeKey()
35-
if err != nil {
36-
return nil, nil, err
36+
d := Dialer{
37+
ReadBufferSize: readBufSize,
38+
WriteBufferSize: writeBufSize,
39+
NetDial: func(net, addr string) (net.Conn, error) {
40+
return netConn, nil
41+
},
3742
}
38-
acceptKey := computeAcceptKey(challengeKey)
39-
40-
c = newConn(netConn, false, readBufSize, writeBufSize)
41-
p := c.writeBuf[:0]
42-
p = append(p, "GET "...)
43-
p = append(p, u.RequestURI()...)
44-
p = append(p, " HTTP/1.1\r\nHost: "...)
45-
p = append(p, u.Host...)
46-
// "Upgrade" is capitalized for servers that do not use case insensitive
47-
// comparisons on header tokens.
48-
p = append(p, "\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Version: 13\r\nSec-WebSocket-Key: "...)
49-
p = append(p, challengeKey...)
50-
p = append(p, "\r\n"...)
51-
for k, vs := range requestHeader {
52-
for _, v := range vs {
53-
p = append(p, k...)
54-
p = append(p, ": "...)
55-
p = append(p, v...)
56-
p = append(p, "\r\n"...)
57-
}
58-
}
59-
p = append(p, "\r\n"...)
60-
61-
if _, err := netConn.Write(p); err != nil {
62-
return nil, nil, err
63-
}
64-
65-
resp, err := http.ReadResponse(c.br, &http.Request{Method: "GET", URL: u})
66-
if err != nil {
67-
return nil, nil, err
68-
}
69-
if resp.StatusCode != 101 ||
70-
!strings.EqualFold(resp.Header.Get("Upgrade"), "websocket") ||
71-
!strings.EqualFold(resp.Header.Get("Connection"), "upgrade") ||
72-
resp.Header.Get("Sec-Websocket-Accept") != acceptKey {
73-
return nil, resp, ErrBadHandshake
74-
}
75-
c.subprotocol = resp.Header.Get("Sec-Websocket-Protocol")
76-
return c, resp, nil
43+
return d.Dial(u.String(), requestHeader)
7744
}
7845

7946
// A Dialer contains options for connecting to WebSocket server.
@@ -99,17 +66,15 @@ type Dialer struct {
9966

10067
var errMalformedURL = errors.New("malformed ws or wss URL")
10168

102-
// parseURL parses the URL. The url.Parse function is not used here because
103-
// url.Parse mangles the path.
69+
// parseURL parses the URL.
70+
//
71+
// This function is a replacement for the standard library url.Parse function.
72+
// In Go 1.4 and earlier, url.Parse loses information from the path.
10473
func parseURL(s string) (*url.URL, error) {
10574
// From the RFC:
10675
//
10776
// ws-URI = "ws:" "//" host [ ":" port ] path [ "?" query ]
10877
// wss-URI = "wss:" "//" host [ ":" port ] path [ "?" query ]
109-
//
110-
// We don't use the net/url parser here because the dialer interface does
111-
// not provide a way for applications to work around percent deocding in
112-
// the net/url parser.
11378

11479
var u url.URL
11580
switch {
@@ -131,7 +96,8 @@ func parseURL(s string) (*url.URL, error) {
13196
}
13297

13398
if strings.Contains(u.Host, "@") {
134-
// WebSocket URIs do not contain user information.
99+
// Don't bother parsing user information because user information is
100+
// not allowed in websocket URIs.
135101
return nil, errMalformedURL
136102
}
137103

@@ -166,16 +132,67 @@ var DefaultDialer = &Dialer{}
166132
// etcetera. The response body may not contain the entire response and does not
167133
// need to be closed by the application.
168134
func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Response, error) {
135+
136+
if d == nil {
137+
d = &Dialer{}
138+
}
139+
140+
challengeKey, err := generateChallengeKey()
141+
if err != nil {
142+
return nil, nil, err
143+
}
144+
169145
u, err := parseURL(urlStr)
170146
if err != nil {
171147
return nil, nil, err
172148
}
173149

174-
hostPort, hostNoPort := hostPortNoPort(u)
150+
switch u.Scheme {
151+
case "ws":
152+
u.Scheme = "http"
153+
case "wss":
154+
u.Scheme = "https"
155+
default:
156+
return nil, nil, errMalformedURL
157+
}
175158

176-
if d == nil {
177-
d = &Dialer{}
159+
if u.User != nil {
160+
// User name and password are not allowed in websocket URIs.
161+
return nil, nil, errMalformedURL
162+
}
163+
164+
req := &http.Request{
165+
Method: "GET",
166+
URL: u,
167+
Proto: "HTTP/1.1",
168+
ProtoMajor: 1,
169+
ProtoMinor: 1,
170+
Header: make(http.Header),
171+
Host: u.Host,
172+
}
173+
174+
// Set the request headers using the capitalization for names and values in
175+
// RFC examples. Although the capitalization shouldn't matter, there are
176+
// servers that depend on it. The Header.Set method is not used because the
177+
// method canonicalizes the header names.
178+
req.Header["Upgrade"] = []string{"websocket"}
179+
req.Header["Connection"] = []string{"Upgrade"}
180+
req.Header["Sec-WebSocket-Key"] = []string{challengeKey}
181+
req.Header["Sec-WebSocket-Version"] = []string{"13"}
182+
if len(d.Subprotocols) > 0 {
183+
req.Header["Sec-WebSocket-Protocol"] = []string{strings.Join(d.Subprotocols, ", ")}
178184
}
185+
for k, vs := range requestHeader {
186+
if k == "Host" {
187+
if len(vs) > 0 {
188+
req.Host = vs[0]
189+
}
190+
} else {
191+
req.Header[k] = vs
192+
}
193+
}
194+
195+
hostPort, hostNoPort := hostPortNoPort(u)
179196

180197
var deadline time.Time
181198
if d.HandshakeTimeout != 0 {
@@ -203,7 +220,7 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re
203220
return nil, nil, err
204221
}
205222

206-
if u.Scheme == "wss" {
223+
if u.Scheme == "https" {
207224
cfg := d.TLSClientConfig
208225
if cfg == nil {
209226
cfg = &tls.Config{ServerName: hostNoPort}
@@ -224,45 +241,33 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re
224241
}
225242
}
226243

227-
if len(d.Subprotocols) > 0 {
228-
h := http.Header{}
229-
for k, v := range requestHeader {
230-
h[k] = v
231-
}
232-
h.Set("Sec-Websocket-Protocol", strings.Join(d.Subprotocols, ", "))
233-
requestHeader = h
234-
}
235-
236-
if len(requestHeader["Host"]) > 0 {
237-
// This can be used to supply a Host: header which is different from
238-
// the dial address.
239-
u.Host = requestHeader.Get("Host")
244+
conn := newConn(netConn, false, d.ReadBufferSize, d.WriteBufferSize)
240245

241-
// Drop "Host" header
242-
h := http.Header{}
243-
for k, v := range requestHeader {
244-
if k == "Host" {
245-
continue
246-
}
247-
h[k] = v
248-
}
249-
requestHeader = h
246+
if err := req.Write(netConn); err != nil {
247+
return nil, nil, err
250248
}
251249

252-
conn, resp, err := NewClient(netConn, u, requestHeader, d.ReadBufferSize, d.WriteBufferSize)
253-
250+
resp, err := http.ReadResponse(conn.br, req)
254251
if err != nil {
255-
if err == ErrBadHandshake {
256-
// Before closing the network connection on return from this
257-
// function, slurp up some of the response to aid application
258-
// debugging.
259-
buf := make([]byte, 1024)
260-
n, _ := io.ReadFull(resp.Body, buf)
261-
resp.Body = ioutil.NopCloser(bytes.NewReader(buf[:n]))
262-
}
263-
return nil, resp, err
252+
return nil, nil, err
253+
}
254+
if resp.StatusCode != 101 ||
255+
!strings.EqualFold(resp.Header.Get("Upgrade"), "websocket") ||
256+
!strings.EqualFold(resp.Header.Get("Connection"), "upgrade") ||
257+
resp.Header.Get("Sec-Websocket-Accept") != computeAcceptKey(challengeKey) {
258+
// Before closing the network connection on return from this
259+
// function, slurp up some of the response to aid application
260+
// debugging.
261+
buf := make([]byte, 1024)
262+
n, _ := io.ReadFull(resp.Body, buf)
263+
resp.Body = ioutil.NopCloser(bytes.NewReader(buf[:n]))
264+
return nil, resp, ErrBadHandshake
265+
} else {
266+
resp.Body = ioutil.NopCloser(bytes.NewReader([]byte{}))
264267
}
265268

269+
conn.subprotocol = resp.Header.Get("Sec-Websocket-Protocol")
270+
266271
netConn.SetDeadline(time.Time{})
267272
netConn = nil // to avoid close in defer.
268273
return conn, resp, nil

client_server_test.go

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -289,8 +289,8 @@ func TestRespOnBadHandshake(t *testing.T) {
289289
}
290290
}
291291

292-
// If the Host header is specified in `Dial()`, the server must receive it as
293-
// the `Host:` header.
292+
// TestHostHeader confirms that the host header provided in the call to Dial is
293+
// sent to the server.
294294
func TestHostHeader(t *testing.T) {
295295
s := newServer(t)
296296
defer s.Close()
@@ -305,16 +305,12 @@ func TestHostHeader(t *testing.T) {
305305
origHandler.ServeHTTP(w, r)
306306
})
307307

308-
ws, resp, err := cstDialer.Dial(s.URL, http.Header{"Host": {"testhost"}})
308+
ws, _, err := cstDialer.Dial(s.URL, http.Header{"Host": {"testhost"}})
309309
if err != nil {
310310
t.Fatalf("Dial: %v", err)
311311
}
312312
defer ws.Close()
313313

314-
if resp.StatusCode != http.StatusSwitchingProtocols {
315-
t.Fatalf("resp.StatusCode = %v, want http.StatusSwitchingProtocols", resp.StatusCode)
316-
}
317-
318314
if gotHost := <-specifiedHost; gotHost != "testhost" {
319315
t.Fatalf("gotHost = %q, want \"testhost\"", gotHost)
320316
}

0 commit comments

Comments
 (0)