From e16f830927fe3b7e87868d2d817560025190e44d Mon Sep 17 00:00:00 2001 From: Anmol Sethi Date: Sat, 30 Mar 2019 23:04:10 -0500 Subject: [PATCH 1/6] Improve edge case testing --- accept.go | 79 +++++++++++-------- accept_test.go | 191 +++++++++++++++++++++++++++++++++++++++++++++ datatype.go | 4 +- datatype_string.go | 4 +- dial.go | 37 +++++---- dial_test.go | 65 +++++++++++++++ header_test.go | 4 + statuscode.go | 12 +-- websocket.go | 102 +++++++++++++----------- websocket_test.go | 2 - 10 files changed, 396 insertions(+), 104 deletions(-) create mode 100644 accept_test.go create mode 100644 dial_test.go diff --git a/accept.go b/accept.go index 70ad2f06..63bccc40 100644 --- a/accept.go +++ b/accept.go @@ -4,6 +4,7 @@ import ( "crypto/sha1" "encoding/base64" "net/http" + "net/textproto" "net/url" "strings" @@ -45,56 +46,65 @@ func AcceptOrigins(origins ...string) AcceptOption { return acceptOrigins(origins) } -// Accept accepts a WebSocket handshake from a client and upgrades the -// the connection to WebSocket. -// Accept will reject the handshake if the Origin is not the same as the Host unless -// InsecureAcceptOrigin is passed. -// Accept uses w to write the handshake response so the timeouts on the http.Server apply. -func Accept(w http.ResponseWriter, r *http.Request, opts ...AcceptOption) (*Conn, error) { - var subprotocols []string - origins := []string{r.Host} - for _, opt := range opts { - switch opt := opt.(type) { - case acceptOrigins: - origins = []string(opt) - case acceptSubprotocols: - subprotocols = []string(opt) - } - } - - if !httpguts.HeaderValuesContainsToken(r.Header["Connection"], "Upgrade") { +func verifyClientRequest(w http.ResponseWriter, r *http.Request) error { + if !headerValuesContainsToken(r.Header, "Connection", "Upgrade") { err := xerrors.Errorf("websocket: protocol violation: Connection header does not contain Upgrade: %q", r.Header.Get("Connection")) http.Error(w, err.Error(), http.StatusBadRequest) - return nil, err + return err } - if !httpguts.HeaderValuesContainsToken(r.Header["Upgrade"], "websocket") { + if !headerValuesContainsToken(r.Header, "Upgrade", "WebSocket") { err := xerrors.Errorf("websocket: protocol violation: Upgrade header does not contain websocket: %q", r.Header.Get("Upgrade")) http.Error(w, err.Error(), http.StatusBadRequest) - return nil, err + return err } if r.Method != "GET" { err := xerrors.Errorf("websocket: protocol violation: handshake request method is not GET: %q", r.Method) http.Error(w, err.Error(), http.StatusBadRequest) - return nil, err + return err } if r.Header.Get("Sec-WebSocket-Version") != "13" { err := xerrors.Errorf("websocket: unsupported protocol version: %q", r.Header.Get("Sec-WebSocket-Version")) http.Error(w, err.Error(), http.StatusBadRequest) - return nil, err + return err } if r.Header.Get("Sec-WebSocket-Key") == "" { err := xerrors.New("websocket: protocol violation: missing Sec-WebSocket-Key") http.Error(w, err.Error(), http.StatusBadRequest) + return err + } + + return nil +} + +// Accept accepts a WebSocket handshake from a client and upgrades the +// the connection to WebSocket. +// Accept will reject the handshake if the Origin is not the same as the Host unless +// InsecureAcceptOrigin is passed. +// Accept uses w to write the handshake response so the timeouts on the http.Server apply. +func Accept(w http.ResponseWriter, r *http.Request, opts ...AcceptOption) (*Conn, error) { + var subprotocols []string + origins := []string{r.Host} + for _, opt := range opts { + switch opt := opt.(type) { + case acceptOrigins: + origins = []string(opt) + case acceptSubprotocols: + subprotocols = []string(opt) + } + } + + err := verifyClientRequest(w, r) + if err != nil { return nil, err } origins = append(origins, r.Host) - err := authenticateOrigin(r, origins) + err = authenticateOrigin(r, origins) if err != nil { http.Error(w, err.Error(), http.StatusForbidden) return nil, err @@ -112,7 +122,10 @@ func Accept(w http.ResponseWriter, r *http.Request, opts ...AcceptOption) (*Conn handleKey(w, r) - selectSubprotocol(w, r, subprotocols) + subproto := selectSubprotocol(r, subprotocols) + if subproto != "" { + w.Header().Set("Sec-WebSocket-Protocol", subproto) + } w.WriteHeader(http.StatusSwitchingProtocols) @@ -134,16 +147,18 @@ func Accept(w http.ResponseWriter, r *http.Request, opts ...AcceptOption) (*Conn return c, nil } -func selectSubprotocol(w http.ResponseWriter, r *http.Request, subprotocols []string) { - clientSubprotocols := strings.Split(r.Header.Get("Sec-WebSocket-Protocol"), ",") +func headerValuesContainsToken(h http.Header, key, val string) bool { + key = textproto.CanonicalMIMEHeaderKey(key) + return httpguts.HeaderValuesContainsToken(h[key], val) +} + +func selectSubprotocol(r *http.Request, subprotocols []string) string { for _, sp := range subprotocols { - for _, cp := range clientSubprotocols { - if sp == strings.TrimSpace(cp) { - w.Header().Set("Sec-WebSocket-Protocol", sp) - return - } + if headerValuesContainsToken(r.Header, "Sec-WebSocket-Protocol", sp) { + return sp } } + return "" } var keyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11") diff --git a/accept_test.go b/accept_test.go new file mode 100644 index 00000000..4b5214dd --- /dev/null +++ b/accept_test.go @@ -0,0 +1,191 @@ +package websocket + +import ( + "net/http/httptest" + "strings" + "testing" +) + +func Test_verifyClientHandshake(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + method string + h map[string]string + success bool + }{ + { + name: "badConnection", + h: map[string]string{ + "Connection": "notUpgrade", + }, + }, + { + name: "badUpgrade", + h: map[string]string{ + "Connection": "Upgrade", + "Upgrade": "notWebSocket", + }, + }, + { + name: "badMethod", + method: "POST", + h: map[string]string{ + "Connection": "Upgrade", + "Upgrade": "websocket", + }, + }, + { + name: "badWebSocketVersion", + h: map[string]string{ + "Connection": "Upgrade", + "Upgrade": "websocket", + "Sec-WebSocket-Version": "14", + }, + }, + { + name: "badWebSocketKey", + h: map[string]string{ + "Connection": "Upgrade", + "Upgrade": "websocket", + "Sec-WebSocket-Version": "13", + "Sec-WebSocket-Key": "", + }, + }, + { + name: "success", + h: map[string]string{ + "Connection": "Upgrade", + "Upgrade": "websocket", + "Sec-WebSocket-Version": "13", + "Sec-WebSocket-Key": "meow123", + }, + success: true, + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + w := httptest.NewRecorder() + r := httptest.NewRequest(tc.method, "/", nil) + + for k, v := range tc.h { + r.Header.Set(k, v) + } + + err := verifyClientRequest(w, r) + if (err == nil) != tc.success { + t.Fatalf("unexpected error value: %+v", err) + } + }) + } +} + +func Test_selectSubprotocol(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + clientProtocols []string + serverProtocols []string + negotiated string + }{ + { + name: "empty", + clientProtocols: nil, + serverProtocols: nil, + negotiated: "", + }, + { + name: "basic", + clientProtocols: []string{"echo", "echo2"}, + serverProtocols: []string{"echo2", "echo"}, + negotiated: "echo2", + }, + { + name: "none", + clientProtocols: []string{"echo", "echo3"}, + serverProtocols: []string{"echo2", "echo4"}, + negotiated: "", + }, + { + name: "fallback", + clientProtocols: []string{"echo", "echo3"}, + serverProtocols: []string{"echo2", "echo3"}, + negotiated: "echo3", + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + r := httptest.NewRequest("GET", "/", nil) + r.Header.Set("Sec-WebSocket-Protocol", strings.Join(tc.clientProtocols, ",")) + + negotiated := selectSubprotocol(r, tc.serverProtocols) + if tc.negotiated != negotiated { + t.Fatalf("expected %q but got %q", tc.negotiated, negotiated) + } + }) + } +} + +func Test_authenticateOrigin(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + origin string + authorizedOrigins []string + success bool + }{ + { + name: "none", + success: true, + }, + { + name: "invalid", + origin: "$#)(*)$#@*$(#@*$)#@*%)#(@*%)#(@%#@$#@$#$#@$#@}{}{}", + success: false, + }, + { + name: "unauthorized", + origin: "https://example.com", + authorizedOrigins: []string{"example1.com"}, + success: false, + }, + { + name: "authorized", + origin: "https://example.com", + authorizedOrigins: []string{"example.com"}, + success: true, + }, + { + name: "authorizedCaseInsensitive", + origin: "https://examplE.com", + authorizedOrigins: []string{"example.com"}, + success: true, + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + r := httptest.NewRequest("GET", "/", nil) + r.Header.Set("Origin", tc.origin) + + err := authenticateOrigin(r, tc.authorizedOrigins) + if (err == nil) != tc.success { + t.Fatalf("unexpected error value: %+v", err) + } + }) + } +} diff --git a/datatype.go b/datatype.go index c2473dfc..a1d8d575 100644 --- a/datatype.go +++ b/datatype.go @@ -7,6 +7,6 @@ type DataType int // DataType constants. const ( - Text DataType = DataType(opText) - Binary DataType = DataType(opBinary) + DataText DataType = DataType(opText) + DataBinary DataType = DataType(opBinary) ) diff --git a/datatype_string.go b/datatype_string.go index 215340c3..60a85c31 100644 --- a/datatype_string.go +++ b/datatype_string.go @@ -8,8 +8,8 @@ func _() { // An "invalid array index" compiler error signifies that the constant values have changed. // Re-run the stringer command to generate them again. var x [1]struct{} - _ = x[Text-1] - _ = x[Binary-2] + _ = x[DataText-1] + _ = x[DataBinary-2] } const _DataType_name = "TextBinary" diff --git a/dial.go b/dial.go index 29ed9b21..99e3c06e 100644 --- a/dial.go +++ b/dial.go @@ -11,7 +11,6 @@ import ( "net/url" "strings" - "golang.org/x/net/http/httpguts" "golang.org/x/xerrors" ) @@ -112,22 +111,11 @@ func Dial(ctx context.Context, u string, opts ...DialOption) (_ *Conn, _ *http.R } }() - if resp.StatusCode != http.StatusSwitchingProtocols { - return nil, resp, xerrors.Errorf("websocket: expected status code %v but got %v", http.StatusSwitchingProtocols, resp.StatusCode) - } - - if !httpguts.HeaderValuesContainsToken(resp.Header["Connection"], "Upgrade") { - return nil, resp, xerrors.Errorf("websocket: protocol violation: Connection header does not contain Upgrade: %q", resp.Header.Get("Connection")) - } - - if !httpguts.HeaderValuesContainsToken(resp.Header["Upgrade"], "websocket") { - return nil, resp, xerrors.Errorf("websocket: protocol violation: Upgrade header does not contain websocket: %q", resp.Header.Get("Upgrade")) - + err = verifyServerResponse(resp) + if err != nil { + return nil, resp, err } - // We do not care about Sec-WebSocket-Accept because it does not matter. - // See the secWebSocketKey global variable. - rwc, ok := resp.Body.(io.ReadWriteCloser) if !ok { return nil, resp, xerrors.Errorf("websocket: body is not a read write closer but should be: %T", rwc) @@ -144,3 +132,22 @@ func Dial(ctx context.Context, u string, opts ...DialOption) (_ *Conn, _ *http.R return c, resp, nil } + +func verifyServerResponse(resp *http.Response) error { + if resp.StatusCode != http.StatusSwitchingProtocols { + return xerrors.Errorf("websocket: expected status code %v but got %v", http.StatusSwitchingProtocols, resp.StatusCode) + } + + if !headerValuesContainsToken(resp.Header, "Connection", "Upgrade") { + return xerrors.Errorf("websocket: protocol violation: Connection header does not contain Upgrade: %q", resp.Header.Get("Connection")) + } + + if !headerValuesContainsToken(resp.Header, "Upgrade", "WebSocket") { + return xerrors.Errorf("websocket: protocol violation: Upgrade header does not contain websocket: %q", resp.Header.Get("Upgrade")) + } + + // We do not care about Sec-WebSocket-Accept because it does not matter. + // See the secWebSocketKey global variable. + + return nil +} diff --git a/dial_test.go b/dial_test.go new file mode 100644 index 00000000..48c1c312 --- /dev/null +++ b/dial_test.go @@ -0,0 +1,65 @@ +package websocket + +import ( + "net/http" + "net/http/httptest" + "testing" +) + +func Test_verifyServerHandshake(t *testing.T) { + testCases := []struct { + name string + response func(w http.ResponseWriter) + success bool + }{ + { + name: "badStatus", + response: func(w http.ResponseWriter) { + w.WriteHeader(http.StatusOK) + }, + success: false, + }, + { + name: "badConnection", + response: func(w http.ResponseWriter) { + w.Header().Set("Connection", "???") + w.WriteHeader(http.StatusSwitchingProtocols) + }, + success: false, + }, + { + name: "badUpgrade", + response: func(w http.ResponseWriter) { + w.Header().Set("Connection", "Upgrade") + w.Header().Set("Upgrade", "???") + w.WriteHeader(http.StatusSwitchingProtocols) + }, + success: false, + }, + { + name: "success", + response: func(w http.ResponseWriter) { + w.Header().Set("Connection", "Upgrade") + w.Header().Set("Upgrade", "websocket") + w.WriteHeader(http.StatusSwitchingProtocols) + }, + success: true, + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + w := httptest.NewRecorder() + tc.response(w) + resp := w.Result() + + err := verifyServerResponse(resp) + if (err == nil) != tc.success { + t.Fatalf("unexpected error: %+v", err) + } + }) + } +} diff --git a/header_test.go b/header_test.go index 04861e81..aefd98d6 100644 --- a/header_test.go +++ b/header_test.go @@ -18,6 +18,10 @@ func randBool() bool { } func TestHeader(t *testing.T) { + +} + +func TestFuzzHeader(t *testing.T) { t.Parallel() for i := 0; i < 1000; i++ { diff --git a/statuscode.go b/statuscode.go index 40f86090..596f78bc 100644 --- a/statuscode.go +++ b/statuscode.go @@ -68,11 +68,12 @@ func parseClosePayload(p []byte) (code StatusCode, reason string, err error) { // See http://www.iana.org/assignments/websocket/websocket.xhtml#close-code-number // and https://tools.ietf.org/html/rfc6455#section-7.4.1 var validReceivedCloseCodes = map[StatusCode]bool{ - StatusNormalClosure: true, - StatusGoingAway: true, - StatusProtocolError: true, - StatusUnsupportedData: true, - StatusNoStatusRcvd: false, + StatusNormalClosure: true, + StatusGoingAway: true, + StatusProtocolError: true, + StatusUnsupportedData: true, + StatusNoStatusRcvd: false, + // TODO use StatusAbnormalClosure: false, StatusInvalidFramePayloadData: true, StatusPolicyViolation: true, @@ -90,6 +91,7 @@ func validCloseCode(code StatusCode) bool { const maxControlFramePayload = 125 +// TODO make method on CloseError func closePayload(code StatusCode, reason string) ([]byte, error) { if len(reason) > maxControlFramePayload-2 { return nil, xerrors.Errorf("reason string max is %v but got %q with length %v", maxControlFramePayload-2, reason, len(reason)) diff --git a/websocket.go b/websocket.go index ab1baf64..2ed4b5dc 100644 --- a/websocket.go +++ b/websocket.go @@ -50,10 +50,10 @@ type Conn struct { } func (c *Conn) getCloseErr() error { - if c.closeErr == nil { - return xerrors.New("websocket: use of closed connection") + if c.closeErr != nil { + return c.closeErr } - return c.closeErr + return nil } func (c *Conn) close(err error) { @@ -235,11 +235,9 @@ func (c *Conn) handleControl(h header) { } c.Close(code, reason) } else { - ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) - defer cancel() - - c.writeControl(ctx, opClose, nil) - c.close(nil) + c.writeClose(nil, CloseError{ + Code: StatusNoStatusRcvd, + }) } default: panic(fmt.Sprintf("websocket: unexpected control opcode: %#v", h)) @@ -340,40 +338,8 @@ func (c *Conn) writePong(p []byte) error { return err } -// MessageWriter returns a writer bounded by the context that will write -// a WebSocket data frame of type dataType to the connection. -// Ensure you close the MessageWriter once you have written to entire message. -// Concurrent calls to MessageWriter are ok. -func (c *Conn) MessageWriter(dataType DataType) *MessageWriter { - return &MessageWriter{ - c: c, - ctx: context.Background(), - datatype: dataType, - } -} - -// ReadMessage will wait until there is a WebSocket data frame to read from the connection. -// It returns the type of the data, a reader to read it and also an error. -// Please use SetContext on the reader to bound the read operation. -// Your application must keep reading messages for the Conn to automatically respond to ping -// and close frames. -func (c *Conn) ReadMessage(ctx context.Context) (DataType, *MessageReader, error) { - select { - case <-c.closed: - return 0, nil, xerrors.Errorf("failed to read message: %w", c.getCloseErr()) - case opcode := <-c.read: - return DataType(opcode), &MessageReader{ - ctx: context.Background(), - c: c, - }, nil - case <-ctx.Done(): - return 0, nil, xerrors.Errorf("failed to read message: %w", ctx.Err()) - } -} - // Close closes the WebSocket connection with the given status code and reason. // It will write a WebSocket close frame with a timeout of 5 seconds. -// TODO close error should become c.closeErr to indicate we closed. func (c *Conn) Close(code StatusCode, reason string) error { // This function also will not wait for a close frame from the peer like the RFC // wants because that makes no sense and I don't think anyone actually follows that. @@ -383,20 +349,33 @@ func (c *Conn) Close(code StatusCode, reason string) error { p, _ = closePayload(StatusInternalError, fmt.Sprintf("websocket: application tried to send code %v but code or reason was invalid", code)) } - ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) - defer cancel() - - err = c.writeControl(ctx, opClose, p) + err2 := c.writeClose(p, CloseError{ + Code: code, + Reason: reason, + }) if err != nil { return err } + return err2 +} - c.close(nil) +func (c *Conn) writeClose(p []byte, cerr CloseError) error { + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + + err := c.writeControl(ctx, opClose, p) + + c.close(cerr) if err != nil { return err } - return c.closeErr + + if cerr != c.closeErr { + return c.closeErr + } + + return nil } func (c *Conn) writeControl(ctx context.Context, opcode opcode, p []byte) error { @@ -422,6 +401,18 @@ func (c *Conn) writeControl(ctx context.Context, opcode opcode, p []byte) error } } +// MessageWriter returns a writer bounded by the context that will write +// a WebSocket data frame of type dataType to the connection. +// Ensure you close the MessageWriter once you have written to entire message. +// Concurrent calls to MessageWriter are ok. +func (c *Conn) MessageWriter(dataType DataType) *MessageWriter { + return &MessageWriter{ + c: c, + ctx: context.Background(), + datatype: dataType, + } +} + // MessageWriter enables writing to a WebSocket connection. // Ensure you close the MessageWriter once you have written to entire message. type MessageWriter struct { @@ -496,6 +487,25 @@ func (w *MessageWriter) Close() error { } } +// ReadMessage will wait until there is a WebSocket data frame to read from the connection. +// It returns the type of the data, a reader to read it and also an error. +// Please use SetContext on the reader to bound the read operation. +// Your application must keep reading messages for the Conn to automatically respond to ping +// and close frames. +func (c *Conn) ReadMessage(ctx context.Context) (DataType, *MessageReader, error) { + select { + case <-c.closed: + return 0, nil, xerrors.Errorf("failed to read message: %w", c.getCloseErr()) + case opcode := <-c.read: + return DataType(opcode), &MessageReader{ + ctx: context.Background(), + c: c, + }, nil + case <-ctx.Done(): + return 0, nil, xerrors.Errorf("failed to read message: %w", ctx.Err()) + } +} + // MessageReader enables reading a data frame from the WebSocket connection. type MessageReader struct { n int diff --git a/websocket_test.go b/websocket_test.go index 3119d1b4..61384af7 100644 --- a/websocket_test.go +++ b/websocket_test.go @@ -378,13 +378,11 @@ func echoLoop(ctx context.Context, c *websocket.Conn, t *testing.T) { for { err := echo() if err != nil { - // t.Logf("%v: failed to echo message: %+v", time.Now(), err) return } } } -// TODO // https://github.com/crossbario/autobahn-python/blob/master/wstest/testee_client_aio.py func TestAutobahnClient(t *testing.T) { t.Parallel() From 0e788123439ef0bc7a4e0736b31956df5044ad26 Mon Sep 17 00:00:00 2001 From: Anmol Sethi Date: Sat, 30 Mar 2019 23:04:10 -0500 Subject: [PATCH 2/6] Improve JSON API Closes #50 --- README.md | 14 +++++++++++--- datatype_string.go | 4 ++-- example_test.go | 12 ++++++++++-- json.go | 43 ++++++++++++++++++++++++++++++++----------- websocket_test.go | 12 ++++++++++-- 5 files changed, 65 insertions(+), 20 deletions(-) diff --git a/README.md b/README.md index 34bd4803..d2261bbf 100644 --- a/README.md +++ b/README.md @@ -45,13 +45,17 @@ fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { } defer c.Close(websocket.StatusInternalError, "") + jc := websocket.JSONConn{ + Conn: c, + } + ctx, cancel := context.WithTimeout(r.Context(), time.Second*10) defer cancel() v := map[string]interface{}{ "my_field": "foo", } - err = websocket.WriteJSON(ctx, c, v) + err = jc.Write(ctx, v) if err != nil { log.Printf("failed to write json: %v", err) return @@ -73,7 +77,7 @@ For a production quality example that shows off the low level API, see the [echo ```go ctx := context.Background() -ctx, cancel := context.WithTimeout(ctx, time.Second*10) +ctx, cancel := context.WithTimeout(ctx, time.Minute) defer cancel() c, _, err := websocket.Dial(ctx, "ws://localhost:8080", @@ -84,8 +88,12 @@ if err != nil { } defer c.Close(websocket.StatusInternalError, "") +jc := websocket.JSONConn{ + Conn: c, +} + var v interface{} -err = websocket.ReadJSON(ctx, c, v) +err = jc.Read(ctx, v) if err != nil { log.Fatalf("failed to read json: %v", err) } diff --git a/datatype_string.go b/datatype_string.go index 60a85c31..1b4aaba5 100644 --- a/datatype_string.go +++ b/datatype_string.go @@ -12,9 +12,9 @@ func _() { _ = x[DataBinary-2] } -const _DataType_name = "TextBinary" +const _DataType_name = "DataTextDataBinary" -var _DataType_index = [...]uint8{0, 4, 10} +var _DataType_index = [...]uint8{0, 8, 18} func (i DataType) String() string { i -= 1 diff --git a/example_test.go b/example_test.go index 0b15fab7..5e6d0729 100644 --- a/example_test.go +++ b/example_test.go @@ -88,13 +88,17 @@ func ExampleAccept() { } defer c.Close(websocket.StatusInternalError, "") + jc := websocket.JSONConn{ + Conn: c, + } + ctx, cancel := context.WithTimeout(r.Context(), time.Second*10) defer cancel() v := map[string]interface{}{ "my_field": "foo", } - err = websocket.WriteJSON(ctx, c, v) + err = jc.Write(ctx, v) if err != nil { log.Printf("failed to write json: %v", err) return @@ -123,8 +127,12 @@ func ExampleDial() { } defer c.Close(websocket.StatusInternalError, "") + jc := websocket.JSONConn{ + Conn: c, + } + var v interface{} - err = websocket.ReadJSON(ctx, c, v) + err = jc.Read(ctx, v) if err != nil { log.Fatalf("failed to read json: %v", err) } diff --git a/json.go b/json.go index ca4ac924..ebe0dfdd 100644 --- a/json.go +++ b/json.go @@ -7,15 +7,28 @@ import ( "golang.org/x/xerrors" ) -// ReadJSON reads a json message from c into v. -func ReadJSON(ctx context.Context, c *Conn, v interface{}) error { - typ, r, err := c.ReadMessage(ctx) +// JSONConn wraps around a Conn with JSON helpers. +type JSONConn struct { + Conn *Conn +} + +// Read reads a json message into v. +func (jc JSONConn) Read(ctx context.Context, v interface{}) error { + err := jc.read(ctx, v) if err != nil { return xerrors.Errorf("failed to read json: %w", err) } + return nil +} + +func (jc *JSONConn) read(ctx context.Context, v interface{}) error { + typ, r, err := jc.Conn.ReadMessage(ctx) + if err != nil { + return err + } - if typ != Text { - return xerrors.Errorf("unexpected frame type for json (expected TextFrame): %v", typ) + if typ != DataText { + return xerrors.Errorf("unexpected frame type for json (expected DataText): %v", typ) } r.Limit(131072) @@ -24,25 +37,33 @@ func ReadJSON(ctx context.Context, c *Conn, v interface{}) error { d := json.NewDecoder(r) err = d.Decode(v) if err != nil { - return xerrors.Errorf("failed to read json: %w", err) + return xerrors.Errorf("failed to decode json: %w", err) } return nil } -// WriteJSON writes the json message v into c. -func WriteJSON(ctx context.Context, c *Conn, v interface{}) error { - w := c.MessageWriter(Text) +// Write writes the json message v. +func (jc JSONConn) Write(ctx context.Context, v interface{}) error { + err := jc.write(ctx, v) + if err != nil { + return xerrors.Errorf("failed to write json: %w", err) + } + return nil +} + +func (jc JSONConn) write(ctx context.Context, v interface{}) error { + w := jc.Conn.MessageWriter(DataText) w.SetContext(ctx) e := json.NewEncoder(w) err := e.Encode(v) if err != nil { - return xerrors.Errorf("failed to write json: %w", err) + return xerrors.Errorf("failed to encode json: %w", err) } err = w.Close() if err != nil { - return xerrors.Errorf("failed to write json: %w", err) + return err } return nil } diff --git a/websocket_test.go b/websocket_test.go index 61384af7..e91e5b22 100644 --- a/websocket_test.go +++ b/websocket_test.go @@ -173,10 +173,14 @@ func TestHandshake(t *testing.T) { ctx, cancel := context.WithTimeout(r.Context(), time.Second*5) defer cancel() + jc := websocket.JSONConn{ + Conn: c, + } + v := map[string]interface{}{ "anmol": "wowow", } - err = websocket.WriteJSON(ctx, c, v) + err = jc.Write(ctx, v) if err != nil { return err } @@ -191,8 +195,12 @@ func TestHandshake(t *testing.T) { } defer c.Close(websocket.StatusInternalError, "") + jc := websocket.JSONConn{ + Conn: c, + } + var v interface{} - err = websocket.ReadJSON(ctx, c, &v) + err = jc.Read(ctx, &v) if err != nil { return err } From 9213cc7bf60ddd78d0e6da16c2641eb83e0cf0f5 Mon Sep 17 00:00:00 2001 From: Anmol Sethi Date: Sat, 30 Mar 2019 23:04:10 -0500 Subject: [PATCH 3/6] Significantly simplify core API and the godoc --- example_test.go | 9 ++----- json.go | 7 +++-- websocket.go | 65 +++++++++++++++-------------------------------- websocket_test.go | 5 +--- 4 files changed, 26 insertions(+), 60 deletions(-) diff --git a/example_test.go b/example_test.go index 5e6d0729..b3ed2a54 100644 --- a/example_test.go +++ b/example_test.go @@ -34,14 +34,9 @@ func ExampleAccept_echo() { return err } - ctx, cancel = context.WithTimeout(ctx, time.Second*10) - defer cancel() - - r.SetContext(ctx) - r.Limit(32768) + r = io.LimitReader(r, 32768) - w := c.MessageWriter(typ) - w.SetContext(ctx) + w := c.MessageWriter(ctx, typ) _, err = io.Copy(w, r) if err != nil { return err diff --git a/json.go b/json.go index ebe0dfdd..514be050 100644 --- a/json.go +++ b/json.go @@ -3,6 +3,7 @@ package websocket import ( "context" "encoding/json" + "io" "golang.org/x/xerrors" ) @@ -31,8 +32,7 @@ func (jc *JSONConn) read(ctx context.Context, v interface{}) error { return xerrors.Errorf("unexpected frame type for json (expected DataText): %v", typ) } - r.Limit(131072) - r.SetContext(ctx) + r = io.LimitReader(r, 131072) d := json.NewDecoder(r) err = d.Decode(v) @@ -52,8 +52,7 @@ func (jc JSONConn) Write(ctx context.Context, v interface{}) error { } func (jc JSONConn) write(ctx context.Context, v interface{}) error { - w := jc.Conn.MessageWriter(DataText) - w.SetContext(ctx) + w := jc.Conn.MessageWriter(ctx, DataText) e := json.NewEncoder(w) err := e.Encode(v) diff --git a/websocket.go b/websocket.go index 2ed4b5dc..99d28559 100644 --- a/websocket.go +++ b/websocket.go @@ -403,19 +403,19 @@ func (c *Conn) writeControl(ctx context.Context, opcode opcode, p []byte) error // MessageWriter returns a writer bounded by the context that will write // a WebSocket data frame of type dataType to the connection. -// Ensure you close the MessageWriter once you have written to entire message. -// Concurrent calls to MessageWriter are ok. -func (c *Conn) MessageWriter(dataType DataType) *MessageWriter { - return &MessageWriter{ +// Ensure you close the messageWriter once you have written to entire message. +// Concurrent calls to messageWriter are ok. +func (c *Conn) MessageWriter(ctx context.Context, dataType DataType) io.WriteCloser { + return &messageWriter{ c: c, - ctx: context.Background(), + ctx: ctx, datatype: dataType, } } -// MessageWriter enables writing to a WebSocket connection. -// Ensure you close the MessageWriter once you have written to entire message. -type MessageWriter struct { +// messageWriter enables writing to a WebSocket connection. +// Ensure you close the messageWriter once you have written to entire message. +type messageWriter struct { datatype DataType ctx context.Context c *Conn @@ -429,7 +429,7 @@ type MessageWriter struct { // The frame will automatically be fragmented as appropriate // with the buffers obtained from http.Hijacker. // Please ensure you call Close once you have written the full message. -func (w *MessageWriter) Write(p []byte) (int, error) { +func (w *messageWriter) Write(p []byte) (int, error) { if !w.acquiredLock { select { case <-w.c.closed: @@ -458,14 +458,9 @@ func (w *MessageWriter) Write(p []byte) (int, error) { } } -// SetContext bounds the writer to the context. -func (w *MessageWriter) SetContext(ctx context.Context) { - w.ctx = ctx -} - // Close flushes the frame to the connection. -// This must be called for every MessageWriter. -func (w *MessageWriter) Close() error { +// This must be called for every messageWriter. +func (w *messageWriter) Close() error { if !w.acquiredLock { select { case <-w.c.closed: @@ -492,13 +487,13 @@ func (w *MessageWriter) Close() error { // Please use SetContext on the reader to bound the read operation. // Your application must keep reading messages for the Conn to automatically respond to ping // and close frames. -func (c *Conn) ReadMessage(ctx context.Context) (DataType, *MessageReader, error) { +func (c *Conn) ReadMessage(ctx context.Context) (DataType, io.Reader, error) { select { case <-c.closed: return 0, nil, xerrors.Errorf("failed to read message: %w", c.getCloseErr()) case opcode := <-c.read: - return DataType(opcode), &MessageReader{ - ctx: context.Background(), + return DataType(opcode), &messageReader{ + ctx: ctx, c: c, }, nil case <-ctx.Done(): @@ -506,36 +501,21 @@ func (c *Conn) ReadMessage(ctx context.Context) (DataType, *MessageReader, error } } -// MessageReader enables reading a data frame from the WebSocket connection. -type MessageReader struct { - n int - limit int - c *Conn - ctx context.Context +// messageReader enables reading a data frame from the WebSocket connection. +type messageReader struct { + ctx context.Context + c *Conn } // SetContext bounds the read operation to the ctx. // By default, the context is the one passed to conn.ReadMessage. // You still almost always want a separate context for reading the message though. -func (r *MessageReader) SetContext(ctx context.Context) { +func (r *messageReader) SetContext(ctx context.Context) { r.ctx = ctx } -// Limit limits the number of bytes read by the reader. -// -// Why not use io.LimitReader? io.LimitReader returns a io.EOF -// after the limit bytes which means its not possible to tell -// whether the message has been read or a limit has been hit. -// This results in unclear error and log messages. -// This function will cause the connection to be closed if the limit is hit -// with a close reason explaining the error and also an error -// indicating the limit was hit. -func (r *MessageReader) Limit(bytes int) { - r.limit = bytes -} - // Read reads as many bytes as possible into p. -func (r *MessageReader) Read(p []byte) (n int, err error) { +func (r *messageReader) Read(p []byte) (n int, err error) { select { case <-r.c.closed: return 0, r.c.getCloseErr() @@ -546,11 +526,6 @@ func (r *MessageReader) Read(p []byte) (n int, err error) { case <-r.c.closed: return 0, r.c.getCloseErr() case n := <-r.c.readDone: - r.n += n - // TODO make this better later and inside readLoop to prevent the read from actually occuring if over limit. - if r.limit > 0 && r.n > r.limit { - return 0, xerrors.New("message too big") - } return n, nil case <-r.ctx.Done(): return 0, r.ctx.Err() diff --git a/websocket_test.go b/websocket_test.go index e91e5b22..e5850823 100644 --- a/websocket_test.go +++ b/websocket_test.go @@ -365,10 +365,7 @@ func echoLoop(ctx context.Context, c *websocket.Conn, t *testing.T) { return err } - r.SetContext(ctx) - - w := c.MessageWriter(typ) - w.SetContext(ctx) + w := c.MessageWriter(ctx, typ) _, err = io.Copy(w, r) if err != nil { From 9f4fb599abc4e52a163189024d7ff81660c04cb4 Mon Sep 17 00:00:00 2001 From: Anmol Sethi Date: Sat, 30 Mar 2019 23:04:10 -0500 Subject: [PATCH 4/6] Add negative payload length test --- header_test.go | 74 ++++++++++++++++++++++++++++++-------------------- 1 file changed, 45 insertions(+), 29 deletions(-) diff --git a/header_test.go b/header_test.go index aefd98d6..f7db9f39 100644 --- a/header_test.go +++ b/header_test.go @@ -18,41 +18,57 @@ func randBool() bool { } func TestHeader(t *testing.T) { - -} - -func TestFuzzHeader(t *testing.T) { t.Parallel() + + t.Run("negative", func(t *testing.T) { + t.Parallel() - for i := 0; i < 1000; i++ { - h := header{ - fin: randBool(), - rsv1: randBool(), - rsv2: randBool(), - rsv3: randBool(), - opcode: opcode(rand.Intn(1 << 4)), - - masked: randBool(), - payloadLength: rand.Int63(), - } + b := marshalHeader(header{ + payloadLength: 1<<16 + 1, + }) - if h.masked { - rand.Read(h.maskKey[:]) - } + // Make length negative + b[2] |= 1 << 7 - b := marshalHeader(h) r := bytes.NewReader(b) - h2, err := readHeader(r) - if err != nil { - t.Logf("header: %#v", h) - t.Logf("bytes: %b", b) - t.Fatalf("failed to read header: %v", err) + _, err := readHeader(r) + if err == nil { + t.Fatalf("unexpected error value: %+v", err) } + }) + t.Run("fuzz", func(t *testing.T) { + t.Parallel() + + for i := 0; i < 1000; i++ { + h := header{ + fin: randBool(), + rsv1: randBool(), + rsv2: randBool(), + rsv3: randBool(), + opcode: opcode(rand.Intn(1 << 4)), + + masked: randBool(), + payloadLength: rand.Int63(), + } + + if h.masked { + rand.Read(h.maskKey[:]) + } + + b := marshalHeader(h) + r := bytes.NewReader(b) + h2, err := readHeader(r) + if err != nil { + t.Logf("header: %#v", h) + t.Logf("bytes: %b", b) + t.Fatalf("failed to read header: %v", err) + } - if !cmp.Equal(h, h2, cmp.AllowUnexported(header{})) { - t.Logf("header: %#v", h) - t.Logf("bytes: %b", b) - t.Fatalf("parsed and read header differ: %v", cmp.Diff(h, h2, cmp.AllowUnexported(header{}))) + if !cmp.Equal(h, h2, cmp.AllowUnexported(header{})) { + t.Logf("header: %#v", h) + t.Logf("bytes: %b", b) + t.Fatalf("parsed and read header differ: %v", cmp.Diff(h, h2, cmp.AllowUnexported(header{}))) + } } - } + }) } From 64e74708676047263809b2d1750204a5e715aa2e Mon Sep 17 00:00:00 2001 From: Anmol Sethi Date: Sat, 30 Mar 2019 23:04:10 -0500 Subject: [PATCH 5/6] Improve API brevity --- example_test.go | 4 ++-- header_test.go | 2 +- json.go | 6 +++--- test.sh | 1 + websocket.go | 6 +++--- websocket_test.go | 6 +++--- 6 files changed, 13 insertions(+), 12 deletions(-) diff --git a/example_test.go b/example_test.go index b3ed2a54..a810c5be 100644 --- a/example_test.go +++ b/example_test.go @@ -29,14 +29,14 @@ func ExampleAccept_echo() { ctx, cancel := context.WithTimeout(ctx, time.Minute) defer cancel() - typ, r, err := c.ReadMessage(ctx) + typ, r, err := c.Read(ctx) if err != nil { return err } r = io.LimitReader(r, 32768) - w := c.MessageWriter(ctx, typ) + w := c.Write(ctx, typ) _, err = io.Copy(w, r) if err != nil { return err diff --git a/header_test.go b/header_test.go index f7db9f39..65812997 100644 --- a/header_test.go +++ b/header_test.go @@ -19,7 +19,7 @@ func randBool() bool { func TestHeader(t *testing.T) { t.Parallel() - + t.Run("negative", func(t *testing.T) { t.Parallel() diff --git a/json.go b/json.go index 514be050..53869b59 100644 --- a/json.go +++ b/json.go @@ -10,7 +10,7 @@ import ( // JSONConn wraps around a Conn with JSON helpers. type JSONConn struct { - Conn *Conn + *Conn } // Read reads a json message into v. @@ -23,7 +23,7 @@ func (jc JSONConn) Read(ctx context.Context, v interface{}) error { } func (jc *JSONConn) read(ctx context.Context, v interface{}) error { - typ, r, err := jc.Conn.ReadMessage(ctx) + typ, r, err := jc.Conn.Read(ctx) if err != nil { return err } @@ -52,7 +52,7 @@ func (jc JSONConn) Write(ctx context.Context, v interface{}) error { } func (jc JSONConn) write(ctx context.Context, v interface{}) error { - w := jc.Conn.MessageWriter(ctx, DataText) + w := jc.Conn.Write(ctx, DataText) e := json.NewEncoder(w) err := e.Encode(v) diff --git a/test.sh b/test.sh index 736dcda7..d6e8e00a 100755 --- a/test.sh +++ b/test.sh @@ -10,6 +10,7 @@ function docker_run() { local IMAGE IMAGE="$(docker build -q "$DIR")" docker run \ + -it \ -v "${PWD}:/repo" \ -v "$(go env GOPATH):/go" \ -v "$(go env GOCACHE):/root/.cache/go-build" \ diff --git a/websocket.go b/websocket.go index 99d28559..88520858 100644 --- a/websocket.go +++ b/websocket.go @@ -401,11 +401,11 @@ func (c *Conn) writeControl(ctx context.Context, opcode opcode, p []byte) error } } -// MessageWriter returns a writer bounded by the context that will write +// Write returns a writer bounded by the context that will write // a WebSocket data frame of type dataType to the connection. // Ensure you close the messageWriter once you have written to entire message. // Concurrent calls to messageWriter are ok. -func (c *Conn) MessageWriter(ctx context.Context, dataType DataType) io.WriteCloser { +func (c *Conn) Write(ctx context.Context, dataType DataType) io.WriteCloser { return &messageWriter{ c: c, ctx: ctx, @@ -487,7 +487,7 @@ func (w *messageWriter) Close() error { // Please use SetContext on the reader to bound the read operation. // Your application must keep reading messages for the Conn to automatically respond to ping // and close frames. -func (c *Conn) ReadMessage(ctx context.Context) (DataType, io.Reader, error) { +func (c *Conn) Read(ctx context.Context) (DataType, io.Reader, error) { select { case <-c.closed: return 0, nil, xerrors.Errorf("failed to read message: %w", c.getCloseErr()) diff --git a/websocket_test.go b/websocket_test.go index e5850823..bedb22c3 100644 --- a/websocket_test.go +++ b/websocket_test.go @@ -360,12 +360,12 @@ func echoLoop(ctx context.Context, c *websocket.Conn, t *testing.T) { ctx, cancel := context.WithTimeout(ctx, time.Minute) defer cancel() - typ, r, err := c.ReadMessage(ctx) + typ, r, err := c.Read(ctx) if err != nil { return err } - w := c.MessageWriter(ctx, typ) + w := c.Write(ctx, typ) _, err = io.Copy(w, r) if err != nil { @@ -447,7 +447,7 @@ func TestAutobahnClient(t *testing.T) { } defer c.Close(websocket.StatusInternalError, "") - _, r, err := c.ReadMessage(ctx) + _, r, err := c.Read(ctx) if err != nil { t.Fatal(err) } From 571644565ba36eb335bbbf581b0c18f66d759cf9 Mon Sep 17 00:00:00 2001 From: Anmol Sethi Date: Sat, 30 Mar 2019 23:04:10 -0500 Subject: [PATCH 6/6] Simplify parts of websocket.go --- websocket.go | 76 +++++++++++++++++++++++----------------------------- 1 file changed, 33 insertions(+), 43 deletions(-) diff --git a/websocket.go b/websocket.go index 88520858..09a94e78 100644 --- a/websocket.go +++ b/websocket.go @@ -49,13 +49,6 @@ type Conn struct { readDone chan int } -func (c *Conn) getCloseErr() error { - if c.closeErr != nil { - return c.closeErr - } - return nil -} - func (c *Conn) close(err error) { if err != nil { err = xerrors.Errorf("websocket: connection broken: %w", err) @@ -160,8 +153,12 @@ messageLoop: masked: c.client, } c.writeFrame(h, control.payload) - c.writeDone <- struct{}{} - continue + select { + case <-c.closed: + return + case c.writeDone <- struct{}{}: + continue + } case b, ok := <-c.writeBytes: h := header{ fin: !ok, @@ -349,14 +346,14 @@ func (c *Conn) Close(code StatusCode, reason string) error { p, _ = closePayload(StatusInternalError, fmt.Sprintf("websocket: application tried to send code %v but code or reason was invalid", code)) } - err2 := c.writeClose(p, CloseError{ + cerr := c.writeClose(p, CloseError{ Code: code, Reason: reason, }) if err != nil { return err } - return err2 + return cerr } func (c *Conn) writeClose(p []byte, cerr CloseError) error { @@ -381,19 +378,19 @@ func (c *Conn) writeClose(p []byte, cerr CloseError) error { func (c *Conn) writeControl(ctx context.Context, opcode opcode, p []byte) error { select { case <-c.closed: - return c.getCloseErr() + return c.closeErr case c.control <- control{ opcode: opcode, payload: p, }: case <-ctx.Done(): c.close(xerrors.New("force closed: close frame write timed out")) - return c.getCloseErr() + return c.closeErr } select { case <-c.closed: - return c.getCloseErr() + return c.closeErr case <-c.writeDone: return nil case <-ctx.Done(): @@ -420,9 +417,6 @@ type messageWriter struct { ctx context.Context c *Conn acquiredLock bool - sentFirst bool - - done chan struct{} } // Write writes the given bytes to the WebSocket connection. @@ -430,24 +424,18 @@ type messageWriter struct { // with the buffers obtained from http.Hijacker. // Please ensure you call Close once you have written the full message. func (w *messageWriter) Write(p []byte) (int, error) { - if !w.acquiredLock { - select { - case <-w.c.closed: - return 0, w.c.getCloseErr() - case w.c.write <- w.datatype: - w.acquiredLock = true - case <-w.ctx.Done(): - return 0, w.ctx.Err() - } + err := w.acquire() + if err != nil { + return 0, err } select { case <-w.c.closed: - return 0, w.c.getCloseErr() + return 0, w.c.closeErr case w.c.writeBytes <- p: select { case <-w.c.closed: - return 0, w.c.getCloseErr() + return 0, w.c.closeErr case <-w.c.writeDone: return len(p), nil case <-w.ctx.Done(): @@ -458,23 +446,32 @@ func (w *messageWriter) Write(p []byte) (int, error) { } } -// Close flushes the frame to the connection. -// This must be called for every messageWriter. -func (w *messageWriter) Close() error { +func (w *messageWriter) acquire() error { if !w.acquiredLock { select { case <-w.c.closed: - return w.c.getCloseErr() + return w.c.closeErr case w.c.write <- w.datatype: w.acquiredLock = true case <-w.ctx.Done(): return w.ctx.Err() } } + return nil +} + +// Close flushes the frame to the connection. +// This must be called for every messageWriter. +func (w *messageWriter) Close() error { + err := w.acquire() + if err != nil { + return err + } + close(w.c.writeBytes) select { case <-w.c.closed: - return w.c.getCloseErr() + return w.c.closeErr case <-w.ctx.Done(): return w.ctx.Err() case <-w.c.writeDone: @@ -490,7 +487,7 @@ func (w *messageWriter) Close() error { func (c *Conn) Read(ctx context.Context) (DataType, io.Reader, error) { select { case <-c.closed: - return 0, nil, xerrors.Errorf("failed to read message: %w", c.getCloseErr()) + return 0, nil, xerrors.Errorf("failed to read message: %w", c.closeErr) case opcode := <-c.read: return DataType(opcode), &messageReader{ ctx: ctx, @@ -507,24 +504,17 @@ type messageReader struct { c *Conn } -// SetContext bounds the read operation to the ctx. -// By default, the context is the one passed to conn.ReadMessage. -// You still almost always want a separate context for reading the message though. -func (r *messageReader) SetContext(ctx context.Context) { - r.ctx = ctx -} - // Read reads as many bytes as possible into p. func (r *messageReader) Read(p []byte) (n int, err error) { select { case <-r.c.closed: - return 0, r.c.getCloseErr() + return 0, r.c.closeErr case <-r.c.readDone: return 0, io.EOF case r.c.readBytes <- p: select { case <-r.c.closed: - return 0, r.c.getCloseErr() + return 0, r.c.closeErr case n := <-r.c.readDone: return n, nil case <-r.ctx.Done():