Skip to content

Commit 92f9d1e

Browse files
authored
Merge pull request #52 from nhooyr/tests
Improve API and test coverage
2 parents 9a0d241 + 5716445 commit 92f9d1e

14 files changed

+545
-235
lines changed

README.md

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,13 +45,17 @@ fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
4545
}
4646
defer c.Close(websocket.StatusInternalError, "")
4747

48+
jc := websocket.JSONConn{
49+
Conn: c,
50+
}
51+
4852
ctx, cancel := context.WithTimeout(r.Context(), time.Second*10)
4953
defer cancel()
5054

5155
v := map[string]interface{}{
5256
"my_field": "foo",
5357
}
54-
err = websocket.WriteJSON(ctx, c, v)
58+
err = jc.Write(ctx, v)
5559
if err != nil {
5660
log.Printf("failed to write json: %v", err)
5761
return
@@ -73,7 +77,7 @@ For a production quality example that shows off the low level API, see the [echo
7377

7478
```go
7579
ctx := context.Background()
76-
ctx, cancel := context.WithTimeout(ctx, time.Second*10)
80+
ctx, cancel := context.WithTimeout(ctx, time.Minute)
7781
defer cancel()
7882

7983
c, _, err := websocket.Dial(ctx, "ws://localhost:8080",
@@ -84,8 +88,12 @@ if err != nil {
8488
}
8589
defer c.Close(websocket.StatusInternalError, "")
8690

91+
jc := websocket.JSONConn{
92+
Conn: c,
93+
}
94+
8795
var v interface{}
88-
err = websocket.ReadJSON(ctx, c, v)
96+
err = jc.Read(ctx, v)
8997
if err != nil {
9098
log.Fatalf("failed to read json: %v", err)
9199
}

accept.go

Lines changed: 47 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"crypto/sha1"
55
"encoding/base64"
66
"net/http"
7+
"net/textproto"
78
"net/url"
89
"strings"
910

@@ -45,56 +46,65 @@ func AcceptOrigins(origins ...string) AcceptOption {
4546
return acceptOrigins(origins)
4647
}
4748

48-
// Accept accepts a WebSocket handshake from a client and upgrades the
49-
// the connection to WebSocket.
50-
// Accept will reject the handshake if the Origin is not the same as the Host unless
51-
// InsecureAcceptOrigin is passed.
52-
// Accept uses w to write the handshake response so the timeouts on the http.Server apply.
53-
func Accept(w http.ResponseWriter, r *http.Request, opts ...AcceptOption) (*Conn, error) {
54-
var subprotocols []string
55-
origins := []string{r.Host}
56-
for _, opt := range opts {
57-
switch opt := opt.(type) {
58-
case acceptOrigins:
59-
origins = []string(opt)
60-
case acceptSubprotocols:
61-
subprotocols = []string(opt)
62-
}
63-
}
64-
65-
if !httpguts.HeaderValuesContainsToken(r.Header["Connection"], "Upgrade") {
49+
func verifyClientRequest(w http.ResponseWriter, r *http.Request) error {
50+
if !headerValuesContainsToken(r.Header, "Connection", "Upgrade") {
6651
err := xerrors.Errorf("websocket: protocol violation: Connection header does not contain Upgrade: %q", r.Header.Get("Connection"))
6752
http.Error(w, err.Error(), http.StatusBadRequest)
68-
return nil, err
53+
return err
6954
}
7055

71-
if !httpguts.HeaderValuesContainsToken(r.Header["Upgrade"], "websocket") {
56+
if !headerValuesContainsToken(r.Header, "Upgrade", "WebSocket") {
7257
err := xerrors.Errorf("websocket: protocol violation: Upgrade header does not contain websocket: %q", r.Header.Get("Upgrade"))
7358
http.Error(w, err.Error(), http.StatusBadRequest)
74-
return nil, err
59+
return err
7560
}
7661

7762
if r.Method != "GET" {
7863
err := xerrors.Errorf("websocket: protocol violation: handshake request method is not GET: %q", r.Method)
7964
http.Error(w, err.Error(), http.StatusBadRequest)
80-
return nil, err
65+
return err
8166
}
8267

8368
if r.Header.Get("Sec-WebSocket-Version") != "13" {
8469
err := xerrors.Errorf("websocket: unsupported protocol version: %q", r.Header.Get("Sec-WebSocket-Version"))
8570
http.Error(w, err.Error(), http.StatusBadRequest)
86-
return nil, err
71+
return err
8772
}
8873

8974
if r.Header.Get("Sec-WebSocket-Key") == "" {
9075
err := xerrors.New("websocket: protocol violation: missing Sec-WebSocket-Key")
9176
http.Error(w, err.Error(), http.StatusBadRequest)
77+
return err
78+
}
79+
80+
return nil
81+
}
82+
83+
// Accept accepts a WebSocket handshake from a client and upgrades the
84+
// the connection to WebSocket.
85+
// Accept will reject the handshake if the Origin is not the same as the Host unless
86+
// InsecureAcceptOrigin is passed.
87+
// Accept uses w to write the handshake response so the timeouts on the http.Server apply.
88+
func Accept(w http.ResponseWriter, r *http.Request, opts ...AcceptOption) (*Conn, error) {
89+
var subprotocols []string
90+
origins := []string{r.Host}
91+
for _, opt := range opts {
92+
switch opt := opt.(type) {
93+
case acceptOrigins:
94+
origins = []string(opt)
95+
case acceptSubprotocols:
96+
subprotocols = []string(opt)
97+
}
98+
}
99+
100+
err := verifyClientRequest(w, r)
101+
if err != nil {
92102
return nil, err
93103
}
94104

95105
origins = append(origins, r.Host)
96106

97-
err := authenticateOrigin(r, origins)
107+
err = authenticateOrigin(r, origins)
98108
if err != nil {
99109
http.Error(w, err.Error(), http.StatusForbidden)
100110
return nil, err
@@ -112,7 +122,10 @@ func Accept(w http.ResponseWriter, r *http.Request, opts ...AcceptOption) (*Conn
112122

113123
handleKey(w, r)
114124

115-
selectSubprotocol(w, r, subprotocols)
125+
subproto := selectSubprotocol(r, subprotocols)
126+
if subproto != "" {
127+
w.Header().Set("Sec-WebSocket-Protocol", subproto)
128+
}
116129

117130
w.WriteHeader(http.StatusSwitchingProtocols)
118131

@@ -134,16 +147,18 @@ func Accept(w http.ResponseWriter, r *http.Request, opts ...AcceptOption) (*Conn
134147
return c, nil
135148
}
136149

137-
func selectSubprotocol(w http.ResponseWriter, r *http.Request, subprotocols []string) {
138-
clientSubprotocols := strings.Split(r.Header.Get("Sec-WebSocket-Protocol"), ",")
150+
func headerValuesContainsToken(h http.Header, key, val string) bool {
151+
key = textproto.CanonicalMIMEHeaderKey(key)
152+
return httpguts.HeaderValuesContainsToken(h[key], val)
153+
}
154+
155+
func selectSubprotocol(r *http.Request, subprotocols []string) string {
139156
for _, sp := range subprotocols {
140-
for _, cp := range clientSubprotocols {
141-
if sp == strings.TrimSpace(cp) {
142-
w.Header().Set("Sec-WebSocket-Protocol", sp)
143-
return
144-
}
157+
if headerValuesContainsToken(r.Header, "Sec-WebSocket-Protocol", sp) {
158+
return sp
145159
}
146160
}
161+
return ""
147162
}
148163

149164
var keyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11")

accept_test.go

Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
package websocket
2+
3+
import (
4+
"net/http/httptest"
5+
"strings"
6+
"testing"
7+
)
8+
9+
func Test_verifyClientHandshake(t *testing.T) {
10+
t.Parallel()
11+
12+
testCases := []struct {
13+
name string
14+
method string
15+
h map[string]string
16+
success bool
17+
}{
18+
{
19+
name: "badConnection",
20+
h: map[string]string{
21+
"Connection": "notUpgrade",
22+
},
23+
},
24+
{
25+
name: "badUpgrade",
26+
h: map[string]string{
27+
"Connection": "Upgrade",
28+
"Upgrade": "notWebSocket",
29+
},
30+
},
31+
{
32+
name: "badMethod",
33+
method: "POST",
34+
h: map[string]string{
35+
"Connection": "Upgrade",
36+
"Upgrade": "websocket",
37+
},
38+
},
39+
{
40+
name: "badWebSocketVersion",
41+
h: map[string]string{
42+
"Connection": "Upgrade",
43+
"Upgrade": "websocket",
44+
"Sec-WebSocket-Version": "14",
45+
},
46+
},
47+
{
48+
name: "badWebSocketKey",
49+
h: map[string]string{
50+
"Connection": "Upgrade",
51+
"Upgrade": "websocket",
52+
"Sec-WebSocket-Version": "13",
53+
"Sec-WebSocket-Key": "",
54+
},
55+
},
56+
{
57+
name: "success",
58+
h: map[string]string{
59+
"Connection": "Upgrade",
60+
"Upgrade": "websocket",
61+
"Sec-WebSocket-Version": "13",
62+
"Sec-WebSocket-Key": "meow123",
63+
},
64+
success: true,
65+
},
66+
}
67+
68+
for _, tc := range testCases {
69+
tc := tc
70+
t.Run(tc.name, func(t *testing.T) {
71+
t.Parallel()
72+
73+
w := httptest.NewRecorder()
74+
r := httptest.NewRequest(tc.method, "/", nil)
75+
76+
for k, v := range tc.h {
77+
r.Header.Set(k, v)
78+
}
79+
80+
err := verifyClientRequest(w, r)
81+
if (err == nil) != tc.success {
82+
t.Fatalf("unexpected error value: %+v", err)
83+
}
84+
})
85+
}
86+
}
87+
88+
func Test_selectSubprotocol(t *testing.T) {
89+
t.Parallel()
90+
91+
testCases := []struct {
92+
name string
93+
clientProtocols []string
94+
serverProtocols []string
95+
negotiated string
96+
}{
97+
{
98+
name: "empty",
99+
clientProtocols: nil,
100+
serverProtocols: nil,
101+
negotiated: "",
102+
},
103+
{
104+
name: "basic",
105+
clientProtocols: []string{"echo", "echo2"},
106+
serverProtocols: []string{"echo2", "echo"},
107+
negotiated: "echo2",
108+
},
109+
{
110+
name: "none",
111+
clientProtocols: []string{"echo", "echo3"},
112+
serverProtocols: []string{"echo2", "echo4"},
113+
negotiated: "",
114+
},
115+
{
116+
name: "fallback",
117+
clientProtocols: []string{"echo", "echo3"},
118+
serverProtocols: []string{"echo2", "echo3"},
119+
negotiated: "echo3",
120+
},
121+
}
122+
123+
for _, tc := range testCases {
124+
tc := tc
125+
t.Run(tc.name, func(t *testing.T) {
126+
t.Parallel()
127+
128+
r := httptest.NewRequest("GET", "/", nil)
129+
r.Header.Set("Sec-WebSocket-Protocol", strings.Join(tc.clientProtocols, ","))
130+
131+
negotiated := selectSubprotocol(r, tc.serverProtocols)
132+
if tc.negotiated != negotiated {
133+
t.Fatalf("expected %q but got %q", tc.negotiated, negotiated)
134+
}
135+
})
136+
}
137+
}
138+
139+
func Test_authenticateOrigin(t *testing.T) {
140+
t.Parallel()
141+
142+
testCases := []struct {
143+
name string
144+
origin string
145+
authorizedOrigins []string
146+
success bool
147+
}{
148+
{
149+
name: "none",
150+
success: true,
151+
},
152+
{
153+
name: "invalid",
154+
origin: "$#)(*)$#@*$(#@*$)#@*%)#(@*%)#(@%#@$#@$#$#@$#@}{}{}",
155+
success: false,
156+
},
157+
{
158+
name: "unauthorized",
159+
origin: "https://example.com",
160+
authorizedOrigins: []string{"example1.com"},
161+
success: false,
162+
},
163+
{
164+
name: "authorized",
165+
origin: "https://example.com",
166+
authorizedOrigins: []string{"example.com"},
167+
success: true,
168+
},
169+
{
170+
name: "authorizedCaseInsensitive",
171+
origin: "https://examplE.com",
172+
authorizedOrigins: []string{"example.com"},
173+
success: true,
174+
},
175+
}
176+
177+
for _, tc := range testCases {
178+
tc := tc
179+
t.Run(tc.name, func(t *testing.T) {
180+
t.Parallel()
181+
182+
r := httptest.NewRequest("GET", "/", nil)
183+
r.Header.Set("Origin", tc.origin)
184+
185+
err := authenticateOrigin(r, tc.authorizedOrigins)
186+
if (err == nil) != tc.success {
187+
t.Fatalf("unexpected error value: %+v", err)
188+
}
189+
})
190+
}
191+
}

datatype.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,6 @@ type DataType int
77

88
// DataType constants.
99
const (
10-
Text DataType = DataType(opText)
11-
Binary DataType = DataType(opBinary)
10+
DataText DataType = DataType(opText)
11+
DataBinary DataType = DataType(opBinary)
1212
)

datatype_string.go

Lines changed: 4 additions & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)