From a3a891bf62a60365ea2084d19c029fdf8802a55e Mon Sep 17 00:00:00 2001 From: Anmol Sethi Date: Wed, 28 Aug 2019 13:07:01 -0400 Subject: [PATCH 1/5] Improve coverage in dial.go and header.go --- ci/test.sh | 2 +- dial_test.go | 42 ++++++++++++++++++++++++++++++++++++++++++ go.mod | 4 +--- go.sum | 4 ++-- header_test.go | 15 +++++++++++++++ websocket_test.go | 2 ++ 6 files changed, 63 insertions(+), 6 deletions(-) diff --git a/ci/test.sh b/ci/test.sh index ab101e91..875216f1 100755 --- a/ci/test.sh +++ b/ci/test.sh @@ -8,7 +8,7 @@ mkdir -p ci/out/websocket testFlags=( -race "-vet=off" - "-bench=." + # "-bench=." "-coverprofile=ci/out/coverage.prof" "-coverpkg=./..." ) diff --git a/dial_test.go b/dial_test.go index 6f0deef9..6400c223 100644 --- a/dial_test.go +++ b/dial_test.go @@ -1,11 +1,53 @@ package websocket import ( + "context" "net/http" "net/http/httptest" "testing" + "time" ) +func TestBadDials(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + url string + opts DialOptions + }{ + { + name: "badURL", + url: "://noscheme", + }, + { + name: "badURLScheme", + url: "ftp://nhooyr.io", + }, + { + name: "badHTTPClient", + url: "ws://nhooyr.io", + opts: DialOptions{ + HTTPClient: &http.Client{ + Timeout: time.Minute, + }, + }, + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + _, _, err := Dial(context.Background(), tc.url, tc.opts) + if err == nil { + t.Fatalf("expected non nil error: %+v", err) + } + }) + } +} + func Test_verifyServerHandshake(t *testing.T) { t.Parallel() diff --git a/go.mod b/go.mod index 58d79bf1..35d500dd 100644 --- a/go.mod +++ b/go.mod @@ -13,8 +13,6 @@ require ( golang.org/x/time v0.0.0-20190308202827-9d24e82272b4 golang.org/x/tools v0.0.0-20190429184909-35c670923e21 golang.org/x/xerrors v0.0.0-20190513163551-3ee3066db522 - gotest.tools/gotestsum v0.3.5 + gotest.tools/gotestsum v0.3.6-0.20190825182939-fc6cb5870c52 mvdan.cc/sh v2.6.4+incompatible ) - -replace gotest.tools/gotestsum => github.com/nhooyr/gotestsum v0.3.6-0.20190821172136-aaabbb33254b diff --git a/go.sum b/go.sum index 98b766bf..b9e3737c 100644 --- a/go.sum +++ b/go.sum @@ -22,8 +22,6 @@ github.com/mattn/go-colorable v0.0.9 h1:UVL0vNpWh04HeJXV0KLcaT7r06gOH2l4OW6ddYRU github.com/mattn/go-colorable v0.0.9/go.mod h1:9vuHe8Xs5qXnSaW/c/ABM9alt+Vo+STaOChaDxuIBZU= github.com/mattn/go-isatty v0.0.3 h1:ns/ykhmWi7G9O+8a448SecJU3nSMBXJfqQkl0upE1jI= github.com/mattn/go-isatty v0.0.3/go.mod h1:M+lRXTBqGeGNdLjl/ufCoiOlB5xdOkqRJdNxMWT7Zi4= -github.com/nhooyr/gotestsum v0.3.6-0.20190821172136-aaabbb33254b h1:t6DbmxEtGMM72Uhs638nBOyK7tjsrDwoMfYO1EfQdFE= -github.com/nhooyr/gotestsum v0.3.6-0.20190821172136-aaabbb33254b/go.mod h1:Mnf3e5FUzXbkCfynWBGOwLssY7gTQgCHObK9tMpAriY= github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= github.com/onsi/ginkgo v1.8.0 h1:VkHVNpR4iVnU8XQR6DBm8BqYjN7CRzw+xKUbVVbbW9w= github.com/onsi/ginkgo v1.8.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= @@ -85,5 +83,7 @@ gopkg.in/yaml.v2 v2.2.1 h1:mUhvW9EsL+naU5Q3cakzfE91YhliOondGd6ZrsDBHQE= gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gotest.tools v2.1.0+incompatible h1:5USw7CrJBYKqjg9R7QlA6jzqZKEAtvW82aNmsxxGPxw= gotest.tools v2.1.0+incompatible/go.mod h1:DsYFclhRJ6vuDpmuTbkuFWG+y2sxOXAzmJt81HFBacw= +gotest.tools/gotestsum v0.3.6-0.20190825182939-fc6cb5870c52 h1:Qr31uPFyjpOhAgRfKV4ATUnknnLT2X7HFjqwkstdbbE= +gotest.tools/gotestsum v0.3.6-0.20190825182939-fc6cb5870c52/go.mod h1:Mnf3e5FUzXbkCfynWBGOwLssY7gTQgCHObK9tMpAriY= mvdan.cc/sh v2.6.4+incompatible h1:eD6tDeh0pw+/TOTI1BBEryZ02rD2nMcFsgcvde7jffM= mvdan.cc/sh v2.6.4+incompatible/go.mod h1:IeeQbZq+x2SUGBensq/jge5lLQbS3XT2ktyp3wrt4x8= diff --git a/header_test.go b/header_test.go index b45854ea..4457c356 100644 --- a/header_test.go +++ b/header_test.go @@ -21,6 +21,21 @@ func randBool() bool { func TestHeader(t *testing.T) { t.Parallel() + t.Run("writeNegativeLength", func(t *testing.T) { + t.Parallel() + + defer func() { + r := recover() + if r == nil { + t.Fatal("failed to induce panic in writeHeader with negative payload length") + } + }() + + writeHeader(nil, header{ + payloadLength: -1, + }) + }) + t.Run("readNegativeLength", func(t *testing.T) { t.Parallel() diff --git a/websocket_test.go b/websocket_test.go index cd6bdaf5..2ef25cdd 100644 --- a/websocket_test.go +++ b/websocket_test.go @@ -654,6 +654,7 @@ func testServer(tb testing.TB, fn http.HandlerFunc) (s *httptest.Server, closeFn // https://github.com/crossbario/autobahn-python/tree/master/wstest func TestAutobahnServer(t *testing.T) { t.Parallel() + t.Skip() s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { c, err := websocket.Accept(w, r, websocket.AcceptOptions{ @@ -794,6 +795,7 @@ func unusedListenAddr() (string, error) { // https://github.com/crossbario/autobahn-python/blob/master/wstest/testee_client_aio.py func TestAutobahnClient(t *testing.T) { t.Parallel() + t.Skip() serverAddr, err := unusedListenAddr() if err != nil { From 679ddb825d5cd5ce4cc7136734fff5effe3a2910 Mon Sep 17 00:00:00 2001 From: Anmol Sethi Date: Thu, 29 Aug 2019 15:37:26 -0500 Subject: [PATCH 2/5] Drastically improve non autobahn test coverage Also simplified and refactored the Conn tests. More changes soon. --- accept_test.go | 33 ++ ci/test.sh | 31 +- dial_test.go | 9 +- export_test.go | 12 +- header_test.go | 31 ++ netconn.go | 4 +- statuscode.go | 2 +- statuscode_test.go | 108 ++++++- websocket.go | 56 ++-- websocket_test.go | 781 ++++++++++++++++++++++++++++++--------------- 10 files changed, 761 insertions(+), 306 deletions(-) diff --git a/accept_test.go b/accept_test.go index 6f5c3fb9..8634066b 100644 --- a/accept_test.go +++ b/accept_test.go @@ -6,6 +6,39 @@ import ( "testing" ) +func TestAccept(t *testing.T) { + t.Parallel() + + t.Run("badClientHandshake", func(t *testing.T) { + t.Parallel() + + w := httptest.NewRecorder() + r := httptest.NewRequest("GET", "/", nil) + + _, err := Accept(w, r, AcceptOptions{}) + if err == nil { + t.Fatalf("unexpected error value: %v", err) + } + + }) + + t.Run("requireHttpHijacker", func(t *testing.T) { + t.Parallel() + + w := httptest.NewRecorder() + r := httptest.NewRequest("GET", "/", nil) + r.Header.Set("Connection", "Upgrade") + r.Header.Set("Upgrade", "websocket") + r.Header.Set("Sec-WebSocket-Version", "13") + r.Header.Set("Sec-WebSocket-Key", "meow123") + + _, err := Accept(w, r, AcceptOptions{}) + if err == nil || !strings.Contains(err.Error(), "http.Hijacker") { + t.Fatalf("unexpected error value: %v", err) + } + }) +} + func Test_verifyClientHandshake(t *testing.T) { t.Parallel() diff --git a/ci/test.sh b/ci/test.sh index 875216f1..1d4a8b07 100755 --- a/ci/test.sh +++ b/ci/test.sh @@ -4,19 +4,34 @@ set -euo pipefail cd "$(dirname "${0}")" cd "$(git rev-parse --show-toplevel)" -mkdir -p ci/out/websocket -testFlags=( +argv=( + go run gotest.tools/gotestsum + # https://circleci.com/docs/2.0/collect-test-data/ + "--junitfile=ci/out/websocket/testReport.xml" + "--format=short-verbose" + -- -race "-vet=off" - # "-bench=." + "-bench=." +) +# Interactive usage probably does not want to enable benchmarks, race detection +# turn off vet or use gotestsum by default. +if [[ $# -gt 0 ]]; then + argv=(go test "$@") +fi + +# We always want coverage. +argv+=( "-coverprofile=ci/out/coverage.prof" "-coverpkg=./..." ) -# https://circleci.com/docs/2.0/collect-test-data/ -go run gotest.tools/gotestsum \ - --junitfile ci/out/websocket/testReport.xml \ - --format=short-verbose \ - -- "${testFlags[@]}" + +mkdir -p ci/out/websocket +"${argv[@]}" + +# Removes coverage of generated files. +grep -v _string.go < ci/out/coverage.prof > ci/out/coverage2.prof +mv ci/out/coverage2.prof ci/out/coverage.prof go tool cover -html=ci/out/coverage.prof -o=ci/out/coverage.html if [[ ${CI:-} ]]; then diff --git a/dial_test.go b/dial_test.go index 6400c223..4607493b 100644 --- a/dial_test.go +++ b/dial_test.go @@ -33,6 +33,10 @@ func TestBadDials(t *testing.T) { }, }, }, + { + name: "badTLS", + url: "wss://totallyfake.nhooyr.io", + }, } for _, tc := range testCases { @@ -40,7 +44,10 @@ func TestBadDials(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() - _, _, err := Dial(context.Background(), tc.url, tc.opts) + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + + _, _, err := Dial(ctx, tc.url, tc.opts) if err == nil { t.Fatalf("expected non nil error: %+v", err) } diff --git a/export_test.go b/export_test.go index 22ad76fc..ab766f14 100644 --- a/export_test.go +++ b/export_test.go @@ -1,3 +1,13 @@ package websocket -var Compute = handleSecWebSocketKey +import ( + "context" +) + +type Addr = websocketAddr + +type Header = header + +func (c *Conn) WriteFrame(ctx context.Context, fin bool, opcode opcode, p []byte) (int, error) { + return c.writeFrame(ctx, fin, opcode, p) +} diff --git a/header_test.go b/header_test.go index 4457c356..45d0535a 100644 --- a/header_test.go +++ b/header_test.go @@ -2,6 +2,7 @@ package websocket import ( "bytes" + "io" "math/rand" "strconv" "testing" @@ -21,6 +22,36 @@ func randBool() bool { func TestHeader(t *testing.T) { t.Parallel() + t.Run("eof", func(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + bytes []byte + }{ + { + "start", + []byte{0xff}, + }, + { + "middle", + []byte{0xff, 0xff, 0xff}, + }, + } + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + b := bytes.NewBuffer(tc.bytes) + _, err := readHeader(nil, b) + if io.ErrUnexpectedEOF != err { + t.Fatalf("expected %v but got: %v", io.ErrUnexpectedEOF, err) + } + }) + } + }) + t.Run("writeNegativeLength", func(t *testing.T) { t.Parallel() diff --git a/netconn.go b/netconn.go index d28eeb84..a6f902da 100644 --- a/netconn.go +++ b/netconn.go @@ -101,8 +101,8 @@ func (c *netConn) Read(p []byte) (int, error) { return 0, err } if typ != c.msgType { - c.c.Close(StatusUnsupportedData, fmt.Sprintf("can only accept %v messages", c.msgType)) - return 0, xerrors.Errorf("unexpected frame type read for net conn adapter (expected %v): %v", c.msgType, typ) + c.c.Close(StatusUnsupportedData, fmt.Sprintf("unexpected frame type read (expected %v): %v", c.msgType, typ)) + return 0, c.c.closeErr } c.reader = r } diff --git a/statuscode.go b/statuscode.go index 42ae40c0..498437d0 100644 --- a/statuscode.go +++ b/statuscode.go @@ -35,7 +35,7 @@ const ( StatusTryAgainLater StatusBadGateway // statusTLSHandshake is unexported because we just return - // handshake error in dial. We do not return a conn + // the handshake error in dial. We do not return a conn // so there is nothing to use this on. At least until WASM. statusTLSHandshake ) diff --git a/statuscode_test.go b/statuscode_test.go index 38ee4c3f..b9637868 100644 --- a/statuscode_test.go +++ b/statuscode_test.go @@ -4,14 +4,13 @@ import ( "math" "strings" "testing" + + "github.com/google/go-cmp/cmp" ) func TestCloseError(t *testing.T) { t.Parallel() - // Other parts of close error are tested by websocket_test.go right now - // with the autobahn tests. - testCases := []struct { name string ce CloseError @@ -50,7 +49,108 @@ func TestCloseError(t *testing.T) { _, err := tc.ce.bytes() if (err == nil) != tc.success { - t.Fatalf("unexpected error value: %v", err) + t.Fatalf("unexpected error value: %+v", err) + } + }) + } +} + +func Test_parseClosePayload(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + p []byte + success bool + ce CloseError + }{ + { + name: "normal", + p: append([]byte{0x3, 0xE8}, []byte("hello")...), + success: true, + ce: CloseError{ + Code: StatusNormalClosure, + Reason: "hello", + }, + }, + { + name: "nothing", + success: true, + ce: CloseError{ + Code: StatusNoStatusRcvd, + }, + }, + { + name: "oneByte", + p: []byte{0}, + success: false, + }, + { + name: "badStatusCode", + p: []byte{0x17, 0x70}, + success: false, + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + ce, err := parseClosePayload(tc.p) + if (err == nil) != tc.success { + t.Fatalf("unexpected expected error value: %+v", err) + } + + if tc.success && tc.ce != ce { + t.Fatalf("unexpected close error: %v", cmp.Diff(tc.ce, ce)) + } + }) + } +} + +func Test_validWireCloseCode(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + code StatusCode + valid bool + }{ + { + name: "normal", + code: StatusNormalClosure, + valid: true, + }, + { + name: "noStatus", + code: StatusNoStatusRcvd, + valid: false, + }, + { + name: "3000", + code: 3000, + valid: true, + }, + { + name: "4999", + code: 4999, + valid: true, + }, + { + name: "unknown", + code: 5000, + valid: false, + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + if valid := validWireCloseCode(tc.code); tc.valid != valid { + t.Fatalf("expected %v for %v but got %v", tc.valid, tc.code, valid) } }) } diff --git a/websocket.go b/websocket.go index 393ea547..833c1209 100644 --- a/websocket.go +++ b/websocket.go @@ -7,8 +7,8 @@ import ( "fmt" "io" "io/ioutil" + "log" "math/rand" - "os" "runtime" "strconv" "sync" @@ -210,9 +210,8 @@ func (c *Conn) readTillMsg(ctx context.Context) (header, error) { } if h.rsv1 || h.rsv2 || h.rsv3 { - err := xerrors.Errorf("received header with rsv bits set: %v:%v:%v", h.rsv1, h.rsv2, h.rsv3) - c.Close(StatusProtocolError, err.Error()) - return header{}, err + c.Close(StatusProtocolError, fmt.Sprintf("received header with rsv bits set: %v:%v:%v", h.rsv1, h.rsv2, h.rsv3)) + return header{}, c.closeErr } if h.opcode.controlOp() { @@ -227,9 +226,8 @@ func (c *Conn) readTillMsg(ctx context.Context) (header, error) { case opBinary, opText, opContinuation: return h, nil default: - err := xerrors.Errorf("received unknown opcode %v", h.opcode) - c.Close(StatusProtocolError, err.Error()) - return header{}, err + c.Close(StatusProtocolError, fmt.Sprintf("received unknown opcode %v", h.opcode)) + return header{}, c.closeErr } } } @@ -273,15 +271,13 @@ func (c *Conn) readFrameHeader(ctx context.Context) (header, error) { func (c *Conn) handleControl(ctx context.Context, h header) error { if h.payloadLength > maxControlFramePayload { - err := xerrors.Errorf("control frame too large at %v bytes", h.payloadLength) - c.Close(StatusProtocolError, err.Error()) - return err + c.Close(StatusProtocolError, fmt.Sprintf("control frame too large at %v bytes", h.payloadLength)) + return c.closeErr } if !h.fin { - err := xerrors.Errorf("received fragmented control frame") - c.Close(StatusProtocolError, err.Error()) - return err + c.Close(StatusProtocolError, "received fragmented control frame") + return c.closeErr } ctx, cancel := context.WithTimeout(ctx, time.Second*5) @@ -311,8 +307,9 @@ func (c *Conn) handleControl(ctx context.Context, h header) error { case opClose: ce, err := parseClosePayload(b) if err != nil { - c.Close(StatusProtocolError, "received invalid close payload") - return xerrors.Errorf("received invalid close payload: %w", err) + err = xerrors.Errorf("received invalid close payload: %w", err) + c.Close(StatusProtocolError, err.Error()) + return c.closeErr } // This ensures the closeErr of the Conn is always the received CloseError // in case the echo close frame write fails. @@ -376,9 +373,8 @@ func (c *Conn) reader(ctx context.Context) (MessageType, io.Reader, error) { if c.activeReader != nil && !c.activeReader.eof() { if h.opcode != opContinuation { - err := xerrors.Errorf("received new data message without finishing the previous message") - c.Close(StatusProtocolError, err.Error()) - return 0, nil, err + c.Close(StatusProtocolError, "received new data message without finishing the previous message") + return 0, nil, c.closeErr } if !h.fin || h.payloadLength > 0 { @@ -392,9 +388,8 @@ func (c *Conn) reader(ctx context.Context) (MessageType, io.Reader, error) { return 0, nil, err } } else if h.opcode == opContinuation { - err := xerrors.Errorf("received continuation frame not after data or text frame") - c.Close(StatusProtocolError, err.Error()) - return 0, nil, err + c.Close(StatusProtocolError, "received continuation frame not after data or text frame") + return 0, nil, c.closeErr } c.readerMsgCtx = ctx @@ -460,9 +455,8 @@ func (r *messageReader) read(p []byte) (int, error) { } if r.c.readMsgLeft <= 0 { - err := xerrors.Errorf("read limited at %v bytes", r.c.msgReadLimit) - r.c.Close(StatusMessageTooBig, err.Error()) - return 0, err + r.c.Close(StatusMessageTooBig, fmt.Sprintf("read limited at %v bytes", r.c.msgReadLimit)) + return 0, r.c.closeErr } if int64(len(p)) > r.c.readMsgLeft { @@ -476,9 +470,8 @@ func (r *messageReader) read(p []byte) (int, error) { } if h.opcode != opContinuation { - err := xerrors.Errorf("received new data message without finishing the previous message") - r.c.Close(StatusProtocolError, err.Error()) - return 0, err + r.c.Close(StatusProtocolError, "received new data message without finishing the previous message") + return 0, r.c.closeErr } r.c.readerMsgHeader = h @@ -828,7 +821,7 @@ func (c *Conn) writePong(p []byte) error { func (c *Conn) Close(code StatusCode, reason string) error { err := c.exportedClose(code, reason) if err != nil { - return xerrors.Errorf("failed to close connection: %w", err) + return xerrors.Errorf("failed to close websocket connection: %w", err) } return nil } @@ -844,7 +837,7 @@ func (c *Conn) exportedClose(code StatusCode, reason string) error { // Definitely worth seeing what popular browsers do later. p, err := ce.bytes() if err != nil { - fmt.Fprintf(os.Stderr, "websocket: failed to marshal close frame: %v\n", err) + log.Printf("websocket: failed to marshal close frame: %+v", err) ce = CloseError{ Code: StatusInternalError, } @@ -853,12 +846,13 @@ func (c *Conn) exportedClose(code StatusCode, reason string) error { // CloseErrors sent are made opaque to prevent applications from thinking // they received a given status. - err = c.writeClose(p, xerrors.Errorf("sent close frame: %v", ce)) + sentErr := xerrors.Errorf("sent close frame: %v", ce) + err = c.writeClose(p, sentErr) if err != nil { return err } - if !xerrors.Is(c.closeErr, ce) { + if !xerrors.Is(c.closeErr, sentErr) { return c.closeErr } diff --git a/websocket_test.go b/websocket_test.go index 2ef25cdd..b45f024f 100644 --- a/websocket_test.go +++ b/websocket_test.go @@ -4,8 +4,11 @@ import ( "context" "encoding/json" "fmt" + "github.com/golang/protobuf/proto" + "github.com/golang/protobuf/ptypes/timestamp" "io" "io/ioutil" + "math/rand" "net" "net/http" "net/http/cookiejar" @@ -75,127 +78,6 @@ func TestHandshake(t *testing.T) { return nil }, }, - { - name: "closeError", - server: func(w http.ResponseWriter, r *http.Request) error { - c, err := websocket.Accept(w, r, websocket.AcceptOptions{}) - if err != nil { - return err - } - defer c.Close(websocket.StatusInternalError, "") - - err = wsjson.Write(r.Context(), c, "hello") - if err != nil { - return err - } - - return nil - }, - client: func(ctx context.Context, u string) error { - c, _, err := websocket.Dial(ctx, u, websocket.DialOptions{ - Subprotocols: []string{"meow"}, - }) - if err != nil { - return err - } - defer c.Close(websocket.StatusInternalError, "") - - var m string - err = wsjson.Read(ctx, c, &m) - if err != nil { - return err - } - - if m != "hello" { - return xerrors.Errorf("recieved unexpected msg but expected hello: %+v", m) - } - - _, _, err = c.Reader(ctx) - var cerr websocket.CloseError - if !xerrors.As(err, &cerr) || cerr.Code != websocket.StatusInternalError { - return xerrors.Errorf("unexpected error: %+v", err) - } - - return nil - }, - }, - { - name: "netConn", - server: func(w http.ResponseWriter, r *http.Request) error { - c, err := websocket.Accept(w, r, websocket.AcceptOptions{}) - if err != nil { - return err - } - defer c.Close(websocket.StatusInternalError, "") - - nc := websocket.NetConn(c, websocket.MessageBinary) - defer nc.Close() - - nc.SetWriteDeadline(time.Time{}) - time.Sleep(1) - nc.SetWriteDeadline(time.Now().Add(time.Second * 15)) - - for i := 0; i < 3; i++ { - _, err = nc.Write([]byte("hello")) - if err != nil { - return err - } - } - - return nil - }, - client: func(ctx context.Context, u string) error { - c, _, err := websocket.Dial(ctx, u, websocket.DialOptions{ - Subprotocols: []string{"meow"}, - }) - if err != nil { - return err - } - defer c.Close(websocket.StatusInternalError, "") - - nc := websocket.NetConn(c, websocket.MessageBinary) - defer nc.Close() - - nc.SetReadDeadline(time.Time{}) - time.Sleep(1) - nc.SetReadDeadline(time.Now().Add(time.Second * 15)) - - read := func() error { - p := make([]byte, len("hello")) - // We do not use io.ReadFull here as it masks EOFs. - // See https://github.com/nhooyr/websocket/issues/100#issuecomment-508148024 - _, err = nc.Read(p) - if err != nil { - return err - } - - if string(p) != "hello" { - return xerrors.Errorf("unexpected payload %q received", string(p)) - } - return nil - } - - for i := 0; i < 3; i++ { - err = read() - if err != nil { - return err - } - } - - // Ensure the close frame is converted to an EOF and multiple read's after all return EOF. - err = read() - if err != io.EOF { - return err - } - - err = read() - if err != io.EOF { - return err - } - - return nil - }, - }, { name: "defaultSubprotocol", server: func(w http.ResponseWriter, r *http.Request) error { @@ -323,22 +205,240 @@ func TestHandshake(t *testing.T) { if err != nil { return err } - defer c.Close(websocket.StatusInternalError, "") + defer c.Close(websocket.StatusInternalError, "") + return nil + }, + }, + { + name: "cookies", + server: func(w http.ResponseWriter, r *http.Request) error { + cookie, err := r.Cookie("mycookie") + if err != nil { + return xerrors.Errorf("request is missing mycookie: %w", err) + } + if cookie.Value != "myvalue" { + return xerrors.Errorf("expected %q but got %q", "myvalue", cookie.Value) + } + c, err := websocket.Accept(w, r, websocket.AcceptOptions{}) + if err != nil { + return err + } + c.Close(websocket.StatusInternalError, "") + return nil + }, + client: func(ctx context.Context, u string) error { + jar, err := cookiejar.New(nil) + if err != nil { + return xerrors.Errorf("failed to create cookie jar: %w", err) + } + parsedURL, err := url.Parse(u) + if err != nil { + return xerrors.Errorf("failed to parse url: %w", err) + } + parsedURL.Scheme = "http" + jar.SetCookies(parsedURL, []*http.Cookie{ + { + Name: "mycookie", + Value: "myvalue", + }, + }) + hc := &http.Client{ + Jar: jar, + } + c, _, err := websocket.Dial(ctx, u, websocket.DialOptions{ + HTTPClient: hc, + }) + if err != nil { + return err + } + c.Close(websocket.StatusInternalError, "") + return nil + }, + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + s, closeFn := testServer(t, tc.server, false) + defer closeFn() + + wsURL := strings.Replace(s.URL, "http", "ws", 1) + + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + defer cancel() + + err := tc.client(ctx, wsURL) + if err != nil { + t.Fatalf("client failed: %+v", err) + } + }) + } +} + +func TestConn(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + client func(ctx context.Context, c *websocket.Conn) error + server func(ctx context.Context, c *websocket.Conn) error + }{ + { + name: "closeError", + server: func(ctx context.Context, c *websocket.Conn) error { + return wsjson.Write(ctx, c, "hello") + }, + client: func(ctx context.Context, c *websocket.Conn) error { + var m string + err := wsjson.Read(ctx, c, &m) + if err != nil { + return err + } + + if m != "hello" { + return xerrors.Errorf("recieved unexpected msg but expected hello: %+v", m) + } + + _, _, err = c.Reader(ctx) + var cerr websocket.CloseError + if !xerrors.As(err, &cerr) || cerr.Code != websocket.StatusInternalError { + return xerrors.Errorf("unexpected error: %+v", err) + } + + return nil + }, + }, + { + name: "netConn", + server: func(ctx context.Context, c *websocket.Conn) error { + nc := websocket.NetConn(c, websocket.MessageBinary) + defer nc.Close() + + nc.SetWriteDeadline(time.Time{}) + time.Sleep(1) + nc.SetWriteDeadline(time.Now().Add(time.Second * 15)) + + if nc.LocalAddr() != (websocket.Addr{}) { + return xerrors.Errorf("net conn local address is not equal to websocket.Addr") + } + if nc.RemoteAddr() != (websocket.Addr{}) { + return xerrors.Errorf("net conn remote address is not equal to websocket.Addr") + } + + for i := 0; i < 3; i++ { + _, err := nc.Write([]byte("hello")) + if err != nil { + return err + } + } + + return nil + }, + client: func(ctx context.Context, c *websocket.Conn) error { + nc := websocket.NetConn(c, websocket.MessageBinary) + defer nc.Close() + + nc.SetReadDeadline(time.Time{}) + time.Sleep(1) + nc.SetReadDeadline(time.Now().Add(time.Second * 15)) + + read := func() error { + p := make([]byte, len("hello")) + // We do not use io.ReadFull here as it masks EOFs. + // See https://github.com/nhooyr/websocket/issues/100#issuecomment-508148024 + _, err := nc.Read(p) + if err != nil { + return err + } + + if string(p) != "hello" { + return xerrors.Errorf("unexpected payload %q received", string(p)) + } + return nil + } + + for i := 0; i < 3; i++ { + err := read() + if err != nil { + return err + } + } + + // Ensure the close frame is converted to an EOF and multiple read's after all return EOF. + err := read() + if err != io.EOF { + return err + } + + err = read() + if err != io.EOF { + return err + } + + return nil + }, + }, + { + name: "netConn/badReadMsgType", + server: func(ctx context.Context, c *websocket.Conn) error { + nc := websocket.NetConn(c, websocket.MessageBinary) + defer nc.Close() + + nc.SetDeadline(time.Now().Add(time.Second * 15)) + + _, err := nc.Read(make([]byte, 1)) + if err == nil { + return xerrors.Errorf("expected error") + } + + return nil + }, + client: func(ctx context.Context, c *websocket.Conn) error { + err := wsjson.Write(ctx, c, "meow") + if err != nil { + return err + } + + _, _, err = c.Read(ctx) + cerr := &websocket.CloseError{} + if !xerrors.As(err, cerr) || cerr.Code != websocket.StatusUnsupportedData { + return xerrors.Errorf("expected close error with code StatusUnsupportedData: %+v", err) + } + + return nil + }, + }, + { + name: "netConn/badRead", + server: func(ctx context.Context, c *websocket.Conn) error { + nc := websocket.NetConn(c, websocket.MessageBinary) + defer nc.Close() + + nc.SetDeadline(time.Now().Add(time.Second * 15)) + + _, err := nc.Read(make([]byte, 1)) + cerr := &websocket.CloseError{} + if !xerrors.As(err, cerr) || cerr.Code != websocket.StatusBadGateway { + return xerrors.Errorf("expected close error with code StatusBadGateway: %+v", err) + } + + _, err = nc.Write([]byte{0xff}) + if err == nil { + return xerrors.Errorf("expected writes to fail after reading a close frame: %v", err) + } + return nil }, + client: func(ctx context.Context, c *websocket.Conn) error { + return c.Close(websocket.StatusBadGateway, "") + }, }, { name: "jsonEcho", - server: func(w http.ResponseWriter, r *http.Request) error { - c, err := websocket.Accept(w, r, websocket.AcceptOptions{}) - if err != nil { - return err - } - defer c.Close(websocket.StatusInternalError, "") - - ctx, cancel := context.WithTimeout(r.Context(), time.Second*5) - defer cancel() - + server: func(ctx context.Context, c *websocket.Conn) error { write := func() error { v := map[string]interface{}{ "anmol": "wowow", @@ -346,7 +446,7 @@ func TestHandshake(t *testing.T) { err := wsjson.Write(ctx, c, v) return err } - err = write() + err := write() if err != nil { return err } @@ -358,13 +458,7 @@ func TestHandshake(t *testing.T) { c.Close(websocket.StatusNormalClosure, "") return nil }, - client: func(ctx context.Context, u string) error { - c, _, err := websocket.Dial(ctx, u, websocket.DialOptions{}) - if err != nil { - return err - } - defer c.Close(websocket.StatusInternalError, "") - + client: func(ctx context.Context, c *websocket.Conn) error { read := func() error { var v interface{} err := wsjson.Read(ctx, c, &v) @@ -380,7 +474,7 @@ func TestHandshake(t *testing.T) { } return nil } - err = read() + err := read() if err != nil { return err } @@ -395,21 +489,12 @@ func TestHandshake(t *testing.T) { }, { name: "protobufEcho", - server: func(w http.ResponseWriter, r *http.Request) error { - c, err := websocket.Accept(w, r, websocket.AcceptOptions{}) - if err != nil { - return err - } - defer c.Close(websocket.StatusInternalError, "") - - ctx, cancel := context.WithTimeout(r.Context(), time.Second*5) - defer cancel() - + server: func(ctx context.Context, c *websocket.Conn) error { write := func() error { err := wspb.Write(ctx, c, ptypes.DurationProto(100)) return err } - err = write() + err := write() if err != nil { return err } @@ -417,13 +502,7 @@ func TestHandshake(t *testing.T) { c.Close(websocket.StatusNormalClosure, "") return nil }, - client: func(ctx context.Context, u string) error { - c, _, err := websocket.Dial(ctx, u, websocket.DialOptions{}) - if err != nil { - return err - } - defer c.Close(websocket.StatusInternalError, "") - + client: func(ctx context.Context, c *websocket.Conn) error { read := func() error { var v duration.Duration err := wspb.Read(ctx, c, &v) @@ -441,7 +520,7 @@ func TestHandshake(t *testing.T) { } return nil } - err = read() + err := read() if err != nil { return err } @@ -450,73 +529,21 @@ func TestHandshake(t *testing.T) { return nil }, }, - { - name: "cookies", - server: func(w http.ResponseWriter, r *http.Request) error { - cookie, err := r.Cookie("mycookie") - if err != nil { - return xerrors.Errorf("request is missing mycookie: %w", err) - } - if cookie.Value != "myvalue" { - return xerrors.Errorf("expected %q but got %q", "myvalue", cookie.Value) - } - c, err := websocket.Accept(w, r, websocket.AcceptOptions{}) - if err != nil { - return err - } - c.Close(websocket.StatusInternalError, "") - return nil - }, - client: func(ctx context.Context, u string) error { - jar, err := cookiejar.New(nil) - if err != nil { - return xerrors.Errorf("failed to create cookie jar: %w", err) - } - parsedURL, err := url.Parse(u) - if err != nil { - return xerrors.Errorf("failed to parse url: %w", err) - } - parsedURL.Scheme = "http" - jar.SetCookies(parsedURL, []*http.Cookie{ - { - Name: "mycookie", - Value: "myvalue", - }, - }) - hc := &http.Client{ - Jar: jar, - } - c, _, err := websocket.Dial(ctx, u, websocket.DialOptions{ - HTTPClient: hc, - }) - if err != nil { - return err - } - c.Close(websocket.StatusInternalError, "") - return nil - }, - }, { name: "ping", - server: func(w http.ResponseWriter, r *http.Request) error { - c, err := websocket.Accept(w, r, websocket.AcceptOptions{}) - if err != nil { - return err - } - defer c.Close(websocket.StatusInternalError, "") - + server: func(ctx context.Context, c *websocket.Conn) error { errc := make(chan error, 1) go func() { - _, _, err2 := c.Read(r.Context()) + _, _, err2 := c.Read(ctx) errc <- err2 }() - err = c.Ping(r.Context()) + err := c.Ping(ctx) if err != nil { return err } - err = c.Write(r.Context(), websocket.MessageText, []byte("hi")) + err = c.Write(ctx, websocket.MessageText, []byte("hi")) if err != nil { return err } @@ -528,13 +555,7 @@ func TestHandshake(t *testing.T) { } return xerrors.Errorf("unexpected error: %w", err) }, - client: func(ctx context.Context, u string) error { - c, _, err := websocket.Dial(ctx, u, websocket.DialOptions{}) - if err != nil { - return err - } - defer c.Close(websocket.StatusInternalError, "") - + client: func(ctx context.Context, c *websocket.Conn) error { // We read a message from the connection and then keep reading until // the Ping completes. done := make(chan struct{}) @@ -550,7 +571,7 @@ func TestHandshake(t *testing.T) { c.Read(ctx) }() - err = c.Ping(ctx) + err := c.Ping(ctx) if err != nil { return err } @@ -563,29 +584,17 @@ func TestHandshake(t *testing.T) { }, { name: "readLimit", - server: func(w http.ResponseWriter, r *http.Request) error { - c, err := websocket.Accept(w, r, websocket.AcceptOptions{}) - if err != nil { - return err - } - defer c.Close(websocket.StatusInternalError, "") - - _, _, err = c.Read(r.Context()) + server: func(ctx context.Context, c *websocket.Conn) error { + _, _, err := c.Read(ctx) if err == nil { return xerrors.Errorf("expected error but got nil") } return nil }, - client: func(ctx context.Context, u string) error { - c, _, err := websocket.Dial(ctx, u, websocket.DialOptions{}) - if err != nil { - return err - } - defer c.Close(websocket.StatusInternalError, "") + client: func(ctx context.Context, c *websocket.Conn) error { + go c.CloseRead(ctx) - go c.Reader(ctx) - - err = c.Write(ctx, websocket.MessageBinary, []byte(strings.Repeat("x", 32769))) + err := c.Write(ctx, websocket.MessageBinary, []byte(strings.Repeat("x", 32769))) if err != nil { return err } @@ -600,20 +609,244 @@ func TestHandshake(t *testing.T) { return nil }, }, - } + { + name: "wsjson/binary", + server: func(ctx context.Context, c *websocket.Conn) error { + var v interface{} + err := wsjson.Read(ctx, c, &v) + if err == nil { + return xerrors.Errorf("expected error: %v", err) + } + return nil + }, + client: func(ctx context.Context, c *websocket.Conn) error { + return wspb.Write(ctx, c, ptypes.DurationProto(100)) + }, + }, + { + name: "wsjson/badRead", + server: func(ctx context.Context, c *websocket.Conn) error { + var v interface{} + err := wsjson.Read(ctx, c, &v) + if err == nil { + return xerrors.Errorf("expected error: %v", err) + } + return nil + }, + client: func(ctx context.Context, c *websocket.Conn) error { + return c.Write(ctx, websocket.MessageText, []byte("notjson")) + }, + }, + { + name: "wsjson/badWrite", + server: func(ctx context.Context, c *websocket.Conn) error { + _, _, err := c.Read(ctx) + if err == nil { + return xerrors.Errorf("expected error: %v", err) + } + return nil + }, + client: func(ctx context.Context, c *websocket.Conn) error { + err := wsjson.Write(ctx, c, fmt.Println) + if err == nil { + return xerrors.Errorf("expected error: %v", err) + } + return nil + }, + }, + { + name: "wspb/text", + server: func(ctx context.Context, c *websocket.Conn) error { + var v proto.Message + err := wspb.Read(ctx, c, v) + if err == nil { + return xerrors.Errorf("expected error: %v", err) + } + return nil + }, + client: func(ctx context.Context, c *websocket.Conn) error { + return wsjson.Write(ctx, c, "hi") + }, + }, + { + name: "wspb/badRead", + server: func(ctx context.Context, c *websocket.Conn) error { + var v timestamp.Timestamp + err := wspb.Read(ctx, c, &v) + if err == nil { + return xerrors.Errorf("expected error: %v", err) + } + return nil + }, + client: func(ctx context.Context, c *websocket.Conn) error { + return c.Write(ctx, websocket.MessageBinary, []byte("notpb")) + }, + }, + { + name: "wspb/badWrite", + server: func(ctx context.Context, c *websocket.Conn) error { + _, _, err := c.Read(ctx) + if err == nil { + return xerrors.Errorf("expected error: %v", err) + } + return nil + }, + client: func(ctx context.Context, c *websocket.Conn) error { + err := wspb.Write(ctx, c, nil) + if err == nil { + return xerrors.Errorf("expected error: %v", err) + } + return nil + }, + }, + { + name: "wspb/badWrite", + server: func(ctx context.Context, c *websocket.Conn) error { + _, _, err := c.Read(ctx) + if err == nil { + return xerrors.Errorf("expected error: %v", err) + } + return nil + }, + client: func(ctx context.Context, c *websocket.Conn) error { + err := wspb.Write(ctx, c, nil) + if err == nil { + return xerrors.Errorf("expected error: %v", err) + } + return nil + }, + }, + { + name: "badClose", + server: func(ctx context.Context, c *websocket.Conn) error { + return c.Close(9999, "") + }, + client: func(ctx context.Context, c *websocket.Conn) error { + _, _, err := c.Read(ctx) + cerr := &websocket.CloseError{} + if !xerrors.As(err, cerr) || cerr.Code != websocket.StatusInternalError { + return xerrors.Errorf("expected close error with StatusInternalError: %+v", err) + } + return nil + }, + }, + { + name: "pingTimeout", + server: func(ctx context.Context, c *websocket.Conn) error { + ctx, cancel := context.WithTimeout(ctx, time.Second) + defer cancel() + err := c.Ping(ctx) + if err == nil { + return xerrors.Errorf("expected nil error: %+v", err) + } + return nil + }, + client: func(ctx context.Context, c *websocket.Conn) error { + time.Sleep(time.Second) + return nil + }, + }, + { + name: "writeTimeout", + server: func(ctx context.Context, c *websocket.Conn) error { + c.Writer(ctx, websocket.MessageBinary) + ctx, cancel := context.WithTimeout(ctx, time.Second) + defer cancel() + err := c.Write(ctx, websocket.MessageBinary, []byte("meow")) + if !xerrors.Is(err, context.DeadlineExceeded) { + return xerrors.Errorf("expected deadline exceeded error: %+v", err) + } + return nil + }, + client: func(ctx context.Context, c *websocket.Conn) error { + time.Sleep(time.Second) + return nil + }, + }, + { + name: "readTimeout", + server: func(ctx context.Context, c *websocket.Conn) error { + ctx, cancel := context.WithTimeout(ctx, time.Second) + defer cancel() + _, r, err := c.Reader(ctx) + if err != nil { + return err + } + <-ctx.Done() + _, err = r.Read(make([]byte, 1)) + if !xerrors.Is(err, context.DeadlineExceeded){ + return xerrors.Errorf("expected deadline exceeded error: %+v", err) + } + return nil + }, + client: func(ctx context.Context, c *websocket.Conn) error { + time.Sleep(time.Second) + return nil + }, + }, + { + name: "badOpCode", + server: func(ctx context.Context, c *websocket.Conn) error { + _, err := c.WriteFrame(ctx, true, 13, []byte("meow")) + if err != nil { + return err + } + _, _, err = c.Read(ctx) + cerr := &websocket.CloseError{} + if !xerrors.As(err, cerr) || cerr.Code != websocket.StatusProtocolError { + return xerrors.Errorf("expected close error with StatusProtocolError: %+v", err) + } + return nil + }, + client: func(ctx context.Context, c *websocket.Conn) error { + _, _, err := c.Read(ctx) + if err == nil || strings.Contains(err.Error(), "opcode") { + return xerrors.Errorf("expected error that contains opcode: %+v", err) + } + return nil + }, + }, + { + name: "noRsv", + server: func(ctx context.Context, c *websocket.Conn) error { + _, err := c.WriteFrame(ctx, true, 99, []byte("meow")) + if err != nil { + return err + } + _, _, err = c.Read(ctx) + cerr := &websocket.CloseError{} + if !xerrors.As(err, cerr) || cerr.Code != websocket.StatusProtocolError { + return xerrors.Errorf("expected close error with StatusProtocolError: %+v", err) + } + return nil + }, + client: func(ctx context.Context, c *websocket.Conn) error { + _, _, err := c.Read(ctx) + if err == nil || !strings.Contains(err.Error(), "rsv") { + return xerrors.Errorf("expected error that contains rsv: %+v", err) + } + return nil + }, + }, + } for _, tc := range testCases { tc := tc t.Run(tc.name, func(t *testing.T) { t.Parallel() - s, closeFn := testServer(t, func(w http.ResponseWriter, r *http.Request) { - err := tc.server(w, r) + // Run random tests over TLS. + tls := rand.Intn(2) == 1 + + s, closeFn := testServer(t, func(w http.ResponseWriter, r *http.Request) error { + c, err := websocket.Accept(w, r, websocket.AcceptOptions{}) if err != nil { - t.Errorf("server failed: %+v", err) - return + return err } - }) + defer c.Close(websocket.StatusInternalError, "") + tc.server(r.Context(), c) + return nil + }, tls) defer closeFn() wsURL := strings.Replace(s.URL, "http", "ws", 1) @@ -621,7 +854,18 @@ func TestHandshake(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), time.Minute) defer cancel() - err := tc.client(ctx, wsURL) + opts := websocket.DialOptions{} + if tls { + opts.HTTPClient = s.Client() + } + + c, _, err := websocket.Dial(ctx, wsURL, opts) + if err != nil { + t.Fatal(err) + } + defer c.Close(websocket.StatusInternalError, "") + + err = tc.client(ctx, c) if err != nil { t.Fatalf("client failed: %+v", err) } @@ -629,14 +873,31 @@ func TestHandshake(t *testing.T) { } } -func testServer(tb testing.TB, fn http.HandlerFunc) (s *httptest.Server, closeFn func()) { +func init() { + rand.Seed(time.Now().UnixNano()) +} + +func testServer(tb testing.TB, fn func(w http.ResponseWriter, r *http.Request) error, tls bool) (s *httptest.Server, closeFn func()) { var conns int64 - s = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { atomic.AddInt64(&conns, 1) defer atomic.AddInt64(&conns, -1) - fn.ServeHTTP(w, r) - })) + ctx, cancel := context.WithTimeout(r.Context(), time.Second*30) + defer cancel() + + r = r.WithContext(ctx) + + err := fn(w, r) + if err != nil { + tb.Errorf("server failed: %+v", err) + } + }) + if tls { + s = httptest.NewTLSServer(h) + } else { + s = httptest.NewServer(h) + } return s, func() { s.Close() @@ -654,7 +915,9 @@ func testServer(tb testing.TB, fn http.HandlerFunc) (s *httptest.Server, closeFn // https://github.com/crossbario/autobahn-python/tree/master/wstest func TestAutobahnServer(t *testing.T) { t.Parallel() - t.Skip() + if os.Getenv("AUTOBAHN") == "" { + t.Skip("Set $AUTOBAHN to run the autobahn test suite.") + } s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { c, err := websocket.Accept(w, r, websocket.AcceptOptions{ @@ -795,7 +1058,9 @@ func unusedListenAddr() (string, error) { // https://github.com/crossbario/autobahn-python/blob/master/wstest/testee_client_aio.py func TestAutobahnClient(t *testing.T) { t.Parallel() - t.Skip() + if os.Getenv("AUTOBAHN") == "" { + t.Skip("Set $AUTOBAHN to run the autobahn test suite.") + } serverAddr, err := unusedListenAddr() if err != nil { @@ -941,18 +1206,18 @@ func checkWSTestIndex(t *testing.T, path string) { } func benchConn(b *testing.B, echo, stream bool, size int) { - s, closeFn := testServer(b, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + s, closeFn := testServer(b, func(w http.ResponseWriter, r *http.Request) error { c, err := websocket.Accept(w, r, websocket.AcceptOptions{}) if err != nil { - b.Logf("server handshake failed: %+v", err) - return + return err } if echo { echoLoop(r.Context(), c) } else { discardLoop(r.Context(), c) } - })) + return nil + }, false) defer closeFn() wsURL := strings.Replace(s.URL, "http", "ws", 1) From 537b26b9c25f621a1e6299b8397ed9684838c12a Mon Sep 17 00:00:00 2001 From: Anmol Sethi Date: Thu, 29 Aug 2019 17:07:20 -0500 Subject: [PATCH 3/5] Change options to be pointer structures Closes #122 --- README.md | 4 ++-- accept.go | 8 ++++++-- accept_test.go | 4 ++-- dial.go | 12 ++++++++++-- dial_test.go | 4 ++-- example_echo_test.go | 4 ++-- example_test.go | 6 +++--- websocket_test.go | 46 ++++++++++++++++++++++---------------------- 8 files changed, 50 insertions(+), 38 deletions(-) diff --git a/README.md b/README.md index cf20b877..d53046c8 100644 --- a/README.md +++ b/README.md @@ -34,7 +34,7 @@ For a production quality example that shows off the full API, see the [echo exam ```go http.HandlerFunc(func (w http.ResponseWriter, r *http.Request) { - c, err := websocket.Accept(w, r, websocket.AcceptOptions{}) + c, err := websocket.Accept(w, r, nil) if err != nil { // ... } @@ -64,7 +64,7 @@ in net/http](https://github.com/golang/go/issues/26937#issuecomment-415855861) t ctx, cancel := context.WithTimeout(context.Background(), time.Minute) defer cancel() -c, _, err := websocket.Dial(ctx, "ws://localhost:8080", websocket.DialOptions{}) +c, _, err := websocket.Dial(ctx, "ws://localhost:8080", nil) if err != nil { // ... } diff --git a/accept.go b/accept.go index 7b727d16..afad1be2 100644 --- a/accept.go +++ b/accept.go @@ -84,7 +84,7 @@ func verifyClientRequest(w http.ResponseWriter, r *http.Request) error { // // If an error occurs, Accept will always write an appropriate response so you do not // have to. -func Accept(w http.ResponseWriter, r *http.Request, opts AcceptOptions) (*Conn, error) { +func Accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (*Conn, error) { c, err := accept(w, r, opts) if err != nil { return nil, xerrors.Errorf("failed to accept websocket connection: %w", err) @@ -92,7 +92,11 @@ func Accept(w http.ResponseWriter, r *http.Request, opts AcceptOptions) (*Conn, return c, nil } -func accept(w http.ResponseWriter, r *http.Request, opts AcceptOptions) (*Conn, error) { +func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (*Conn, error) { + if opts == nil { + opts = &AcceptOptions{} + } + err := verifyClientRequest(w, r) if err != nil { return nil, err diff --git a/accept_test.go b/accept_test.go index 8634066b..6602a8d0 100644 --- a/accept_test.go +++ b/accept_test.go @@ -15,7 +15,7 @@ func TestAccept(t *testing.T) { w := httptest.NewRecorder() r := httptest.NewRequest("GET", "/", nil) - _, err := Accept(w, r, AcceptOptions{}) + _, err := Accept(w, r, nil) if err == nil { t.Fatalf("unexpected error value: %v", err) } @@ -32,7 +32,7 @@ func TestAccept(t *testing.T) { r.Header.Set("Sec-WebSocket-Version", "13") r.Header.Set("Sec-WebSocket-Key", "meow123") - _, err := Accept(w, r, AcceptOptions{}) + _, err := Accept(w, r, nil) if err == nil || !strings.Contains(err.Error(), "http.Hijacker") { t.Fatalf("unexpected error value: %v", err) } diff --git a/dial.go b/dial.go index ac632c11..461817f6 100644 --- a/dial.go +++ b/dial.go @@ -41,7 +41,7 @@ type DialOptions struct { // This function requires at least Go 1.12 to succeed as it uses a new feature // in net/http to perform WebSocket handshakes and get a writable body // from the transport. See https://github.com/golang/go/issues/26937#issuecomment-415855861 -func Dial(ctx context.Context, u string, opts DialOptions) (*Conn, *http.Response, error) { +func Dial(ctx context.Context, u string, opts *DialOptions) (*Conn, *http.Response, error) { c, r, err := dial(ctx, u, opts) if err != nil { return nil, r, xerrors.Errorf("failed to websocket dial: %w", err) @@ -49,7 +49,15 @@ func Dial(ctx context.Context, u string, opts DialOptions) (*Conn, *http.Respons return c, r, nil } -func dial(ctx context.Context, u string, opts DialOptions) (_ *Conn, _ *http.Response, err error) { +func dial(ctx context.Context, u string, opts *DialOptions) (_ *Conn, _ *http.Response, err error) { + if opts == nil { + opts = &DialOptions{} + } + + // Shallow copy to ensure defaults do not affect user passed options. + opts2 := *opts + opts = &opts2 + if opts.HTTPClient == nil { opts.HTTPClient = http.DefaultClient } diff --git a/dial_test.go b/dial_test.go index 4607493b..96537bdb 100644 --- a/dial_test.go +++ b/dial_test.go @@ -14,7 +14,7 @@ func TestBadDials(t *testing.T) { testCases := []struct { name string url string - opts DialOptions + opts *DialOptions }{ { name: "badURL", @@ -27,7 +27,7 @@ func TestBadDials(t *testing.T) { { name: "badHTTPClient", url: "ws://nhooyr.io", - opts: DialOptions{ + opts: &DialOptions{ HTTPClient: &http.Client{ Timeout: time.Minute, }, diff --git a/example_echo_test.go b/example_echo_test.go index 6923bc04..3e7e7f9d 100644 --- a/example_echo_test.go +++ b/example_echo_test.go @@ -68,7 +68,7 @@ func Example_echo() { func echoServer(w http.ResponseWriter, r *http.Request) error { log.Printf("serving %v", r.RemoteAddr) - c, err := websocket.Accept(w, r, websocket.AcceptOptions{ + c, err := websocket.Accept(w, r, &websocket.AcceptOptions{ Subprotocols: []string{"echo"}, }) if err != nil { @@ -128,7 +128,7 @@ func client(url string) error { ctx, cancel := context.WithTimeout(context.Background(), time.Minute) defer cancel() - c, _, err := websocket.Dial(ctx, url, websocket.DialOptions{ + c, _, err := websocket.Dial(ctx, url, &websocket.DialOptions{ Subprotocols: []string{"echo"}, }) if err != nil { diff --git a/example_test.go b/example_test.go index 0b59e6a0..22c31202 100644 --- a/example_test.go +++ b/example_test.go @@ -14,7 +14,7 @@ import ( // message from the client and then closes the connection. func ExampleAccept() { fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - c, err := websocket.Accept(w, r, websocket.AcceptOptions{}) + c, err := websocket.Accept(w, r, nil) if err != nil { log.Println(err) return @@ -46,7 +46,7 @@ func ExampleDial() { ctx, cancel := context.WithTimeout(context.Background(), time.Minute) defer cancel() - c, _, err := websocket.Dial(ctx, "ws://localhost:8080", websocket.DialOptions{}) + c, _, err := websocket.Dial(ctx, "ws://localhost:8080", nil) if err != nil { log.Fatal(err) } @@ -64,7 +64,7 @@ func ExampleDial() { // on which you will only write and do not expect to read data messages. func Example_writeOnly() { fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - c, err := websocket.Accept(w, r, websocket.AcceptOptions{}) + c, err := websocket.Accept(w, r, nil) if err != nil { log.Println(err) return diff --git a/websocket_test.go b/websocket_test.go index b45f024f..1f1b5245 100644 --- a/websocket_test.go +++ b/websocket_test.go @@ -44,7 +44,7 @@ func TestHandshake(t *testing.T) { { name: "handshake", server: func(w http.ResponseWriter, r *http.Request) error { - c, err := websocket.Accept(w, r, websocket.AcceptOptions{ + c, err := websocket.Accept(w, r, &websocket.AcceptOptions{ Subprotocols: []string{"myproto"}, }) if err != nil { @@ -54,7 +54,7 @@ func TestHandshake(t *testing.T) { return nil }, client: func(ctx context.Context, u string) error { - c, resp, err := websocket.Dial(ctx, u, websocket.DialOptions{ + c, resp, err := websocket.Dial(ctx, u, &websocket.DialOptions{ Subprotocols: []string{"myproto"}, }) if err != nil { @@ -81,7 +81,7 @@ func TestHandshake(t *testing.T) { { name: "defaultSubprotocol", server: func(w http.ResponseWriter, r *http.Request) error { - c, err := websocket.Accept(w, r, websocket.AcceptOptions{}) + c, err := websocket.Accept(w, r, nil) if err != nil { return err } @@ -93,7 +93,7 @@ func TestHandshake(t *testing.T) { return nil }, client: func(ctx context.Context, u string) error { - c, _, err := websocket.Dial(ctx, u, websocket.DialOptions{ + c, _, err := websocket.Dial(ctx, u, &websocket.DialOptions{ Subprotocols: []string{"meow"}, }) if err != nil { @@ -110,7 +110,7 @@ func TestHandshake(t *testing.T) { { name: "subprotocol", server: func(w http.ResponseWriter, r *http.Request) error { - c, err := websocket.Accept(w, r, websocket.AcceptOptions{ + c, err := websocket.Accept(w, r, &websocket.AcceptOptions{ Subprotocols: []string{"echo", "lar"}, }) if err != nil { @@ -124,7 +124,7 @@ func TestHandshake(t *testing.T) { return nil }, client: func(ctx context.Context, u string) error { - c, _, err := websocket.Dial(ctx, u, websocket.DialOptions{ + c, _, err := websocket.Dial(ctx, u, &websocket.DialOptions{ Subprotocols: []string{"poof", "echo"}, }) if err != nil { @@ -141,7 +141,7 @@ func TestHandshake(t *testing.T) { { name: "badOrigin", server: func(w http.ResponseWriter, r *http.Request) error { - c, err := websocket.Accept(w, r, websocket.AcceptOptions{}) + c, err := websocket.Accept(w, r, nil) if err == nil { c.Close(websocket.StatusInternalError, "") return xerrors.New("expected error regarding bad origin") @@ -151,7 +151,7 @@ func TestHandshake(t *testing.T) { client: func(ctx context.Context, u string) error { h := http.Header{} h.Set("Origin", "http://unauthorized.com") - c, _, err := websocket.Dial(ctx, u, websocket.DialOptions{ + c, _, err := websocket.Dial(ctx, u, &websocket.DialOptions{ HTTPHeader: h, }) if err == nil { @@ -164,7 +164,7 @@ func TestHandshake(t *testing.T) { { name: "acceptSecureOrigin", server: func(w http.ResponseWriter, r *http.Request) error { - c, err := websocket.Accept(w, r, websocket.AcceptOptions{}) + c, err := websocket.Accept(w, r, nil) if err != nil { return err } @@ -174,7 +174,7 @@ func TestHandshake(t *testing.T) { client: func(ctx context.Context, u string) error { h := http.Header{} h.Set("Origin", u) - c, _, err := websocket.Dial(ctx, u, websocket.DialOptions{ + c, _, err := websocket.Dial(ctx, u, &websocket.DialOptions{ HTTPHeader: h, }) if err != nil { @@ -187,7 +187,7 @@ func TestHandshake(t *testing.T) { { name: "acceptInsecureOrigin", server: func(w http.ResponseWriter, r *http.Request) error { - c, err := websocket.Accept(w, r, websocket.AcceptOptions{ + c, err := websocket.Accept(w, r, &websocket.AcceptOptions{ InsecureSkipVerify: true, }) if err != nil { @@ -199,7 +199,7 @@ func TestHandshake(t *testing.T) { client: func(ctx context.Context, u string) error { h := http.Header{} h.Set("Origin", "https://example.com") - c, _, err := websocket.Dial(ctx, u, websocket.DialOptions{ + c, _, err := websocket.Dial(ctx, u, &websocket.DialOptions{ HTTPHeader: h, }) if err != nil { @@ -219,7 +219,7 @@ func TestHandshake(t *testing.T) { if cookie.Value != "myvalue" { return xerrors.Errorf("expected %q but got %q", "myvalue", cookie.Value) } - c, err := websocket.Accept(w, r, websocket.AcceptOptions{}) + c, err := websocket.Accept(w, r, nil) if err != nil { return err } @@ -245,7 +245,7 @@ func TestHandshake(t *testing.T) { hc := &http.Client{ Jar: jar, } - c, _, err := websocket.Dial(ctx, u, websocket.DialOptions{ + c, _, err := websocket.Dial(ctx, u, &websocket.DialOptions{ HTTPClient: hc, }) if err != nil { @@ -801,7 +801,7 @@ func TestConn(t *testing.T) { }, client: func(ctx context.Context, c *websocket.Conn) error { _, _, err := c.Read(ctx) - if err == nil || strings.Contains(err.Error(), "opcode") { + if err == nil || !strings.Contains(err.Error(), "opcode") { return xerrors.Errorf("expected error that contains opcode: %+v", err) } return nil @@ -839,7 +839,7 @@ func TestConn(t *testing.T) { tls := rand.Intn(2) == 1 s, closeFn := testServer(t, func(w http.ResponseWriter, r *http.Request) error { - c, err := websocket.Accept(w, r, websocket.AcceptOptions{}) + c, err := websocket.Accept(w, r, nil) if err != nil { return err } @@ -854,7 +854,7 @@ func TestConn(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), time.Minute) defer cancel() - opts := websocket.DialOptions{} + opts := &websocket.DialOptions{} if tls { opts.HTTPClient = s.Client() } @@ -920,7 +920,7 @@ func TestAutobahnServer(t *testing.T) { } s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - c, err := websocket.Accept(w, r, websocket.AcceptOptions{ + c, err := websocket.Accept(w, r, &websocket.AcceptOptions{ Subprotocols: []string{"echo"}, }) if err != nil { @@ -1120,7 +1120,7 @@ func TestAutobahnClient(t *testing.T) { var cases int func() { - c, _, err := websocket.Dial(ctx, wsServerURL+"/getCaseCount", websocket.DialOptions{}) + c, _, err := websocket.Dial(ctx, wsServerURL+"/getCaseCount", nil) if err != nil { t.Fatal(err) } @@ -1147,7 +1147,7 @@ func TestAutobahnClient(t *testing.T) { ctx, cancel := context.WithTimeout(ctx, time.Second*45) defer cancel() - c, _, err := websocket.Dial(ctx, fmt.Sprintf(wsServerURL+"/runCase?case=%v&agent=main", i), websocket.DialOptions{}) + c, _, err := websocket.Dial(ctx, fmt.Sprintf(wsServerURL+"/runCase?case=%v&agent=main", i), nil) if err != nil { t.Fatal(err) } @@ -1155,7 +1155,7 @@ func TestAutobahnClient(t *testing.T) { }() } - c, _, err := websocket.Dial(ctx, fmt.Sprintf(wsServerURL+"/updateReports?agent=main"), websocket.DialOptions{}) + c, _, err := websocket.Dial(ctx, fmt.Sprintf(wsServerURL+"/updateReports?agent=main"), nil) if err != nil { t.Fatal(err) } @@ -1207,7 +1207,7 @@ func checkWSTestIndex(t *testing.T, path string) { func benchConn(b *testing.B, echo, stream bool, size int) { s, closeFn := testServer(b, func(w http.ResponseWriter, r *http.Request) error { - c, err := websocket.Accept(w, r, websocket.AcceptOptions{}) + c, err := websocket.Accept(w, r, nil) if err != nil { return err } @@ -1225,7 +1225,7 @@ func benchConn(b *testing.B, echo, stream bool, size int) { ctx, cancel := context.WithTimeout(context.Background(), time.Minute*5) defer cancel() - c, _, err := websocket.Dial(ctx, wsURL, websocket.DialOptions{}) + c, _, err := websocket.Dial(ctx, wsURL, nil) if err != nil { b.Fatal(err) } From de687ea0f90d0873473092dde1ed0ae1f6b9424c Mon Sep 17 00:00:00 2001 From: Anmol Sethi Date: Fri, 30 Aug 2019 13:34:36 -0500 Subject: [PATCH 4/5] More test coverage and updated CONTRIBUTING.md --- ci/image/Dockerfile | 5 +- docs/CONTRIBUTING.md | 19 ++++-- export_test.go | 7 ++ websocket_test.go | 154 +++++++++++++++++++++++++++++++++++++++++-- 4 files changed, 171 insertions(+), 14 deletions(-) diff --git a/ci/image/Dockerfile b/ci/image/Dockerfile index d435e949..4477d646 100644 --- a/ci/image/Dockerfile +++ b/ci/image/Dockerfile @@ -6,8 +6,7 @@ ENV GOFLAGS="-mod=readonly" ENV PAGER=cat RUN apt-get update && \ - apt-get install -y shellcheck python-pip npm && \ - pip2 install autobahntestsuite && \ + apt-get install -y shellcheck npm && \ npm install -g prettier -RUN git config --global color.ui always \ No newline at end of file +RUN git config --global color.ui always diff --git a/docs/CONTRIBUTING.md b/docs/CONTRIBUTING.md index a0c97261..f003e743 100644 --- a/docs/CONTRIBUTING.md +++ b/docs/CONTRIBUTING.md @@ -34,16 +34,21 @@ browse coverage. You can run CI locally. The various steps are located in `ci/*.sh`. -1. `ci/fmt.sh` requires node (specifically prettier). -1. `ci/lint.sh` requires [shellcheck](https://github.com/koalaman/shellcheck#installing). -1. `ci/test.sh` requires the [Autobahn Test suite pip package](https://github.com/crossbario/autobahn-testsuite). -1. `ci/run.sh` runs the above scripts in order. +1. `ci/fmt.sh` which requires node (specifically prettier). +1. `ci/lint.sh` which requires [shellcheck](https://github.com/koalaman/shellcheck#installing). +1. `ci/test.sh` +1. `ci/run.sh` which runs the above scripts in order. For coverage details locally, please see `ci/out/coverage.html` after running `ci/test.sh`. See [ci/image/Dockerfile](ci/image/Dockerfile) for the installation of the CI dependencies on Ubuntu. -You can also run tests normally with `go test` once you have the -[Autobahn Test suite pip package](https://github.com/crossbario/autobahn-testsuite) -installed. `ci/test.sh` just passes a default set of flags to `go test` to collect coverage, +You can also run tests normally with `go test`. +`ci/test.sh` just passes a default set of flags to `go test` to collect coverage, enable the race detector, run benchmarks and also prettifies the output. + +If you pass flags to `ci/test.sh`, it will pass those flags directly to `go test` but will also +collect coverage for you. This is nice for when you don't want to wait for benchmarks +or the race detector but want to have coverage. + +Coverage percentage from codecov and the CI scripts will be different because they are calculated differently. diff --git a/export_test.go b/export_test.go index ab766f14..9c65360a 100644 --- a/export_test.go +++ b/export_test.go @@ -8,6 +8,13 @@ type Addr = websocketAddr type Header = header +const OPClose = opClose +const OPPing = opPing + func (c *Conn) WriteFrame(ctx context.Context, fin bool, opcode opcode, p []byte) (int, error) { return c.writeFrame(ctx, fin, opcode, p) } + +func (c *Conn) Flush() error { + return c.bw.Flush() +} diff --git a/websocket_test.go b/websocket_test.go index 1f1b5245..73020f5e 100644 --- a/websocket_test.go +++ b/websocket_test.go @@ -4,8 +4,6 @@ import ( "context" "encoding/json" "fmt" - "github.com/golang/protobuf/proto" - "github.com/golang/protobuf/ptypes/timestamp" "io" "io/ioutil" "math/rand" @@ -23,8 +21,10 @@ import ( "testing" "time" + "github.com/golang/protobuf/proto" "github.com/golang/protobuf/ptypes" "github.com/golang/protobuf/ptypes/duration" + "github.com/golang/protobuf/ptypes/timestamp" "github.com/google/go-cmp/cmp" "golang.org/x/xerrors" @@ -592,7 +592,7 @@ func TestConn(t *testing.T) { return nil }, client: func(ctx context.Context, c *websocket.Conn) error { - go c.CloseRead(ctx) + c.CloseRead(ctx) err := c.Write(ctx, websocket.MessageBinary, []byte(strings.Repeat("x", 32769))) if err != nil { @@ -775,7 +775,7 @@ func TestConn(t *testing.T) { } <-ctx.Done() _, err = r.Read(make([]byte, 1)) - if !xerrors.Is(err, context.DeadlineExceeded){ + if !xerrors.Is(err, context.DeadlineExceeded) { return xerrors.Errorf("expected deadline exceeded error: %+v", err) } return nil @@ -829,6 +829,152 @@ func TestConn(t *testing.T) { return nil }, }, + { + name: "largeControlFrame", + server: func(ctx context.Context, c *websocket.Conn) error { + _, err := c.WriteFrame(ctx, true, websocket.OPClose, []byte(strings.Repeat("x", 4096))) + if err != nil { + return err + } + _, _, err = c.Read(ctx) + cerr := &websocket.CloseError{} + if !xerrors.As(err, cerr) || cerr.Code != websocket.StatusProtocolError { + return xerrors.Errorf("expected close error with StatusProtocolError: %+v", err) + } + return nil + }, + client: func(ctx context.Context, c *websocket.Conn) error { + _, _, err := c.Read(ctx) + if err == nil || !strings.Contains(err.Error(), "too large") { + return xerrors.Errorf("expected error that contains too large: %+v", err) + } + return nil + }, + }, + { + name: "fragmentedControlFrame", + server: func(ctx context.Context, c *websocket.Conn) error { + _, err := c.WriteFrame(ctx, false, websocket.OPPing, []byte(strings.Repeat("x", 32))) + if err != nil { + return err + } + err = c.Flush() + if err != nil { + return err + } + _, _, err = c.Read(ctx) + cerr := &websocket.CloseError{} + if !xerrors.As(err, cerr) || cerr.Code != websocket.StatusProtocolError { + return xerrors.Errorf("expected close error with StatusProtocolError: %+v", err) + } + return nil + }, + client: func(ctx context.Context, c *websocket.Conn) error { + _, _, err := c.Read(ctx) + if err == nil || !strings.Contains(err.Error(), "fragmented") { + return xerrors.Errorf("expected error that contains fragmented: %+v", err) + } + return nil + }, + }, + { + name: "invalidClosePayload", + server: func(ctx context.Context, c *websocket.Conn) error { + _, err := c.WriteFrame(ctx, true, websocket.OPClose, []byte{0x17, 0x70}) + if err != nil { + return err + } + _, _, err = c.Read(ctx) + cerr := &websocket.CloseError{} + if !xerrors.As(err, cerr) || cerr.Code != websocket.StatusProtocolError { + return xerrors.Errorf("expected close error with StatusProtocolError: %+v", err) + } + return nil + }, + client: func(ctx context.Context, c *websocket.Conn) error { + _, _, err := c.Read(ctx) + if err == nil || !strings.Contains(err.Error(), "invalid status code") { + return xerrors.Errorf("expected error that contains invalid status code: %+v", err) + } + return nil + }, + }, + { + name: "doubleReader", + server: func(ctx context.Context, c *websocket.Conn) error { + _, r, err := c.Reader(ctx) + if err != nil { + return err + } + p := make([]byte, 10) + _, err = io.ReadFull(r, p) + if err != nil { + return err + } + _, _, err = c.Reader(ctx) + if err == nil { + return xerrors.Errorf("expected non nil error: %v", err) + } + return nil + }, + client: func(ctx context.Context, c *websocket.Conn) error { + err := c.Write(ctx, websocket.MessageBinary, []byte(strings.Repeat("x", 11))) + if err != nil { + return err + } + _, _, err = c.Read(ctx) + if err == nil { + return xerrors.Errorf("expected non nil error: %v", err) + } + return nil + }, + }, + { + name: "doubleFragmentedReader", + server: func(ctx context.Context, c *websocket.Conn) error { + _, r, err := c.Reader(ctx) + if err != nil { + return err + } + p := make([]byte, 10) + _, err = io.ReadFull(r, p) + if err != nil { + return err + } + _, _, err = c.Reader(ctx) + if err == nil { + return xerrors.Errorf("expected non nil error: %v", err) + } + return nil + }, + client: func(ctx context.Context, c *websocket.Conn) error { + w, err := c.Writer(ctx, websocket.MessageBinary) + if err != nil { + return err + } + _, err = w.Write([]byte(strings.Repeat("x", 10))) + if err != nil { + return xerrors.Errorf("expected non nil error") + } + err = c.Flush() + if err != nil { + return xerrors.Errorf("failed to flush: %w", err) + } + _, err = w.Write([]byte(strings.Repeat("x", 10))) + if err != nil { + return xerrors.Errorf("expected non nil error") + } + err = c.Flush() + if err != nil { + return xerrors.Errorf("failed to flush: %w", err) + } + _, _, err = c.Read(ctx) + if err == nil { + return xerrors.Errorf("expected non nil error: %v", err) + } + return nil + }, + }, } for _, tc := range testCases { tc := tc From 8cfcf43b9b111d7fd17ebf128a2468015f98ad74 Mon Sep 17 00:00:00 2001 From: Anmol Sethi Date: Fri, 30 Aug 2019 19:25:46 -0500 Subject: [PATCH 5/5] Add more tests and prepare for a rewrite of the tests tomorrow --- .circleci/config.yml | 6 +- export_test.go | 11 ++- websocket_test.go | 225 +++++++++++++++++++++++++++++++++++-------- 3 files changed, 197 insertions(+), 45 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index 65b17aa0..196ec671 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -2,7 +2,7 @@ version: 2 jobs: fmt: docker: - - image: nhooyr/websocket-ci + - image: nhooyr/websocket-ci@sha256:371ca985ce2548840aeb0f8434a551708cdfe0628be722c361958e65cdded945 steps: - checkout - restore_cache: @@ -19,7 +19,7 @@ jobs: lint: docker: - - image: nhooyr/websocket-ci + - image: nhooyr/websocket-ci@sha256:371ca985ce2548840aeb0f8434a551708cdfe0628be722c361958e65cdded945 steps: - checkout - restore_cache: @@ -36,7 +36,7 @@ jobs: test: docker: - - image: nhooyr/websocket-ci + - image: nhooyr/websocket-ci@sha256:371ca985ce2548840aeb0f8434a551708cdfe0628be722c361958e65cdded945 steps: - checkout - restore_cache: diff --git a/export_test.go b/export_test.go index 9c65360a..fc885bff 100644 --- a/export_test.go +++ b/export_test.go @@ -6,15 +6,22 @@ import ( type Addr = websocketAddr -type Header = header - const OPClose = opClose +const OPBinary = opBinary const OPPing = opPing +const OPContinuation = opContinuation func (c *Conn) WriteFrame(ctx context.Context, fin bool, opcode opcode, p []byte) (int, error) { return c.writeFrame(ctx, fin, opcode, p) } +func (c *Conn) WriteHalfFrame(ctx context.Context) (int, error) { + return c.realWriteFrame(ctx, header{ + opcode: opBinary, + payloadLength: 5, + }, make([]byte, 10)) +} + func (c *Conn) Flush() error { return c.bw.Flush() } diff --git a/websocket_test.go b/websocket_test.go index 73020f5e..1963ce70 100644 --- a/websocket_test.go +++ b/websocket_test.go @@ -146,6 +146,9 @@ func TestHandshake(t *testing.T) { c.Close(websocket.StatusInternalError, "") return xerrors.New("expected error regarding bad origin") } + if !strings.Contains(err.Error(), "not authorized") { + return xerrors.Errorf("expected error regarding bad origin: %+v", err) + } return nil }, client: func(ctx context.Context, u string) error { @@ -158,6 +161,9 @@ func TestHandshake(t *testing.T) { c.Close(websocket.StatusInternalError, "") return xerrors.New("expected handshake failure") } + if !strings.Contains(err.Error(), "403") { + return xerrors.Errorf("expected handshake failure: %+v", err) + } return nil }, }, @@ -390,8 +396,8 @@ func TestConn(t *testing.T) { nc.SetDeadline(time.Now().Add(time.Second * 15)) _, err := nc.Read(make([]byte, 1)) - if err == nil { - return xerrors.Errorf("expected error") + if err == nil || !strings.Contains(err.Error(), "unexpected frame type read") { + return xerrors.Errorf("expected error: %+v", err) } return nil @@ -426,7 +432,7 @@ func TestConn(t *testing.T) { } _, err = nc.Write([]byte{0xff}) - if err == nil { + if err == nil || !strings.Contains(err.Error(), "websocket closed") { return xerrors.Errorf("expected writes to fail after reading a close frame: %v", err) } @@ -586,8 +592,8 @@ func TestConn(t *testing.T) { name: "readLimit", server: func(ctx context.Context, c *websocket.Conn) error { _, _, err := c.Read(ctx) - if err == nil { - return xerrors.Errorf("expected error but got nil") + if err == nil || !strings.Contains(err.Error(), "read limited at") { + return xerrors.Errorf("expected error but got nil: %+v", err) } return nil }, @@ -614,7 +620,7 @@ func TestConn(t *testing.T) { server: func(ctx context.Context, c *websocket.Conn) error { var v interface{} err := wsjson.Read(ctx, c, &v) - if err == nil { + if err == nil || !strings.Contains(err.Error(), "unexpected frame type") { return xerrors.Errorf("expected error: %v", err) } return nil @@ -628,7 +634,7 @@ func TestConn(t *testing.T) { server: func(ctx context.Context, c *websocket.Conn) error { var v interface{} err := wsjson.Read(ctx, c, &v) - if err == nil { + if err == nil || !strings.Contains(err.Error(), "failed to unmarshal json") { return xerrors.Errorf("expected error: %v", err) } return nil @@ -641,7 +647,7 @@ func TestConn(t *testing.T) { name: "wsjson/badWrite", server: func(ctx context.Context, c *websocket.Conn) error { _, _, err := c.Read(ctx) - if err == nil { + if err == nil || !strings.Contains(err.Error(), "StatusInternalError") { return xerrors.Errorf("expected error: %v", err) } return nil @@ -659,7 +665,7 @@ func TestConn(t *testing.T) { server: func(ctx context.Context, c *websocket.Conn) error { var v proto.Message err := wspb.Read(ctx, c, v) - if err == nil { + if err == nil || !strings.Contains(err.Error(), "unexpected frame type") { return xerrors.Errorf("expected error: %v", err) } return nil @@ -673,7 +679,7 @@ func TestConn(t *testing.T) { server: func(ctx context.Context, c *websocket.Conn) error { var v timestamp.Timestamp err := wspb.Read(ctx, c, &v) - if err == nil { + if err == nil || !strings.Contains(err.Error(), "failed to unmarshal protobuf") { return xerrors.Errorf("expected error: %v", err) } return nil @@ -686,24 +692,7 @@ func TestConn(t *testing.T) { name: "wspb/badWrite", server: func(ctx context.Context, c *websocket.Conn) error { _, _, err := c.Read(ctx) - if err == nil { - return xerrors.Errorf("expected error: %v", err) - } - return nil - }, - client: func(ctx context.Context, c *websocket.Conn) error { - err := wspb.Write(ctx, c, nil) - if err == nil { - return xerrors.Errorf("expected error: %v", err) - } - return nil - }, - }, - { - name: "wspb/badWrite", - server: func(ctx context.Context, c *websocket.Conn) error { - _, _, err := c.Read(ctx) - if err == nil { + if err == nil || !strings.Contains(err.Error(), "StatusInternalError") { return xerrors.Errorf("expected error: %v", err) } return nil @@ -736,13 +725,13 @@ func TestConn(t *testing.T) { ctx, cancel := context.WithTimeout(ctx, time.Second) defer cancel() err := c.Ping(ctx) - if err == nil { + if err == nil || !xerrors.Is(err, context.DeadlineExceeded) { return xerrors.Errorf("expected nil error: %+v", err) } return nil }, client: func(ctx context.Context, c *websocket.Conn) error { - time.Sleep(time.Second) + c.Read(ctx) return nil }, }, @@ -769,19 +758,14 @@ func TestConn(t *testing.T) { server: func(ctx context.Context, c *websocket.Conn) error { ctx, cancel := context.WithTimeout(ctx, time.Second) defer cancel() - _, r, err := c.Reader(ctx) - if err != nil { - return err - } - <-ctx.Done() - _, err = r.Read(make([]byte, 1)) + _, _, err := c.Read(ctx) if !xerrors.Is(err, context.DeadlineExceeded) { return xerrors.Errorf("expected deadline exceeded error: %+v", err) } return nil }, client: func(ctx context.Context, c *websocket.Conn) error { - time.Sleep(time.Second) + c.Read(ctx) return nil }, }, @@ -912,7 +896,7 @@ func TestConn(t *testing.T) { return err } _, _, err = c.Reader(ctx) - if err == nil { + if err == nil || !strings.Contains(err.Error(), "previous message not read to completion") { return xerrors.Errorf("expected non nil error: %v", err) } return nil @@ -942,11 +926,57 @@ func TestConn(t *testing.T) { return err } _, _, err = c.Reader(ctx) + if err == nil || !strings.Contains(err.Error(), "previous message not read to completion") { + return xerrors.Errorf("expected non nil error: %v", err) + } + return nil + }, + client: func(ctx context.Context, c *websocket.Conn) error { + w, err := c.Writer(ctx, websocket.MessageBinary) + if err != nil { + return err + } + _, err = w.Write([]byte(strings.Repeat("x", 10))) + if err != nil { + return xerrors.Errorf("expected non nil error") + } + err = c.Flush() + if err != nil { + return xerrors.Errorf("failed to flush: %w", err) + } + _, err = w.Write([]byte(strings.Repeat("x", 10))) + if err != nil { + return xerrors.Errorf("expected non nil error") + } + err = c.Flush() + if err != nil { + return xerrors.Errorf("failed to flush: %w", err) + } + _, _, err = c.Read(ctx) if err == nil { return xerrors.Errorf("expected non nil error: %v", err) } return nil }, + }, + { + name: "newMessageInFragmentedMessage", + server: func(ctx context.Context, c *websocket.Conn) error { + _, r, err := c.Reader(ctx) + if err != nil { + return err + } + p := make([]byte, 10) + _, err = io.ReadFull(r, p) + if err != nil { + return err + } + _, _, err = c.Reader(ctx) + if err == nil || !strings.Contains(err.Error(), "received new data message without finishing") { + return xerrors.Errorf("expected non nil error: %v", err) + } + return nil + }, client: func(ctx context.Context, c *websocket.Conn) error { w, err := c.Writer(ctx, websocket.MessageBinary) if err != nil { @@ -960,6 +990,83 @@ func TestConn(t *testing.T) { if err != nil { return xerrors.Errorf("failed to flush: %w", err) } + _, err = c.WriteFrame(ctx, true, websocket.OPBinary, []byte(strings.Repeat("x", 10))) + if err != nil { + return xerrors.Errorf("expected non nil error") + } + _, _, err = c.Read(ctx) + if err == nil || !strings.Contains(err.Error(), "received new data message without finishing") { + return xerrors.Errorf("expected non nil error: %v", err) + } + return nil + }, + }, + { + name: "continuationFrameWithoutDataFrame", + server: func(ctx context.Context, c *websocket.Conn) error { + _, _, err := c.Reader(ctx) + if err == nil || !strings.Contains(err.Error(), "received continuation frame not after data") { + return xerrors.Errorf("expected non nil error: %v", err) + } + return nil + }, + client: func(ctx context.Context, c *websocket.Conn) error { + _, err := c.WriteFrame(ctx, false, websocket.OPContinuation, []byte(strings.Repeat("x", 10))) + if err != nil { + return xerrors.Errorf("expected non nil error") + } + return nil + }, + }, + { + name: "readBeforeEOF", + server: func(ctx context.Context, c *websocket.Conn) error { + _, r, err := c.Reader(ctx) + if err != nil { + return err + } + var v interface{} + d := json.NewDecoder(r) + err = d.Decode(&v) + if err != nil { + return err + } + _, b, err := c.Read(ctx) + if err != nil { + return err + } + if string(b) != "hi" { + return xerrors.Errorf("expected hi but got %q", string(b)) + } + return nil + }, + client: func(ctx context.Context, c *websocket.Conn) error { + err := wsjson.Write(ctx, c, "hi") + if err != nil { + return err + } + return c.Write(ctx, websocket.MessageBinary, []byte("hi")) + }, + }, + { + name: "newMessageInFragmentedMessage2", + server: func(ctx context.Context, c *websocket.Conn) error { + _, r, err := c.Reader(ctx) + if err != nil { + return err + } + p := make([]byte, 11) + _, err = io.ReadFull(r, p) + if err == nil || !strings.Contains(err.Error(), "received new data message without finishing") { + return xerrors.Errorf("expected non nil error: %v", err) + } + return nil + }, + client: func(ctx context.Context, c *websocket.Conn) error { + w, err := c.Writer(ctx, websocket.MessageBinary) + if err != nil { + return err + } _, err = w.Write([]byte(strings.Repeat("x", 10))) if err != nil { return xerrors.Errorf("expected non nil error") @@ -968,6 +1075,10 @@ func TestConn(t *testing.T) { if err != nil { return xerrors.Errorf("failed to flush: %w", err) } + _, err = c.WriteFrame(ctx, true, websocket.OPBinary, []byte(strings.Repeat("x", 10))) + if err != nil { + return xerrors.Errorf("expected non nil error") + } _, _, err = c.Read(ctx) if err == nil { return xerrors.Errorf("expected non nil error: %v", err) @@ -975,6 +1086,41 @@ func TestConn(t *testing.T) { return nil }, }, + { + name: "doubleRead", + server: func(ctx context.Context, c *websocket.Conn) error { + _, r, err := c.Reader(ctx) + if err != nil { + return err + } + _, err = ioutil.ReadAll(r) + if err != nil { + return err + } + _, err = r.Read(make([]byte, 1)) + if err == nil || !strings.Contains(err.Error(), "cannot use EOFed reader") { + return xerrors.Errorf("expected non nil error: %+v", err) + } + return nil + }, + client: func(ctx context.Context, c *websocket.Conn) error { + return c.Write(ctx, websocket.MessageBinary, []byte("hi")) + }, + }, + { + name: "eofInPayload", + server: func(ctx context.Context, c *websocket.Conn) error { + _, _, err := c.Read(ctx) + if err == nil || !strings.Contains(err.Error(), "failed to read frame payload") { + return xerrors.Errorf("expected failed to read frame payload: %v", err) + } + return nil + }, + client: func(ctx context.Context, c *websocket.Conn) error { + _, err := c.WriteHalfFrame(ctx) + return err + }, + }, } for _, tc := range testCases { tc := tc @@ -990,8 +1136,7 @@ func TestConn(t *testing.T) { return err } defer c.Close(websocket.StatusInternalError, "") - tc.server(r.Context(), c) - return nil + return tc.server(r.Context(), c) }, tls) defer closeFn()