diff --git a/.circleci/config.yml b/.circleci/config.yml index 65b17aa0..196ec671 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -2,7 +2,7 @@ version: 2 jobs: fmt: docker: - - image: nhooyr/websocket-ci + - image: nhooyr/websocket-ci@sha256:371ca985ce2548840aeb0f8434a551708cdfe0628be722c361958e65cdded945 steps: - checkout - restore_cache: @@ -19,7 +19,7 @@ jobs: lint: docker: - - image: nhooyr/websocket-ci + - image: nhooyr/websocket-ci@sha256:371ca985ce2548840aeb0f8434a551708cdfe0628be722c361958e65cdded945 steps: - checkout - restore_cache: @@ -36,7 +36,7 @@ jobs: test: docker: - - image: nhooyr/websocket-ci + - image: nhooyr/websocket-ci@sha256:371ca985ce2548840aeb0f8434a551708cdfe0628be722c361958e65cdded945 steps: - checkout - restore_cache: diff --git a/README.md b/README.md index cf20b877..d53046c8 100644 --- a/README.md +++ b/README.md @@ -34,7 +34,7 @@ For a production quality example that shows off the full API, see the [echo exam ```go http.HandlerFunc(func (w http.ResponseWriter, r *http.Request) { - c, err := websocket.Accept(w, r, websocket.AcceptOptions{}) + c, err := websocket.Accept(w, r, nil) if err != nil { // ... } @@ -64,7 +64,7 @@ in net/http](https://github.com/golang/go/issues/26937#issuecomment-415855861) t ctx, cancel := context.WithTimeout(context.Background(), time.Minute) defer cancel() -c, _, err := websocket.Dial(ctx, "ws://localhost:8080", websocket.DialOptions{}) +c, _, err := websocket.Dial(ctx, "ws://localhost:8080", nil) if err != nil { // ... } diff --git a/accept.go b/accept.go index 7b727d16..afad1be2 100644 --- a/accept.go +++ b/accept.go @@ -84,7 +84,7 @@ func verifyClientRequest(w http.ResponseWriter, r *http.Request) error { // // If an error occurs, Accept will always write an appropriate response so you do not // have to. -func Accept(w http.ResponseWriter, r *http.Request, opts AcceptOptions) (*Conn, error) { +func Accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (*Conn, error) { c, err := accept(w, r, opts) if err != nil { return nil, xerrors.Errorf("failed to accept websocket connection: %w", err) @@ -92,7 +92,11 @@ func Accept(w http.ResponseWriter, r *http.Request, opts AcceptOptions) (*Conn, return c, nil } -func accept(w http.ResponseWriter, r *http.Request, opts AcceptOptions) (*Conn, error) { +func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (*Conn, error) { + if opts == nil { + opts = &AcceptOptions{} + } + err := verifyClientRequest(w, r) if err != nil { return nil, err diff --git a/accept_test.go b/accept_test.go index 6f5c3fb9..6602a8d0 100644 --- a/accept_test.go +++ b/accept_test.go @@ -6,6 +6,39 @@ import ( "testing" ) +func TestAccept(t *testing.T) { + t.Parallel() + + t.Run("badClientHandshake", func(t *testing.T) { + t.Parallel() + + w := httptest.NewRecorder() + r := httptest.NewRequest("GET", "/", nil) + + _, err := Accept(w, r, nil) + if err == nil { + t.Fatalf("unexpected error value: %v", err) + } + + }) + + t.Run("requireHttpHijacker", func(t *testing.T) { + t.Parallel() + + w := httptest.NewRecorder() + r := httptest.NewRequest("GET", "/", nil) + r.Header.Set("Connection", "Upgrade") + r.Header.Set("Upgrade", "websocket") + r.Header.Set("Sec-WebSocket-Version", "13") + r.Header.Set("Sec-WebSocket-Key", "meow123") + + _, err := Accept(w, r, nil) + if err == nil || !strings.Contains(err.Error(), "http.Hijacker") { + t.Fatalf("unexpected error value: %v", err) + } + }) +} + func Test_verifyClientHandshake(t *testing.T) { t.Parallel() diff --git a/ci/image/Dockerfile b/ci/image/Dockerfile index d435e949..4477d646 100644 --- a/ci/image/Dockerfile +++ b/ci/image/Dockerfile @@ -6,8 +6,7 @@ ENV GOFLAGS="-mod=readonly" ENV PAGER=cat RUN apt-get update && \ - apt-get install -y shellcheck python-pip npm && \ - pip2 install autobahntestsuite && \ + apt-get install -y shellcheck npm && \ npm install -g prettier -RUN git config --global color.ui always \ No newline at end of file +RUN git config --global color.ui always diff --git a/ci/test.sh b/ci/test.sh index ab101e91..1d4a8b07 100755 --- a/ci/test.sh +++ b/ci/test.sh @@ -4,19 +4,34 @@ set -euo pipefail cd "$(dirname "${0}")" cd "$(git rev-parse --show-toplevel)" -mkdir -p ci/out/websocket -testFlags=( +argv=( + go run gotest.tools/gotestsum + # https://circleci.com/docs/2.0/collect-test-data/ + "--junitfile=ci/out/websocket/testReport.xml" + "--format=short-verbose" + -- -race "-vet=off" "-bench=." +) +# Interactive usage probably does not want to enable benchmarks, race detection +# turn off vet or use gotestsum by default. +if [[ $# -gt 0 ]]; then + argv=(go test "$@") +fi + +# We always want coverage. +argv+=( "-coverprofile=ci/out/coverage.prof" "-coverpkg=./..." ) -# https://circleci.com/docs/2.0/collect-test-data/ -go run gotest.tools/gotestsum \ - --junitfile ci/out/websocket/testReport.xml \ - --format=short-verbose \ - -- "${testFlags[@]}" + +mkdir -p ci/out/websocket +"${argv[@]}" + +# Removes coverage of generated files. +grep -v _string.go < ci/out/coverage.prof > ci/out/coverage2.prof +mv ci/out/coverage2.prof ci/out/coverage.prof go tool cover -html=ci/out/coverage.prof -o=ci/out/coverage.html if [[ ${CI:-} ]]; then diff --git a/dial.go b/dial.go index ac632c11..461817f6 100644 --- a/dial.go +++ b/dial.go @@ -41,7 +41,7 @@ type DialOptions struct { // This function requires at least Go 1.12 to succeed as it uses a new feature // in net/http to perform WebSocket handshakes and get a writable body // from the transport. See https://github.com/golang/go/issues/26937#issuecomment-415855861 -func Dial(ctx context.Context, u string, opts DialOptions) (*Conn, *http.Response, error) { +func Dial(ctx context.Context, u string, opts *DialOptions) (*Conn, *http.Response, error) { c, r, err := dial(ctx, u, opts) if err != nil { return nil, r, xerrors.Errorf("failed to websocket dial: %w", err) @@ -49,7 +49,15 @@ func Dial(ctx context.Context, u string, opts DialOptions) (*Conn, *http.Respons return c, r, nil } -func dial(ctx context.Context, u string, opts DialOptions) (_ *Conn, _ *http.Response, err error) { +func dial(ctx context.Context, u string, opts *DialOptions) (_ *Conn, _ *http.Response, err error) { + if opts == nil { + opts = &DialOptions{} + } + + // Shallow copy to ensure defaults do not affect user passed options. + opts2 := *opts + opts = &opts2 + if opts.HTTPClient == nil { opts.HTTPClient = http.DefaultClient } diff --git a/dial_test.go b/dial_test.go index 6f0deef9..96537bdb 100644 --- a/dial_test.go +++ b/dial_test.go @@ -1,11 +1,60 @@ package websocket import ( + "context" "net/http" "net/http/httptest" "testing" + "time" ) +func TestBadDials(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + url string + opts *DialOptions + }{ + { + name: "badURL", + url: "://noscheme", + }, + { + name: "badURLScheme", + url: "ftp://nhooyr.io", + }, + { + name: "badHTTPClient", + url: "ws://nhooyr.io", + opts: &DialOptions{ + HTTPClient: &http.Client{ + Timeout: time.Minute, + }, + }, + }, + { + name: "badTLS", + url: "wss://totallyfake.nhooyr.io", + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + + _, _, err := Dial(ctx, tc.url, tc.opts) + if err == nil { + t.Fatalf("expected non nil error: %+v", err) + } + }) + } +} + func Test_verifyServerHandshake(t *testing.T) { t.Parallel() diff --git a/docs/CONTRIBUTING.md b/docs/CONTRIBUTING.md index a0c97261..f003e743 100644 --- a/docs/CONTRIBUTING.md +++ b/docs/CONTRIBUTING.md @@ -34,16 +34,21 @@ browse coverage. You can run CI locally. The various steps are located in `ci/*.sh`. -1. `ci/fmt.sh` requires node (specifically prettier). -1. `ci/lint.sh` requires [shellcheck](https://github.com/koalaman/shellcheck#installing). -1. `ci/test.sh` requires the [Autobahn Test suite pip package](https://github.com/crossbario/autobahn-testsuite). -1. `ci/run.sh` runs the above scripts in order. +1. `ci/fmt.sh` which requires node (specifically prettier). +1. `ci/lint.sh` which requires [shellcheck](https://github.com/koalaman/shellcheck#installing). +1. `ci/test.sh` +1. `ci/run.sh` which runs the above scripts in order. For coverage details locally, please see `ci/out/coverage.html` after running `ci/test.sh`. See [ci/image/Dockerfile](ci/image/Dockerfile) for the installation of the CI dependencies on Ubuntu. -You can also run tests normally with `go test` once you have the -[Autobahn Test suite pip package](https://github.com/crossbario/autobahn-testsuite) -installed. `ci/test.sh` just passes a default set of flags to `go test` to collect coverage, +You can also run tests normally with `go test`. +`ci/test.sh` just passes a default set of flags to `go test` to collect coverage, enable the race detector, run benchmarks and also prettifies the output. + +If you pass flags to `ci/test.sh`, it will pass those flags directly to `go test` but will also +collect coverage for you. This is nice for when you don't want to wait for benchmarks +or the race detector but want to have coverage. + +Coverage percentage from codecov and the CI scripts will be different because they are calculated differently. diff --git a/example_echo_test.go b/example_echo_test.go index 6923bc04..3e7e7f9d 100644 --- a/example_echo_test.go +++ b/example_echo_test.go @@ -68,7 +68,7 @@ func Example_echo() { func echoServer(w http.ResponseWriter, r *http.Request) error { log.Printf("serving %v", r.RemoteAddr) - c, err := websocket.Accept(w, r, websocket.AcceptOptions{ + c, err := websocket.Accept(w, r, &websocket.AcceptOptions{ Subprotocols: []string{"echo"}, }) if err != nil { @@ -128,7 +128,7 @@ func client(url string) error { ctx, cancel := context.WithTimeout(context.Background(), time.Minute) defer cancel() - c, _, err := websocket.Dial(ctx, url, websocket.DialOptions{ + c, _, err := websocket.Dial(ctx, url, &websocket.DialOptions{ Subprotocols: []string{"echo"}, }) if err != nil { diff --git a/example_test.go b/example_test.go index 0b59e6a0..22c31202 100644 --- a/example_test.go +++ b/example_test.go @@ -14,7 +14,7 @@ import ( // message from the client and then closes the connection. func ExampleAccept() { fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - c, err := websocket.Accept(w, r, websocket.AcceptOptions{}) + c, err := websocket.Accept(w, r, nil) if err != nil { log.Println(err) return @@ -46,7 +46,7 @@ func ExampleDial() { ctx, cancel := context.WithTimeout(context.Background(), time.Minute) defer cancel() - c, _, err := websocket.Dial(ctx, "ws://localhost:8080", websocket.DialOptions{}) + c, _, err := websocket.Dial(ctx, "ws://localhost:8080", nil) if err != nil { log.Fatal(err) } @@ -64,7 +64,7 @@ func ExampleDial() { // on which you will only write and do not expect to read data messages. func Example_writeOnly() { fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - c, err := websocket.Accept(w, r, websocket.AcceptOptions{}) + c, err := websocket.Accept(w, r, nil) if err != nil { log.Println(err) return diff --git a/export_test.go b/export_test.go index 22ad76fc..fc885bff 100644 --- a/export_test.go +++ b/export_test.go @@ -1,3 +1,27 @@ package websocket -var Compute = handleSecWebSocketKey +import ( + "context" +) + +type Addr = websocketAddr + +const OPClose = opClose +const OPBinary = opBinary +const OPPing = opPing +const OPContinuation = opContinuation + +func (c *Conn) WriteFrame(ctx context.Context, fin bool, opcode opcode, p []byte) (int, error) { + return c.writeFrame(ctx, fin, opcode, p) +} + +func (c *Conn) WriteHalfFrame(ctx context.Context) (int, error) { + return c.realWriteFrame(ctx, header{ + opcode: opBinary, + payloadLength: 5, + }, make([]byte, 10)) +} + +func (c *Conn) Flush() error { + return c.bw.Flush() +} diff --git a/go.mod b/go.mod index 58d79bf1..35d500dd 100644 --- a/go.mod +++ b/go.mod @@ -13,8 +13,6 @@ require ( golang.org/x/time v0.0.0-20190308202827-9d24e82272b4 golang.org/x/tools v0.0.0-20190429184909-35c670923e21 golang.org/x/xerrors v0.0.0-20190513163551-3ee3066db522 - gotest.tools/gotestsum v0.3.5 + gotest.tools/gotestsum v0.3.6-0.20190825182939-fc6cb5870c52 mvdan.cc/sh v2.6.4+incompatible ) - -replace gotest.tools/gotestsum => github.com/nhooyr/gotestsum v0.3.6-0.20190821172136-aaabbb33254b diff --git a/go.sum b/go.sum index 98b766bf..b9e3737c 100644 --- a/go.sum +++ b/go.sum @@ -22,8 +22,6 @@ github.com/mattn/go-colorable v0.0.9 h1:UVL0vNpWh04HeJXV0KLcaT7r06gOH2l4OW6ddYRU github.com/mattn/go-colorable v0.0.9/go.mod h1:9vuHe8Xs5qXnSaW/c/ABM9alt+Vo+STaOChaDxuIBZU= github.com/mattn/go-isatty v0.0.3 h1:ns/ykhmWi7G9O+8a448SecJU3nSMBXJfqQkl0upE1jI= github.com/mattn/go-isatty v0.0.3/go.mod h1:M+lRXTBqGeGNdLjl/ufCoiOlB5xdOkqRJdNxMWT7Zi4= -github.com/nhooyr/gotestsum v0.3.6-0.20190821172136-aaabbb33254b h1:t6DbmxEtGMM72Uhs638nBOyK7tjsrDwoMfYO1EfQdFE= -github.com/nhooyr/gotestsum v0.3.6-0.20190821172136-aaabbb33254b/go.mod h1:Mnf3e5FUzXbkCfynWBGOwLssY7gTQgCHObK9tMpAriY= github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= github.com/onsi/ginkgo v1.8.0 h1:VkHVNpR4iVnU8XQR6DBm8BqYjN7CRzw+xKUbVVbbW9w= github.com/onsi/ginkgo v1.8.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= @@ -85,5 +83,7 @@ gopkg.in/yaml.v2 v2.2.1 h1:mUhvW9EsL+naU5Q3cakzfE91YhliOondGd6ZrsDBHQE= gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gotest.tools v2.1.0+incompatible h1:5USw7CrJBYKqjg9R7QlA6jzqZKEAtvW82aNmsxxGPxw= gotest.tools v2.1.0+incompatible/go.mod h1:DsYFclhRJ6vuDpmuTbkuFWG+y2sxOXAzmJt81HFBacw= +gotest.tools/gotestsum v0.3.6-0.20190825182939-fc6cb5870c52 h1:Qr31uPFyjpOhAgRfKV4ATUnknnLT2X7HFjqwkstdbbE= +gotest.tools/gotestsum v0.3.6-0.20190825182939-fc6cb5870c52/go.mod h1:Mnf3e5FUzXbkCfynWBGOwLssY7gTQgCHObK9tMpAriY= mvdan.cc/sh v2.6.4+incompatible h1:eD6tDeh0pw+/TOTI1BBEryZ02rD2nMcFsgcvde7jffM= mvdan.cc/sh v2.6.4+incompatible/go.mod h1:IeeQbZq+x2SUGBensq/jge5lLQbS3XT2ktyp3wrt4x8= diff --git a/header_test.go b/header_test.go index b45854ea..45d0535a 100644 --- a/header_test.go +++ b/header_test.go @@ -2,6 +2,7 @@ package websocket import ( "bytes" + "io" "math/rand" "strconv" "testing" @@ -21,6 +22,51 @@ func randBool() bool { func TestHeader(t *testing.T) { t.Parallel() + t.Run("eof", func(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + bytes []byte + }{ + { + "start", + []byte{0xff}, + }, + { + "middle", + []byte{0xff, 0xff, 0xff}, + }, + } + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + b := bytes.NewBuffer(tc.bytes) + _, err := readHeader(nil, b) + if io.ErrUnexpectedEOF != err { + t.Fatalf("expected %v but got: %v", io.ErrUnexpectedEOF, err) + } + }) + } + }) + + t.Run("writeNegativeLength", func(t *testing.T) { + t.Parallel() + + defer func() { + r := recover() + if r == nil { + t.Fatal("failed to induce panic in writeHeader with negative payload length") + } + }() + + writeHeader(nil, header{ + payloadLength: -1, + }) + }) + t.Run("readNegativeLength", func(t *testing.T) { t.Parallel() diff --git a/netconn.go b/netconn.go index d28eeb84..a6f902da 100644 --- a/netconn.go +++ b/netconn.go @@ -101,8 +101,8 @@ func (c *netConn) Read(p []byte) (int, error) { return 0, err } if typ != c.msgType { - c.c.Close(StatusUnsupportedData, fmt.Sprintf("can only accept %v messages", c.msgType)) - return 0, xerrors.Errorf("unexpected frame type read for net conn adapter (expected %v): %v", c.msgType, typ) + c.c.Close(StatusUnsupportedData, fmt.Sprintf("unexpected frame type read (expected %v): %v", c.msgType, typ)) + return 0, c.c.closeErr } c.reader = r } diff --git a/statuscode.go b/statuscode.go index 42ae40c0..498437d0 100644 --- a/statuscode.go +++ b/statuscode.go @@ -35,7 +35,7 @@ const ( StatusTryAgainLater StatusBadGateway // statusTLSHandshake is unexported because we just return - // handshake error in dial. We do not return a conn + // the handshake error in dial. We do not return a conn // so there is nothing to use this on. At least until WASM. statusTLSHandshake ) diff --git a/statuscode_test.go b/statuscode_test.go index 38ee4c3f..b9637868 100644 --- a/statuscode_test.go +++ b/statuscode_test.go @@ -4,14 +4,13 @@ import ( "math" "strings" "testing" + + "github.com/google/go-cmp/cmp" ) func TestCloseError(t *testing.T) { t.Parallel() - // Other parts of close error are tested by websocket_test.go right now - // with the autobahn tests. - testCases := []struct { name string ce CloseError @@ -50,7 +49,108 @@ func TestCloseError(t *testing.T) { _, err := tc.ce.bytes() if (err == nil) != tc.success { - t.Fatalf("unexpected error value: %v", err) + t.Fatalf("unexpected error value: %+v", err) + } + }) + } +} + +func Test_parseClosePayload(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + p []byte + success bool + ce CloseError + }{ + { + name: "normal", + p: append([]byte{0x3, 0xE8}, []byte("hello")...), + success: true, + ce: CloseError{ + Code: StatusNormalClosure, + Reason: "hello", + }, + }, + { + name: "nothing", + success: true, + ce: CloseError{ + Code: StatusNoStatusRcvd, + }, + }, + { + name: "oneByte", + p: []byte{0}, + success: false, + }, + { + name: "badStatusCode", + p: []byte{0x17, 0x70}, + success: false, + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + ce, err := parseClosePayload(tc.p) + if (err == nil) != tc.success { + t.Fatalf("unexpected expected error value: %+v", err) + } + + if tc.success && tc.ce != ce { + t.Fatalf("unexpected close error: %v", cmp.Diff(tc.ce, ce)) + } + }) + } +} + +func Test_validWireCloseCode(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + code StatusCode + valid bool + }{ + { + name: "normal", + code: StatusNormalClosure, + valid: true, + }, + { + name: "noStatus", + code: StatusNoStatusRcvd, + valid: false, + }, + { + name: "3000", + code: 3000, + valid: true, + }, + { + name: "4999", + code: 4999, + valid: true, + }, + { + name: "unknown", + code: 5000, + valid: false, + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + if valid := validWireCloseCode(tc.code); tc.valid != valid { + t.Fatalf("expected %v for %v but got %v", tc.valid, tc.code, valid) } }) } diff --git a/websocket.go b/websocket.go index 393ea547..833c1209 100644 --- a/websocket.go +++ b/websocket.go @@ -7,8 +7,8 @@ import ( "fmt" "io" "io/ioutil" + "log" "math/rand" - "os" "runtime" "strconv" "sync" @@ -210,9 +210,8 @@ func (c *Conn) readTillMsg(ctx context.Context) (header, error) { } if h.rsv1 || h.rsv2 || h.rsv3 { - err := xerrors.Errorf("received header with rsv bits set: %v:%v:%v", h.rsv1, h.rsv2, h.rsv3) - c.Close(StatusProtocolError, err.Error()) - return header{}, err + c.Close(StatusProtocolError, fmt.Sprintf("received header with rsv bits set: %v:%v:%v", h.rsv1, h.rsv2, h.rsv3)) + return header{}, c.closeErr } if h.opcode.controlOp() { @@ -227,9 +226,8 @@ func (c *Conn) readTillMsg(ctx context.Context) (header, error) { case opBinary, opText, opContinuation: return h, nil default: - err := xerrors.Errorf("received unknown opcode %v", h.opcode) - c.Close(StatusProtocolError, err.Error()) - return header{}, err + c.Close(StatusProtocolError, fmt.Sprintf("received unknown opcode %v", h.opcode)) + return header{}, c.closeErr } } } @@ -273,15 +271,13 @@ func (c *Conn) readFrameHeader(ctx context.Context) (header, error) { func (c *Conn) handleControl(ctx context.Context, h header) error { if h.payloadLength > maxControlFramePayload { - err := xerrors.Errorf("control frame too large at %v bytes", h.payloadLength) - c.Close(StatusProtocolError, err.Error()) - return err + c.Close(StatusProtocolError, fmt.Sprintf("control frame too large at %v bytes", h.payloadLength)) + return c.closeErr } if !h.fin { - err := xerrors.Errorf("received fragmented control frame") - c.Close(StatusProtocolError, err.Error()) - return err + c.Close(StatusProtocolError, "received fragmented control frame") + return c.closeErr } ctx, cancel := context.WithTimeout(ctx, time.Second*5) @@ -311,8 +307,9 @@ func (c *Conn) handleControl(ctx context.Context, h header) error { case opClose: ce, err := parseClosePayload(b) if err != nil { - c.Close(StatusProtocolError, "received invalid close payload") - return xerrors.Errorf("received invalid close payload: %w", err) + err = xerrors.Errorf("received invalid close payload: %w", err) + c.Close(StatusProtocolError, err.Error()) + return c.closeErr } // This ensures the closeErr of the Conn is always the received CloseError // in case the echo close frame write fails. @@ -376,9 +373,8 @@ func (c *Conn) reader(ctx context.Context) (MessageType, io.Reader, error) { if c.activeReader != nil && !c.activeReader.eof() { if h.opcode != opContinuation { - err := xerrors.Errorf("received new data message without finishing the previous message") - c.Close(StatusProtocolError, err.Error()) - return 0, nil, err + c.Close(StatusProtocolError, "received new data message without finishing the previous message") + return 0, nil, c.closeErr } if !h.fin || h.payloadLength > 0 { @@ -392,9 +388,8 @@ func (c *Conn) reader(ctx context.Context) (MessageType, io.Reader, error) { return 0, nil, err } } else if h.opcode == opContinuation { - err := xerrors.Errorf("received continuation frame not after data or text frame") - c.Close(StatusProtocolError, err.Error()) - return 0, nil, err + c.Close(StatusProtocolError, "received continuation frame not after data or text frame") + return 0, nil, c.closeErr } c.readerMsgCtx = ctx @@ -460,9 +455,8 @@ func (r *messageReader) read(p []byte) (int, error) { } if r.c.readMsgLeft <= 0 { - err := xerrors.Errorf("read limited at %v bytes", r.c.msgReadLimit) - r.c.Close(StatusMessageTooBig, err.Error()) - return 0, err + r.c.Close(StatusMessageTooBig, fmt.Sprintf("read limited at %v bytes", r.c.msgReadLimit)) + return 0, r.c.closeErr } if int64(len(p)) > r.c.readMsgLeft { @@ -476,9 +470,8 @@ func (r *messageReader) read(p []byte) (int, error) { } if h.opcode != opContinuation { - err := xerrors.Errorf("received new data message without finishing the previous message") - r.c.Close(StatusProtocolError, err.Error()) - return 0, err + r.c.Close(StatusProtocolError, "received new data message without finishing the previous message") + return 0, r.c.closeErr } r.c.readerMsgHeader = h @@ -828,7 +821,7 @@ func (c *Conn) writePong(p []byte) error { func (c *Conn) Close(code StatusCode, reason string) error { err := c.exportedClose(code, reason) if err != nil { - return xerrors.Errorf("failed to close connection: %w", err) + return xerrors.Errorf("failed to close websocket connection: %w", err) } return nil } @@ -844,7 +837,7 @@ func (c *Conn) exportedClose(code StatusCode, reason string) error { // Definitely worth seeing what popular browsers do later. p, err := ce.bytes() if err != nil { - fmt.Fprintf(os.Stderr, "websocket: failed to marshal close frame: %v\n", err) + log.Printf("websocket: failed to marshal close frame: %+v", err) ce = CloseError{ Code: StatusInternalError, } @@ -853,12 +846,13 @@ func (c *Conn) exportedClose(code StatusCode, reason string) error { // CloseErrors sent are made opaque to prevent applications from thinking // they received a given status. - err = c.writeClose(p, xerrors.Errorf("sent close frame: %v", ce)) + sentErr := xerrors.Errorf("sent close frame: %v", ce) + err = c.writeClose(p, sentErr) if err != nil { return err } - if !xerrors.Is(c.closeErr, ce) { + if !xerrors.Is(c.closeErr, sentErr) { return c.closeErr } diff --git a/websocket_test.go b/websocket_test.go index cd6bdaf5..1963ce70 100644 --- a/websocket_test.go +++ b/websocket_test.go @@ -6,6 +6,7 @@ import ( "fmt" "io" "io/ioutil" + "math/rand" "net" "net/http" "net/http/cookiejar" @@ -20,8 +21,10 @@ import ( "testing" "time" + "github.com/golang/protobuf/proto" "github.com/golang/protobuf/ptypes" "github.com/golang/protobuf/ptypes/duration" + "github.com/golang/protobuf/ptypes/timestamp" "github.com/google/go-cmp/cmp" "golang.org/x/xerrors" @@ -41,7 +44,7 @@ func TestHandshake(t *testing.T) { { name: "handshake", server: func(w http.ResponseWriter, r *http.Request) error { - c, err := websocket.Accept(w, r, websocket.AcceptOptions{ + c, err := websocket.Accept(w, r, &websocket.AcceptOptions{ Subprotocols: []string{"myproto"}, }) if err != nil { @@ -51,7 +54,7 @@ func TestHandshake(t *testing.T) { return nil }, client: func(ctx context.Context, u string) error { - c, resp, err := websocket.Dial(ctx, u, websocket.DialOptions{ + c, resp, err := websocket.Dial(ctx, u, &websocket.DialOptions{ Subprotocols: []string{"myproto"}, }) if err != nil { @@ -76,32 +79,227 @@ func TestHandshake(t *testing.T) { }, }, { - name: "closeError", + name: "defaultSubprotocol", server: func(w http.ResponseWriter, r *http.Request) error { - c, err := websocket.Accept(w, r, websocket.AcceptOptions{}) + c, err := websocket.Accept(w, r, nil) if err != nil { return err } defer c.Close(websocket.StatusInternalError, "") - err = wsjson.Write(r.Context(), c, "hello") + if c.Subprotocol() != "" { + return xerrors.Errorf("unexpected subprotocol: %v", c.Subprotocol()) + } + return nil + }, + client: func(ctx context.Context, u string) error { + c, _, err := websocket.Dial(ctx, u, &websocket.DialOptions{ + Subprotocols: []string{"meow"}, + }) if err != nil { return err } + defer c.Close(websocket.StatusInternalError, "") + if c.Subprotocol() != "" { + return xerrors.Errorf("unexpected subprotocol: %v", c.Subprotocol()) + } + return nil + }, + }, + { + name: "subprotocol", + server: func(w http.ResponseWriter, r *http.Request) error { + c, err := websocket.Accept(w, r, &websocket.AcceptOptions{ + Subprotocols: []string{"echo", "lar"}, + }) + if err != nil { + return err + } + defer c.Close(websocket.StatusInternalError, "") + + if c.Subprotocol() != "echo" { + return xerrors.Errorf("unexpected subprotocol: %q", c.Subprotocol()) + } return nil }, client: func(ctx context.Context, u string) error { - c, _, err := websocket.Dial(ctx, u, websocket.DialOptions{ - Subprotocols: []string{"meow"}, + c, _, err := websocket.Dial(ctx, u, &websocket.DialOptions{ + Subprotocols: []string{"poof", "echo"}, }) if err != nil { return err } defer c.Close(websocket.StatusInternalError, "") + if c.Subprotocol() != "echo" { + return xerrors.Errorf("unexpected subprotocol: %q", c.Subprotocol()) + } + return nil + }, + }, + { + name: "badOrigin", + server: func(w http.ResponseWriter, r *http.Request) error { + c, err := websocket.Accept(w, r, nil) + if err == nil { + c.Close(websocket.StatusInternalError, "") + return xerrors.New("expected error regarding bad origin") + } + if !strings.Contains(err.Error(), "not authorized") { + return xerrors.Errorf("expected error regarding bad origin: %+v", err) + } + return nil + }, + client: func(ctx context.Context, u string) error { + h := http.Header{} + h.Set("Origin", "http://unauthorized.com") + c, _, err := websocket.Dial(ctx, u, &websocket.DialOptions{ + HTTPHeader: h, + }) + if err == nil { + c.Close(websocket.StatusInternalError, "") + return xerrors.New("expected handshake failure") + } + if !strings.Contains(err.Error(), "403") { + return xerrors.Errorf("expected handshake failure: %+v", err) + } + return nil + }, + }, + { + name: "acceptSecureOrigin", + server: func(w http.ResponseWriter, r *http.Request) error { + c, err := websocket.Accept(w, r, nil) + if err != nil { + return err + } + defer c.Close(websocket.StatusInternalError, "") + return nil + }, + client: func(ctx context.Context, u string) error { + h := http.Header{} + h.Set("Origin", u) + c, _, err := websocket.Dial(ctx, u, &websocket.DialOptions{ + HTTPHeader: h, + }) + if err != nil { + return err + } + defer c.Close(websocket.StatusInternalError, "") + return nil + }, + }, + { + name: "acceptInsecureOrigin", + server: func(w http.ResponseWriter, r *http.Request) error { + c, err := websocket.Accept(w, r, &websocket.AcceptOptions{ + InsecureSkipVerify: true, + }) + if err != nil { + return err + } + defer c.Close(websocket.StatusInternalError, "") + return nil + }, + client: func(ctx context.Context, u string) error { + h := http.Header{} + h.Set("Origin", "https://example.com") + c, _, err := websocket.Dial(ctx, u, &websocket.DialOptions{ + HTTPHeader: h, + }) + if err != nil { + return err + } + defer c.Close(websocket.StatusInternalError, "") + return nil + }, + }, + { + name: "cookies", + server: func(w http.ResponseWriter, r *http.Request) error { + cookie, err := r.Cookie("mycookie") + if err != nil { + return xerrors.Errorf("request is missing mycookie: %w", err) + } + if cookie.Value != "myvalue" { + return xerrors.Errorf("expected %q but got %q", "myvalue", cookie.Value) + } + c, err := websocket.Accept(w, r, nil) + if err != nil { + return err + } + c.Close(websocket.StatusInternalError, "") + return nil + }, + client: func(ctx context.Context, u string) error { + jar, err := cookiejar.New(nil) + if err != nil { + return xerrors.Errorf("failed to create cookie jar: %w", err) + } + parsedURL, err := url.Parse(u) + if err != nil { + return xerrors.Errorf("failed to parse url: %w", err) + } + parsedURL.Scheme = "http" + jar.SetCookies(parsedURL, []*http.Cookie{ + { + Name: "mycookie", + Value: "myvalue", + }, + }) + hc := &http.Client{ + Jar: jar, + } + c, _, err := websocket.Dial(ctx, u, &websocket.DialOptions{ + HTTPClient: hc, + }) + if err != nil { + return err + } + c.Close(websocket.StatusInternalError, "") + return nil + }, + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + s, closeFn := testServer(t, tc.server, false) + defer closeFn() + + wsURL := strings.Replace(s.URL, "http", "ws", 1) + + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + defer cancel() + + err := tc.client(ctx, wsURL) + if err != nil { + t.Fatalf("client failed: %+v", err) + } + }) + } +} + +func TestConn(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + client func(ctx context.Context, c *websocket.Conn) error + server func(ctx context.Context, c *websocket.Conn) error + }{ + { + name: "closeError", + server: func(ctx context.Context, c *websocket.Conn) error { + return wsjson.Write(ctx, c, "hello") + }, + client: func(ctx context.Context, c *websocket.Conn) error { var m string - err = wsjson.Read(ctx, c, &m) + err := wsjson.Read(ctx, c, &m) if err != nil { return err } @@ -121,13 +319,7 @@ func TestHandshake(t *testing.T) { }, { name: "netConn", - server: func(w http.ResponseWriter, r *http.Request) error { - c, err := websocket.Accept(w, r, websocket.AcceptOptions{}) - if err != nil { - return err - } - defer c.Close(websocket.StatusInternalError, "") - + server: func(ctx context.Context, c *websocket.Conn) error { nc := websocket.NetConn(c, websocket.MessageBinary) defer nc.Close() @@ -135,8 +327,15 @@ func TestHandshake(t *testing.T) { time.Sleep(1) nc.SetWriteDeadline(time.Now().Add(time.Second * 15)) + if nc.LocalAddr() != (websocket.Addr{}) { + return xerrors.Errorf("net conn local address is not equal to websocket.Addr") + } + if nc.RemoteAddr() != (websocket.Addr{}) { + return xerrors.Errorf("net conn remote address is not equal to websocket.Addr") + } + for i := 0; i < 3; i++ { - _, err = nc.Write([]byte("hello")) + _, err := nc.Write([]byte("hello")) if err != nil { return err } @@ -144,15 +343,7 @@ func TestHandshake(t *testing.T) { return nil }, - client: func(ctx context.Context, u string) error { - c, _, err := websocket.Dial(ctx, u, websocket.DialOptions{ - Subprotocols: []string{"meow"}, - }) - if err != nil { - return err - } - defer c.Close(websocket.StatusInternalError, "") - + client: func(ctx context.Context, c *websocket.Conn) error { nc := websocket.NetConn(c, websocket.MessageBinary) defer nc.Close() @@ -164,7 +355,7 @@ func TestHandshake(t *testing.T) { p := make([]byte, len("hello")) // We do not use io.ReadFull here as it masks EOFs. // See https://github.com/nhooyr/websocket/issues/100#issuecomment-508148024 - _, err = nc.Read(p) + _, err := nc.Read(p) if err != nil { return err } @@ -176,14 +367,14 @@ func TestHandshake(t *testing.T) { } for i := 0; i < 3; i++ { - err = read() + err := read() if err != nil { return err } } // Ensure the close frame is converted to an EOF and multiple read's after all return EOF. - err = read() + err := read() if err != io.EOF { return err } @@ -197,423 +388,756 @@ func TestHandshake(t *testing.T) { }, }, { - name: "defaultSubprotocol", - server: func(w http.ResponseWriter, r *http.Request) error { - c, err := websocket.Accept(w, r, websocket.AcceptOptions{}) + name: "netConn/badReadMsgType", + server: func(ctx context.Context, c *websocket.Conn) error { + nc := websocket.NetConn(c, websocket.MessageBinary) + defer nc.Close() + + nc.SetDeadline(time.Now().Add(time.Second * 15)) + + _, err := nc.Read(make([]byte, 1)) + if err == nil || !strings.Contains(err.Error(), "unexpected frame type read") { + return xerrors.Errorf("expected error: %+v", err) + } + + return nil + }, + client: func(ctx context.Context, c *websocket.Conn) error { + err := wsjson.Write(ctx, c, "meow") if err != nil { return err } - defer c.Close(websocket.StatusInternalError, "") - if c.Subprotocol() != "" { - return xerrors.Errorf("unexpected subprotocol: %v", c.Subprotocol()) + _, _, err = c.Read(ctx) + cerr := &websocket.CloseError{} + if !xerrors.As(err, cerr) || cerr.Code != websocket.StatusUnsupportedData { + return xerrors.Errorf("expected close error with code StatusUnsupportedData: %+v", err) } + return nil }, - client: func(ctx context.Context, u string) error { - c, _, err := websocket.Dial(ctx, u, websocket.DialOptions{ - Subprotocols: []string{"meow"}, - }) + }, + { + name: "netConn/badRead", + server: func(ctx context.Context, c *websocket.Conn) error { + nc := websocket.NetConn(c, websocket.MessageBinary) + defer nc.Close() + + nc.SetDeadline(time.Now().Add(time.Second * 15)) + + _, err := nc.Read(make([]byte, 1)) + cerr := &websocket.CloseError{} + if !xerrors.As(err, cerr) || cerr.Code != websocket.StatusBadGateway { + return xerrors.Errorf("expected close error with code StatusBadGateway: %+v", err) + } + + _, err = nc.Write([]byte{0xff}) + if err == nil || !strings.Contains(err.Error(), "websocket closed") { + return xerrors.Errorf("expected writes to fail after reading a close frame: %v", err) + } + + return nil + }, + client: func(ctx context.Context, c *websocket.Conn) error { + return c.Close(websocket.StatusBadGateway, "") + }, + }, + { + name: "jsonEcho", + server: func(ctx context.Context, c *websocket.Conn) error { + write := func() error { + v := map[string]interface{}{ + "anmol": "wowow", + } + err := wsjson.Write(ctx, c, v) + return err + } + err := write() + if err != nil { + return err + } + err = write() if err != nil { return err } - defer c.Close(websocket.StatusInternalError, "") - if c.Subprotocol() != "" { - return xerrors.Errorf("unexpected subprotocol: %v", c.Subprotocol()) + c.Close(websocket.StatusNormalClosure, "") + return nil + }, + client: func(ctx context.Context, c *websocket.Conn) error { + read := func() error { + var v interface{} + err := wsjson.Read(ctx, c, &v) + if err != nil { + return err + } + + exp := map[string]interface{}{ + "anmol": "wowow", + } + if !reflect.DeepEqual(exp, v) { + return xerrors.Errorf("expected %v but got %v", exp, v) + } + return nil + } + err := read() + if err != nil { + return err } + err = read() + if err != nil { + return err + } + + c.Close(websocket.StatusNormalClosure, "") return nil }, }, { - name: "subprotocol", - server: func(w http.ResponseWriter, r *http.Request) error { - c, err := websocket.Accept(w, r, websocket.AcceptOptions{ - Subprotocols: []string{"echo", "lar"}, - }) + name: "protobufEcho", + server: func(ctx context.Context, c *websocket.Conn) error { + write := func() error { + err := wspb.Write(ctx, c, ptypes.DurationProto(100)) + return err + } + err := write() if err != nil { return err } - defer c.Close(websocket.StatusInternalError, "") - if c.Subprotocol() != "echo" { - return xerrors.Errorf("unexpected subprotocol: %q", c.Subprotocol()) + c.Close(websocket.StatusNormalClosure, "") + return nil + }, + client: func(ctx context.Context, c *websocket.Conn) error { + read := func() error { + var v duration.Duration + err := wspb.Read(ctx, c, &v) + if err != nil { + return err + } + + d, err := ptypes.Duration(&v) + if err != nil { + return xerrors.Errorf("failed to convert duration.Duration to time.Duration: %w", err) + } + const exp = time.Duration(100) + if !reflect.DeepEqual(exp, d) { + return xerrors.Errorf("expected %v but got %v", exp, d) + } + return nil + } + err := read() + if err != nil { + return err } + + c.Close(websocket.StatusNormalClosure, "") return nil }, - client: func(ctx context.Context, u string) error { - c, _, err := websocket.Dial(ctx, u, websocket.DialOptions{ - Subprotocols: []string{"poof", "echo"}, - }) + }, + { + name: "ping", + server: func(ctx context.Context, c *websocket.Conn) error { + errc := make(chan error, 1) + go func() { + _, _, err2 := c.Read(ctx) + errc <- err2 + }() + + err := c.Ping(ctx) if err != nil { return err } - defer c.Close(websocket.StatusInternalError, "") - if c.Subprotocol() != "echo" { - return xerrors.Errorf("unexpected subprotocol: %q", c.Subprotocol()) + err = c.Write(ctx, websocket.MessageText, []byte("hi")) + if err != nil { + return err + } + + err = <-errc + var ce websocket.CloseError + if xerrors.As(err, &ce) && ce.Code == websocket.StatusNormalClosure { + return nil } + return xerrors.Errorf("unexpected error: %w", err) + }, + client: func(ctx context.Context, c *websocket.Conn) error { + // We read a message from the connection and then keep reading until + // the Ping completes. + done := make(chan struct{}) + go func() { + _, _, err := c.Read(ctx) + if err != nil { + c.Close(websocket.StatusInternalError, err.Error()) + return + } + + close(done) + + c.Read(ctx) + }() + + err := c.Ping(ctx) + if err != nil { + return err + } + + <-done + + c.Close(websocket.StatusNormalClosure, "") return nil }, }, { - name: "badOrigin", - server: func(w http.ResponseWriter, r *http.Request) error { - c, err := websocket.Accept(w, r, websocket.AcceptOptions{}) + name: "readLimit", + server: func(ctx context.Context, c *websocket.Conn) error { + _, _, err := c.Read(ctx) + if err == nil || !strings.Contains(err.Error(), "read limited at") { + return xerrors.Errorf("expected error but got nil: %+v", err) + } + return nil + }, + client: func(ctx context.Context, c *websocket.Conn) error { + c.CloseRead(ctx) + + err := c.Write(ctx, websocket.MessageBinary, []byte(strings.Repeat("x", 32769))) + if err != nil { + return err + } + + err = c.Ping(ctx) + + var ce websocket.CloseError + if !xerrors.As(err, &ce) || ce.Code != websocket.StatusMessageTooBig { + return xerrors.Errorf("unexpected error: %w", err) + } + + return nil + }, + }, + { + name: "wsjson/binary", + server: func(ctx context.Context, c *websocket.Conn) error { + var v interface{} + err := wsjson.Read(ctx, c, &v) + if err == nil || !strings.Contains(err.Error(), "unexpected frame type") { + return xerrors.Errorf("expected error: %v", err) + } + return nil + }, + client: func(ctx context.Context, c *websocket.Conn) error { + return wspb.Write(ctx, c, ptypes.DurationProto(100)) + }, + }, + { + name: "wsjson/badRead", + server: func(ctx context.Context, c *websocket.Conn) error { + var v interface{} + err := wsjson.Read(ctx, c, &v) + if err == nil || !strings.Contains(err.Error(), "failed to unmarshal json") { + return xerrors.Errorf("expected error: %v", err) + } + return nil + }, + client: func(ctx context.Context, c *websocket.Conn) error { + return c.Write(ctx, websocket.MessageText, []byte("notjson")) + }, + }, + { + name: "wsjson/badWrite", + server: func(ctx context.Context, c *websocket.Conn) error { + _, _, err := c.Read(ctx) + if err == nil || !strings.Contains(err.Error(), "StatusInternalError") { + return xerrors.Errorf("expected error: %v", err) + } + return nil + }, + client: func(ctx context.Context, c *websocket.Conn) error { + err := wsjson.Write(ctx, c, fmt.Println) if err == nil { - c.Close(websocket.StatusInternalError, "") - return xerrors.New("expected error regarding bad origin") + return xerrors.Errorf("expected error: %v", err) } return nil }, - client: func(ctx context.Context, u string) error { - h := http.Header{} - h.Set("Origin", "http://unauthorized.com") - c, _, err := websocket.Dial(ctx, u, websocket.DialOptions{ - HTTPHeader: h, - }) + }, + { + name: "wspb/text", + server: func(ctx context.Context, c *websocket.Conn) error { + var v proto.Message + err := wspb.Read(ctx, c, v) + if err == nil || !strings.Contains(err.Error(), "unexpected frame type") { + return xerrors.Errorf("expected error: %v", err) + } + return nil + }, + client: func(ctx context.Context, c *websocket.Conn) error { + return wsjson.Write(ctx, c, "hi") + }, + }, + { + name: "wspb/badRead", + server: func(ctx context.Context, c *websocket.Conn) error { + var v timestamp.Timestamp + err := wspb.Read(ctx, c, &v) + if err == nil || !strings.Contains(err.Error(), "failed to unmarshal protobuf") { + return xerrors.Errorf("expected error: %v", err) + } + return nil + }, + client: func(ctx context.Context, c *websocket.Conn) error { + return c.Write(ctx, websocket.MessageBinary, []byte("notpb")) + }, + }, + { + name: "wspb/badWrite", + server: func(ctx context.Context, c *websocket.Conn) error { + _, _, err := c.Read(ctx) + if err == nil || !strings.Contains(err.Error(), "StatusInternalError") { + return xerrors.Errorf("expected error: %v", err) + } + return nil + }, + client: func(ctx context.Context, c *websocket.Conn) error { + err := wspb.Write(ctx, c, nil) if err == nil { - c.Close(websocket.StatusInternalError, "") - return xerrors.New("expected handshake failure") + return xerrors.Errorf("expected error: %v", err) } return nil }, }, { - name: "acceptSecureOrigin", - server: func(w http.ResponseWriter, r *http.Request) error { - c, err := websocket.Accept(w, r, websocket.AcceptOptions{}) + name: "badClose", + server: func(ctx context.Context, c *websocket.Conn) error { + return c.Close(9999, "") + }, + client: func(ctx context.Context, c *websocket.Conn) error { + _, _, err := c.Read(ctx) + cerr := &websocket.CloseError{} + if !xerrors.As(err, cerr) || cerr.Code != websocket.StatusInternalError { + return xerrors.Errorf("expected close error with StatusInternalError: %+v", err) + } + return nil + }, + }, + { + name: "pingTimeout", + server: func(ctx context.Context, c *websocket.Conn) error { + ctx, cancel := context.WithTimeout(ctx, time.Second) + defer cancel() + err := c.Ping(ctx) + if err == nil || !xerrors.Is(err, context.DeadlineExceeded) { + return xerrors.Errorf("expected nil error: %+v", err) + } + return nil + }, + client: func(ctx context.Context, c *websocket.Conn) error { + c.Read(ctx) + return nil + }, + }, + { + name: "writeTimeout", + server: func(ctx context.Context, c *websocket.Conn) error { + c.Writer(ctx, websocket.MessageBinary) + + ctx, cancel := context.WithTimeout(ctx, time.Second) + defer cancel() + err := c.Write(ctx, websocket.MessageBinary, []byte("meow")) + if !xerrors.Is(err, context.DeadlineExceeded) { + return xerrors.Errorf("expected deadline exceeded error: %+v", err) + } + return nil + }, + client: func(ctx context.Context, c *websocket.Conn) error { + time.Sleep(time.Second) + return nil + }, + }, + { + name: "readTimeout", + server: func(ctx context.Context, c *websocket.Conn) error { + ctx, cancel := context.WithTimeout(ctx, time.Second) + defer cancel() + _, _, err := c.Read(ctx) + if !xerrors.Is(err, context.DeadlineExceeded) { + return xerrors.Errorf("expected deadline exceeded error: %+v", err) + } + return nil + }, + client: func(ctx context.Context, c *websocket.Conn) error { + c.Read(ctx) + return nil + }, + }, + { + name: "badOpCode", + server: func(ctx context.Context, c *websocket.Conn) error { + _, err := c.WriteFrame(ctx, true, 13, []byte("meow")) if err != nil { return err } - defer c.Close(websocket.StatusInternalError, "") + _, _, err = c.Read(ctx) + cerr := &websocket.CloseError{} + if !xerrors.As(err, cerr) || cerr.Code != websocket.StatusProtocolError { + return xerrors.Errorf("expected close error with StatusProtocolError: %+v", err) + } return nil }, - client: func(ctx context.Context, u string) error { - h := http.Header{} - h.Set("Origin", u) - c, _, err := websocket.Dial(ctx, u, websocket.DialOptions{ - HTTPHeader: h, - }) + client: func(ctx context.Context, c *websocket.Conn) error { + _, _, err := c.Read(ctx) + if err == nil || !strings.Contains(err.Error(), "opcode") { + return xerrors.Errorf("expected error that contains opcode: %+v", err) + } + return nil + }, + }, + { + name: "noRsv", + server: func(ctx context.Context, c *websocket.Conn) error { + _, err := c.WriteFrame(ctx, true, 99, []byte("meow")) + if err != nil { + return err + } + _, _, err = c.Read(ctx) + cerr := &websocket.CloseError{} + if !xerrors.As(err, cerr) || cerr.Code != websocket.StatusProtocolError { + return xerrors.Errorf("expected close error with StatusProtocolError: %+v", err) + } + return nil + }, + client: func(ctx context.Context, c *websocket.Conn) error { + _, _, err := c.Read(ctx) + if err == nil || !strings.Contains(err.Error(), "rsv") { + return xerrors.Errorf("expected error that contains rsv: %+v", err) + } + return nil + }, + }, + { + name: "largeControlFrame", + server: func(ctx context.Context, c *websocket.Conn) error { + _, err := c.WriteFrame(ctx, true, websocket.OPClose, []byte(strings.Repeat("x", 4096))) if err != nil { return err } - defer c.Close(websocket.StatusInternalError, "") + _, _, err = c.Read(ctx) + cerr := &websocket.CloseError{} + if !xerrors.As(err, cerr) || cerr.Code != websocket.StatusProtocolError { + return xerrors.Errorf("expected close error with StatusProtocolError: %+v", err) + } + return nil + }, + client: func(ctx context.Context, c *websocket.Conn) error { + _, _, err := c.Read(ctx) + if err == nil || !strings.Contains(err.Error(), "too large") { + return xerrors.Errorf("expected error that contains too large: %+v", err) + } + return nil + }, + }, + { + name: "fragmentedControlFrame", + server: func(ctx context.Context, c *websocket.Conn) error { + _, err := c.WriteFrame(ctx, false, websocket.OPPing, []byte(strings.Repeat("x", 32))) + if err != nil { + return err + } + err = c.Flush() + if err != nil { + return err + } + _, _, err = c.Read(ctx) + cerr := &websocket.CloseError{} + if !xerrors.As(err, cerr) || cerr.Code != websocket.StatusProtocolError { + return xerrors.Errorf("expected close error with StatusProtocolError: %+v", err) + } + return nil + }, + client: func(ctx context.Context, c *websocket.Conn) error { + _, _, err := c.Read(ctx) + if err == nil || !strings.Contains(err.Error(), "fragmented") { + return xerrors.Errorf("expected error that contains fragmented: %+v", err) + } return nil }, }, { - name: "acceptInsecureOrigin", - server: func(w http.ResponseWriter, r *http.Request) error { - c, err := websocket.Accept(w, r, websocket.AcceptOptions{ - InsecureSkipVerify: true, - }) + name: "invalidClosePayload", + server: func(ctx context.Context, c *websocket.Conn) error { + _, err := c.WriteFrame(ctx, true, websocket.OPClose, []byte{0x17, 0x70}) if err != nil { return err } - defer c.Close(websocket.StatusInternalError, "") + _, _, err = c.Read(ctx) + cerr := &websocket.CloseError{} + if !xerrors.As(err, cerr) || cerr.Code != websocket.StatusProtocolError { + return xerrors.Errorf("expected close error with StatusProtocolError: %+v", err) + } return nil }, - client: func(ctx context.Context, u string) error { - h := http.Header{} - h.Set("Origin", "https://example.com") - c, _, err := websocket.Dial(ctx, u, websocket.DialOptions{ - HTTPHeader: h, - }) - if err != nil { - return err + client: func(ctx context.Context, c *websocket.Conn) error { + _, _, err := c.Read(ctx) + if err == nil || !strings.Contains(err.Error(), "invalid status code") { + return xerrors.Errorf("expected error that contains invalid status code: %+v", err) } - defer c.Close(websocket.StatusInternalError, "") return nil }, }, { - name: "jsonEcho", - server: func(w http.ResponseWriter, r *http.Request) error { - c, err := websocket.Accept(w, r, websocket.AcceptOptions{}) + name: "doubleReader", + server: func(ctx context.Context, c *websocket.Conn) error { + _, r, err := c.Reader(ctx) if err != nil { return err } - defer c.Close(websocket.StatusInternalError, "") - - ctx, cancel := context.WithTimeout(r.Context(), time.Second*5) - defer cancel() - - write := func() error { - v := map[string]interface{}{ - "anmol": "wowow", - } - err := wsjson.Write(ctx, c, v) - return err - } - err = write() + p := make([]byte, 10) + _, err = io.ReadFull(r, p) if err != nil { return err } - err = write() - if err != nil { - return err + _, _, err = c.Reader(ctx) + if err == nil || !strings.Contains(err.Error(), "previous message not read to completion") { + return xerrors.Errorf("expected non nil error: %v", err) } - - c.Close(websocket.StatusNormalClosure, "") return nil }, - client: func(ctx context.Context, u string) error { - c, _, err := websocket.Dial(ctx, u, websocket.DialOptions{}) - if err != nil { - return err - } - defer c.Close(websocket.StatusInternalError, "") - - read := func() error { - var v interface{} - err := wsjson.Read(ctx, c, &v) - if err != nil { - return err - } - - exp := map[string]interface{}{ - "anmol": "wowow", - } - if !reflect.DeepEqual(exp, v) { - return xerrors.Errorf("expected %v but got %v", exp, v) - } - return nil - } - err = read() + client: func(ctx context.Context, c *websocket.Conn) error { + err := c.Write(ctx, websocket.MessageBinary, []byte(strings.Repeat("x", 11))) if err != nil { return err } - err = read() - if err != nil { - return err + _, _, err = c.Read(ctx) + if err == nil { + return xerrors.Errorf("expected non nil error: %v", err) } - - c.Close(websocket.StatusNormalClosure, "") return nil }, }, { - name: "protobufEcho", - server: func(w http.ResponseWriter, r *http.Request) error { - c, err := websocket.Accept(w, r, websocket.AcceptOptions{}) + name: "doubleFragmentedReader", + server: func(ctx context.Context, c *websocket.Conn) error { + _, r, err := c.Reader(ctx) if err != nil { return err } - defer c.Close(websocket.StatusInternalError, "") - - ctx, cancel := context.WithTimeout(r.Context(), time.Second*5) - defer cancel() - - write := func() error { - err := wspb.Write(ctx, c, ptypes.DurationProto(100)) - return err - } - err = write() + p := make([]byte, 10) + _, err = io.ReadFull(r, p) if err != nil { return err } - - c.Close(websocket.StatusNormalClosure, "") + _, _, err = c.Reader(ctx) + if err == nil || !strings.Contains(err.Error(), "previous message not read to completion") { + return xerrors.Errorf("expected non nil error: %v", err) + } return nil }, - client: func(ctx context.Context, u string) error { - c, _, err := websocket.Dial(ctx, u, websocket.DialOptions{}) + client: func(ctx context.Context, c *websocket.Conn) error { + w, err := c.Writer(ctx, websocket.MessageBinary) if err != nil { return err } - defer c.Close(websocket.StatusInternalError, "") - - read := func() error { - var v duration.Duration - err := wspb.Read(ctx, c, &v) - if err != nil { - return err - } - - d, err := ptypes.Duration(&v) - if err != nil { - return xerrors.Errorf("failed to convert duration.Duration to time.Duration: %w", err) - } - const exp = time.Duration(100) - if !reflect.DeepEqual(exp, d) { - return xerrors.Errorf("expected %v but got %v", exp, d) - } - return nil + _, err = w.Write([]byte(strings.Repeat("x", 10))) + if err != nil { + return xerrors.Errorf("expected non nil error") } - err = read() + err = c.Flush() if err != nil { - return err + return xerrors.Errorf("failed to flush: %w", err) + } + _, err = w.Write([]byte(strings.Repeat("x", 10))) + if err != nil { + return xerrors.Errorf("expected non nil error") + } + err = c.Flush() + if err != nil { + return xerrors.Errorf("failed to flush: %w", err) + } + _, _, err = c.Read(ctx) + if err == nil { + return xerrors.Errorf("expected non nil error: %v", err) } - - c.Close(websocket.StatusNormalClosure, "") return nil }, }, { - name: "cookies", - server: func(w http.ResponseWriter, r *http.Request) error { - cookie, err := r.Cookie("mycookie") + name: "newMessageInFragmentedMessage", + server: func(ctx context.Context, c *websocket.Conn) error { + _, r, err := c.Reader(ctx) if err != nil { - return xerrors.Errorf("request is missing mycookie: %w", err) - } - if cookie.Value != "myvalue" { - return xerrors.Errorf("expected %q but got %q", "myvalue", cookie.Value) + return err } - c, err := websocket.Accept(w, r, websocket.AcceptOptions{}) + p := make([]byte, 10) + _, err = io.ReadFull(r, p) if err != nil { return err } - c.Close(websocket.StatusInternalError, "") + _, _, err = c.Reader(ctx) + if err == nil || !strings.Contains(err.Error(), "received new data message without finishing") { + return xerrors.Errorf("expected non nil error: %v", err) + } return nil }, - client: func(ctx context.Context, u string) error { - jar, err := cookiejar.New(nil) + client: func(ctx context.Context, c *websocket.Conn) error { + w, err := c.Writer(ctx, websocket.MessageBinary) if err != nil { - return xerrors.Errorf("failed to create cookie jar: %w", err) + return err } - parsedURL, err := url.Parse(u) + _, err = w.Write([]byte(strings.Repeat("x", 10))) if err != nil { - return xerrors.Errorf("failed to parse url: %w", err) + return xerrors.Errorf("expected non nil error") } - parsedURL.Scheme = "http" - jar.SetCookies(parsedURL, []*http.Cookie{ - { - Name: "mycookie", - Value: "myvalue", - }, - }) - hc := &http.Client{ - Jar: jar, + err = c.Flush() + if err != nil { + return xerrors.Errorf("failed to flush: %w", err) } - c, _, err := websocket.Dial(ctx, u, websocket.DialOptions{ - HTTPClient: hc, - }) + _, err = c.WriteFrame(ctx, true, websocket.OPBinary, []byte(strings.Repeat("x", 10))) if err != nil { - return err + return xerrors.Errorf("expected non nil error") + } + _, _, err = c.Read(ctx) + if err == nil || !strings.Contains(err.Error(), "received new data message without finishing") { + return xerrors.Errorf("expected non nil error: %v", err) } - c.Close(websocket.StatusInternalError, "") return nil }, }, { - name: "ping", - server: func(w http.ResponseWriter, r *http.Request) error { - c, err := websocket.Accept(w, r, websocket.AcceptOptions{}) + name: "continuationFrameWithoutDataFrame", + server: func(ctx context.Context, c *websocket.Conn) error { + _, _, err := c.Reader(ctx) + if err == nil || !strings.Contains(err.Error(), "received continuation frame not after data") { + return xerrors.Errorf("expected non nil error: %v", err) + } + return nil + }, + client: func(ctx context.Context, c *websocket.Conn) error { + _, err := c.WriteFrame(ctx, false, websocket.OPContinuation, []byte(strings.Repeat("x", 10))) + if err != nil { + return xerrors.Errorf("expected non nil error") + } + return nil + }, + }, + { + name: "readBeforeEOF", + server: func(ctx context.Context, c *websocket.Conn) error { + _, r, err := c.Reader(ctx) if err != nil { return err } - defer c.Close(websocket.StatusInternalError, "") - - errc := make(chan error, 1) - go func() { - _, _, err2 := c.Read(r.Context()) - errc <- err2 - }() - - err = c.Ping(r.Context()) + var v interface{} + d := json.NewDecoder(r) + err = d.Decode(&v) if err != nil { return err } - - err = c.Write(r.Context(), websocket.MessageText, []byte("hi")) + _, b, err := c.Read(ctx) if err != nil { return err } - - err = <-errc - var ce websocket.CloseError - if xerrors.As(err, &ce) && ce.Code == websocket.StatusNormalClosure { - return nil + if string(b) != "hi" { + return xerrors.Errorf("expected hi but got %q", string(b)) } - return xerrors.Errorf("unexpected error: %w", err) + return nil }, - client: func(ctx context.Context, u string) error { - c, _, err := websocket.Dial(ctx, u, websocket.DialOptions{}) + client: func(ctx context.Context, c *websocket.Conn) error { + err := wsjson.Write(ctx, c, "hi") if err != nil { return err } - defer c.Close(websocket.StatusInternalError, "") - - // We read a message from the connection and then keep reading until - // the Ping completes. - done := make(chan struct{}) - go func() { - _, _, err := c.Read(ctx) - if err != nil { - c.Close(websocket.StatusInternalError, err.Error()) - return - } - - close(done) - - c.Read(ctx) - }() - - err = c.Ping(ctx) + return c.Write(ctx, websocket.MessageBinary, []byte("hi")) + }, + }, + { + name: "newMessageInFragmentedMessage2", + server: func(ctx context.Context, c *websocket.Conn) error { + _, r, err := c.Reader(ctx) if err != nil { return err } - - <-done - - c.Close(websocket.StatusNormalClosure, "") + p := make([]byte, 11) + _, err = io.ReadFull(r, p) + if err == nil || !strings.Contains(err.Error(), "received new data message without finishing") { + return xerrors.Errorf("expected non nil error: %v", err) + } return nil }, - }, - { - name: "readLimit", - server: func(w http.ResponseWriter, r *http.Request) error { - c, err := websocket.Accept(w, r, websocket.AcceptOptions{}) + client: func(ctx context.Context, c *websocket.Conn) error { + w, err := c.Writer(ctx, websocket.MessageBinary) if err != nil { return err } - defer c.Close(websocket.StatusInternalError, "") - - _, _, err = c.Read(r.Context()) + _, err = w.Write([]byte(strings.Repeat("x", 10))) + if err != nil { + return xerrors.Errorf("expected non nil error") + } + err = c.Flush() + if err != nil { + return xerrors.Errorf("failed to flush: %w", err) + } + _, err = c.WriteFrame(ctx, true, websocket.OPBinary, []byte(strings.Repeat("x", 10))) + if err != nil { + return xerrors.Errorf("expected non nil error") + } + _, _, err = c.Read(ctx) if err == nil { - return xerrors.Errorf("expected error but got nil") + return xerrors.Errorf("expected non nil error: %v", err) } return nil }, - client: func(ctx context.Context, u string) error { - c, _, err := websocket.Dial(ctx, u, websocket.DialOptions{}) + }, + { + name: "doubleRead", + server: func(ctx context.Context, c *websocket.Conn) error { + _, r, err := c.Reader(ctx) if err != nil { return err } - defer c.Close(websocket.StatusInternalError, "") - - go c.Reader(ctx) - - err = c.Write(ctx, websocket.MessageBinary, []byte(strings.Repeat("x", 32769))) + _, err = ioutil.ReadAll(r) if err != nil { return err } - - err = c.Ping(ctx) - - var ce websocket.CloseError - if !xerrors.As(err, &ce) || ce.Code != websocket.StatusMessageTooBig { - return xerrors.Errorf("unexpected error: %w", err) + _, err = r.Read(make([]byte, 1)) + if err == nil || !strings.Contains(err.Error(), "cannot use EOFed reader") { + return xerrors.Errorf("expected non nil error: %+v", err) + } + return nil + }, + client: func(ctx context.Context, c *websocket.Conn) error { + return c.Write(ctx, websocket.MessageBinary, []byte("hi")) + }, + }, + { + name: "eofInPayload", + server: func(ctx context.Context, c *websocket.Conn) error { + _, _, err := c.Read(ctx) + if err == nil || !strings.Contains(err.Error(), "failed to read frame payload") { + return xerrors.Errorf("expected failed to read frame payload: %v", err) } - return nil }, + client: func(ctx context.Context, c *websocket.Conn) error { + _, err := c.WriteHalfFrame(ctx) + return err + }, }, } - for _, tc := range testCases { tc := tc t.Run(tc.name, func(t *testing.T) { t.Parallel() - s, closeFn := testServer(t, func(w http.ResponseWriter, r *http.Request) { - err := tc.server(w, r) + // Run random tests over TLS. + tls := rand.Intn(2) == 1 + + s, closeFn := testServer(t, func(w http.ResponseWriter, r *http.Request) error { + c, err := websocket.Accept(w, r, nil) if err != nil { - t.Errorf("server failed: %+v", err) - return + return err } - }) + defer c.Close(websocket.StatusInternalError, "") + return tc.server(r.Context(), c) + }, tls) defer closeFn() wsURL := strings.Replace(s.URL, "http", "ws", 1) @@ -621,7 +1145,18 @@ func TestHandshake(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), time.Minute) defer cancel() - err := tc.client(ctx, wsURL) + opts := &websocket.DialOptions{} + if tls { + opts.HTTPClient = s.Client() + } + + c, _, err := websocket.Dial(ctx, wsURL, opts) + if err != nil { + t.Fatal(err) + } + defer c.Close(websocket.StatusInternalError, "") + + err = tc.client(ctx, c) if err != nil { t.Fatalf("client failed: %+v", err) } @@ -629,14 +1164,31 @@ func TestHandshake(t *testing.T) { } } -func testServer(tb testing.TB, fn http.HandlerFunc) (s *httptest.Server, closeFn func()) { +func init() { + rand.Seed(time.Now().UnixNano()) +} + +func testServer(tb testing.TB, fn func(w http.ResponseWriter, r *http.Request) error, tls bool) (s *httptest.Server, closeFn func()) { var conns int64 - s = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { atomic.AddInt64(&conns, 1) defer atomic.AddInt64(&conns, -1) - fn.ServeHTTP(w, r) - })) + ctx, cancel := context.WithTimeout(r.Context(), time.Second*30) + defer cancel() + + r = r.WithContext(ctx) + + err := fn(w, r) + if err != nil { + tb.Errorf("server failed: %+v", err) + } + }) + if tls { + s = httptest.NewTLSServer(h) + } else { + s = httptest.NewServer(h) + } return s, func() { s.Close() @@ -654,9 +1206,12 @@ func testServer(tb testing.TB, fn http.HandlerFunc) (s *httptest.Server, closeFn // https://github.com/crossbario/autobahn-python/tree/master/wstest func TestAutobahnServer(t *testing.T) { t.Parallel() + if os.Getenv("AUTOBAHN") == "" { + t.Skip("Set $AUTOBAHN to run the autobahn test suite.") + } s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - c, err := websocket.Accept(w, r, websocket.AcceptOptions{ + c, err := websocket.Accept(w, r, &websocket.AcceptOptions{ Subprotocols: []string{"echo"}, }) if err != nil { @@ -794,6 +1349,9 @@ func unusedListenAddr() (string, error) { // https://github.com/crossbario/autobahn-python/blob/master/wstest/testee_client_aio.py func TestAutobahnClient(t *testing.T) { t.Parallel() + if os.Getenv("AUTOBAHN") == "" { + t.Skip("Set $AUTOBAHN to run the autobahn test suite.") + } serverAddr, err := unusedListenAddr() if err != nil { @@ -853,7 +1411,7 @@ func TestAutobahnClient(t *testing.T) { var cases int func() { - c, _, err := websocket.Dial(ctx, wsServerURL+"/getCaseCount", websocket.DialOptions{}) + c, _, err := websocket.Dial(ctx, wsServerURL+"/getCaseCount", nil) if err != nil { t.Fatal(err) } @@ -880,7 +1438,7 @@ func TestAutobahnClient(t *testing.T) { ctx, cancel := context.WithTimeout(ctx, time.Second*45) defer cancel() - c, _, err := websocket.Dial(ctx, fmt.Sprintf(wsServerURL+"/runCase?case=%v&agent=main", i), websocket.DialOptions{}) + c, _, err := websocket.Dial(ctx, fmt.Sprintf(wsServerURL+"/runCase?case=%v&agent=main", i), nil) if err != nil { t.Fatal(err) } @@ -888,7 +1446,7 @@ func TestAutobahnClient(t *testing.T) { }() } - c, _, err := websocket.Dial(ctx, fmt.Sprintf(wsServerURL+"/updateReports?agent=main"), websocket.DialOptions{}) + c, _, err := websocket.Dial(ctx, fmt.Sprintf(wsServerURL+"/updateReports?agent=main"), nil) if err != nil { t.Fatal(err) } @@ -939,18 +1497,18 @@ func checkWSTestIndex(t *testing.T, path string) { } func benchConn(b *testing.B, echo, stream bool, size int) { - s, closeFn := testServer(b, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - c, err := websocket.Accept(w, r, websocket.AcceptOptions{}) + s, closeFn := testServer(b, func(w http.ResponseWriter, r *http.Request) error { + c, err := websocket.Accept(w, r, nil) if err != nil { - b.Logf("server handshake failed: %+v", err) - return + return err } if echo { echoLoop(r.Context(), c) } else { discardLoop(r.Context(), c) } - })) + return nil + }, false) defer closeFn() wsURL := strings.Replace(s.URL, "http", "ws", 1) @@ -958,7 +1516,7 @@ func benchConn(b *testing.B, echo, stream bool, size int) { ctx, cancel := context.WithTimeout(context.Background(), time.Minute*5) defer cancel() - c, _, err := websocket.Dial(ctx, wsURL, websocket.DialOptions{}) + c, _, err := websocket.Dial(ctx, wsURL, nil) if err != nil { b.Fatal(err) }