From 97172f3339a9bef16fa82fde84b4b0c7a1357e56 Mon Sep 17 00:00:00 2001 From: Anmol Sethi Date: Tue, 25 Feb 2020 23:59:57 -0500 Subject: [PATCH 1/4] Add Grace to gracefully close WebSocket connections Closes #199 --- accept.go | 20 ++++++- conn_notjs.go | 5 ++ conn_test.go | 12 ++--- example_echo_test.go | 6 ++- example_test.go | 46 ++++++++++++++++ grace.go | 123 +++++++++++++++++++++++++++++++++++++++++++ ws_js.go | 2 + 7 files changed, 202 insertions(+), 12 deletions(-) create mode 100644 grace.go diff --git a/accept.go b/accept.go index 47e20b52..52a93459 100644 --- a/accept.go +++ b/accept.go @@ -75,6 +75,13 @@ func Accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (*Conn, func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Conn, err error) { defer errd.Wrap(&err, "failed to accept WebSocket connection") + g := graceFromRequest(r) + if g != nil && g.isClosing() { + err := errors.New("server closing") + http.Error(w, err.Error(), http.StatusServiceUnavailable) + return nil, err + } + if opts == nil { opts = &AcceptOptions{} } @@ -134,7 +141,7 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Con b, _ := brw.Reader.Peek(brw.Reader.Buffered()) brw.Reader.Reset(io.MultiReader(bytes.NewReader(b), netConn)) - return newConn(connConfig{ + c := newConn(connConfig{ subprotocol: w.Header().Get("Sec-WebSocket-Protocol"), rwc: netConn, client: false, @@ -143,7 +150,16 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Con br: brw.Reader, bw: brw.Writer, - }), nil + }) + + if g != nil { + err = g.addConn(c) + if err != nil { + return nil, err + } + } + + return c, nil } func verifyClientRequest(w http.ResponseWriter, r *http.Request) (errCode int, _ error) { diff --git a/conn_notjs.go b/conn_notjs.go index bb2eb22f..f604898e 100644 --- a/conn_notjs.go +++ b/conn_notjs.go @@ -33,6 +33,7 @@ type Conn struct { flateThreshold int br *bufio.Reader bw *bufio.Writer + g *Grace readTimeout chan context.Context writeTimeout chan context.Context @@ -138,6 +139,10 @@ func (c *Conn) close(err error) { // closeErr. c.rwc.Close() + if c.g != nil { + c.g.delConn(c) + } + go func() { c.msgWriterState.close() diff --git a/conn_test.go b/conn_test.go index 28da3c07..af4fa4c0 100644 --- a/conn_test.go +++ b/conn_test.go @@ -13,7 +13,6 @@ import ( "os" "os/exec" "strings" - "sync" "testing" "time" @@ -272,11 +271,9 @@ func TestWasm(t *testing.T) { t.Skip("skipping on CI") } - var wg sync.WaitGroup - s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - wg.Add(1) - defer wg.Done() - + var g websocket.Grace + defer g.Close() + s := httptest.NewServer(g.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { c, err := websocket.Accept(w, r, &websocket.AcceptOptions{ Subprotocols: []string{"echo"}, InsecureSkipVerify: true, @@ -294,8 +291,7 @@ func TestWasm(t *testing.T) { t.Errorf("echo server failed: %v", err) return } - })) - defer wg.Wait() + }))) defer s.Close() ctx, cancel := context.WithTimeout(context.Background(), time.Minute) diff --git a/example_echo_test.go b/example_echo_test.go index cd195d2e..0c0b84ea 100644 --- a/example_echo_test.go +++ b/example_echo_test.go @@ -31,13 +31,15 @@ func Example_echo() { } defer l.Close() + var g websocket.Grace + defer g.Close() s := &http.Server{ - Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + Handler: g.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { err := echoServer(w, r) if err != nil { log.Printf("echo server: %v", err) } - }), + })), ReadTimeout: time.Second * 15, WriteTimeout: time.Second * 15, } diff --git a/example_test.go b/example_test.go index c56e53f3..ce049bc3 100644 --- a/example_test.go +++ b/example_test.go @@ -6,6 +6,8 @@ import ( "context" "log" "net/http" + "os" + "os/signal" "time" "nhooyr.io/websocket" @@ -133,3 +135,47 @@ func Example_crossOrigin() { err := http.ListenAndServe("localhost:8080", fn) log.Fatal(err) } + +// This example demonstrates how to create a WebSocket server +// that gracefully exits when sent a signal. +// +// It starts a WebSocket server that keeps every connection open +// for 10 seconds. +// If you CTRL+C while a connection is open, it will wait at most 30s +// for all connections to terminate before shutting down. +func ExampleGrace() { + fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + c, err := websocket.Accept(w, r, nil) + if err != nil { + log.Println(err) + return + } + defer c.Close(websocket.StatusInternalError, "the sky is falling") + + ctx := c.CloseRead(r.Context()) + select { + case <-ctx.Done(): + case <-time.After(time.Second * 10): + } + + c.Close(websocket.StatusNormalClosure, "") + }) + + var g websocket.Grace + s := &http.Server{ + Handler: g.Handler(fn), + ReadTimeout: time.Second * 15, + WriteTimeout: time.Second * 15, + } + go s.ListenAndServe() + + sigs := make(chan os.Signal, 1) + signal.Notify(sigs, os.Interrupt) + sig := <-sigs + log.Printf("recieved %v, shutting down", sig) + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*30) + defer cancel() + s.Shutdown(ctx) + g.Shutdown(ctx) +} diff --git a/grace.go b/grace.go new file mode 100644 index 00000000..8dadc43d --- /dev/null +++ b/grace.go @@ -0,0 +1,123 @@ +package websocket + +import ( + "context" + "errors" + "fmt" + "net/http" + "sync" + "time" +) + +// Grace enables graceful shutdown of accepted WebSocket connections. +// +// Use Handler to wrap WebSocket handlers to record accepted connections +// and then use Close or Shutdown to gracefully close these connections. +// +// Grace is intended to be used in harmony with net/http.Server's Shutdown and Close methods. +type Grace struct { + mu sync.Mutex + closing bool + conns map[*Conn]struct{} +} + +// Handler returns a handler that wraps around h to record +// all WebSocket connections accepted. +// +// Use Close or Shutdown to gracefully close recorded connections. +func (g *Grace) Handler(h http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx := context.WithValue(r.Context(), gracefulContextKey{}, g) + r = r.WithContext(ctx) + h.ServeHTTP(w, r) + }) +} + +func (g *Grace) isClosing() bool { + g.mu.Lock() + defer g.mu.Unlock() + return g.closing +} + +func graceFromRequest(r *http.Request) *Grace { + g, _ := r.Context().Value(gracefulContextKey{}).(*Grace) + return g +} + +func (g *Grace) addConn(c *Conn) error { + g.mu.Lock() + defer g.mu.Unlock() + if g.closing { + c.Close(StatusGoingAway, "server shutting down") + return errors.New("server shutting down") + } + if g.conns == nil { + g.conns = make(map[*Conn]struct{}) + } + g.conns[c] = struct{}{} + c.g = g + return nil +} + +func (g *Grace) delConn(c *Conn) { + g.mu.Lock() + defer g.mu.Unlock() + delete(g.conns, c) +} + +type gracefulContextKey struct{} + +// Close prevents the acceptance of new connections with +// http.StatusServiceUnavailable and closes all accepted +// connections with StatusGoingAway. +func (g *Grace) Close() error { + g.mu.Lock() + g.closing = true + var wg sync.WaitGroup + for c := range g.conns { + wg.Add(1) + go func(c *Conn) { + defer wg.Done() + c.Close(StatusGoingAway, "server shutting down") + }(c) + + delete(g.conns, c) + } + g.mu.Unlock() + + wg.Wait() + + return nil +} + +// Shutdown prevents the acceptance of new connections and waits until +// all connections close. If the context is cancelled before that, it +// calls Close to close all connections immediately. +func (g *Grace) Shutdown(ctx context.Context) error { + defer g.Close() + + g.mu.Lock() + g.closing = true + g.mu.Unlock() + + // Same poll period used by net/http. + t := time.NewTicker(500 * time.Millisecond) + defer t.Stop() + for { + if g.zeroConns() { + return nil + } + + select { + case <-t.C: + case <-ctx.Done(): + return fmt.Errorf("failed to shutdown WebSockets: %w", ctx.Err()) + } + } +} + +func (g *Grace) zeroConns() bool { + g.mu.Lock() + defer g.mu.Unlock() + return len(g.conns) == 0 +} diff --git a/ws_js.go b/ws_js.go index 2b560ce8..a8c8b771 100644 --- a/ws_js.go +++ b/ws_js.go @@ -38,6 +38,8 @@ type Conn struct { readSignal chan struct{} readBufMu sync.Mutex readBuf []wsjs.MessageEvent + + g *Grace } func (c *Conn) close(err error, wasClean bool) { From e335b09210e47739545fe30c69f3a0f56ede98a0 Mon Sep 17 00:00:00 2001 From: Anmol Sethi Date: Wed, 26 Feb 2020 14:47:40 -0500 Subject: [PATCH 2/4] Use grace in chat example --- accept.go | 4 ++-- chat-example/go.mod | 4 +++- chat-example/go.sum | 10 ++++++++-- chat-example/index.css | 2 +- chat-example/index.js | 13 ++++++++++--- chat-example/main.go | 29 +++++++++++++++++++++++++++-- example_test.go | 14 +++++++++++--- grace.go | 20 ++++++++++++-------- 8 files changed, 74 insertions(+), 22 deletions(-) diff --git a/accept.go b/accept.go index 52a93459..dd96c9bd 100644 --- a/accept.go +++ b/accept.go @@ -76,8 +76,8 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Con defer errd.Wrap(&err, "failed to accept WebSocket connection") g := graceFromRequest(r) - if g != nil && g.isClosing() { - err := errors.New("server closing") + if g != nil && g.isShuttingdown() { + err := errors.New("server shutting down") http.Error(w, err.Error(), http.StatusServiceUnavailable) return nil, err } diff --git a/chat-example/go.mod b/chat-example/go.mod index 34fa5a69..c47a5a2f 100644 --- a/chat-example/go.mod +++ b/chat-example/go.mod @@ -2,4 +2,6 @@ module nhooyr.io/websocket/example-chat go 1.13 -require nhooyr.io/websocket v1.8.2 +require nhooyr.io/websocket v0.0.0 + +replace nhooyr.io/websocket => ../ diff --git a/chat-example/go.sum b/chat-example/go.sum index 0755fca5..e4bbd62d 100644 --- a/chat-example/go.sum +++ b/chat-example/go.sum @@ -1,12 +1,18 @@ +github.com/gobwas/httphead v0.0.0-20180130184737-2c6c146eadee h1:s+21KNqlpePfkah2I+gwHF8xmJWRjooY+5248k6m4A0= github.com/gobwas/httphead v0.0.0-20180130184737-2c6c146eadee/go.mod h1:L0fX3K22YWvt/FAX9NnzrNzcI4wNYi9Yku4O0LKYflo= +github.com/gobwas/pool v0.2.0 h1:QEmUOlnSjWtnpRGHF3SauEiOsy82Cup83Vf2LcMlnc8= github.com/gobwas/pool v0.2.0/go.mod h1:q8bcK0KcYlCgd9e7WYLm9LpyS+YeLd8JVDW6WezmKEw= +github.com/gobwas/ws v1.0.2 h1:CoAavW/wd/kulfZmSIBt6p24n4j7tHgNVCjsfHVNUbo= github.com/gobwas/ws v1.0.2/go.mod h1:szmBTxLgaFppYjEmNtny/v3w89xOydFnnZMcgRRu/EM= +github.com/golang/protobuf v1.3.3 h1:gyjaxf+svBWX08ZjK86iN9geUJF0H6gp2IRKX6Nf6/I= github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw= +github.com/google/go-cmp v0.4.0 h1:xsAVV57WRhGj6kEIi8ReJzQlHHqcBYCElAvkovg3B/4= github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/gorilla/websocket v1.4.1 h1:q7AeDBpnBk8AogcD4DSag/Ukw/KV+YhzLj2bP5HvKCM= github.com/gorilla/websocket v1.4.1/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/klauspost/compress v1.10.0 h1:92XGj1AcYzA6UrVdd4qIIBrT8OroryvRvdmg/IfmC7Y= github.com/klauspost/compress v1.10.0/go.mod h1:aoV0uJVorq1K+umq18yTdKaF57EivdYsUV+/s2qKfXs= +golang.org/x/time v0.0.0-20191024005414-555d28b269f0 h1:/5xXl8Y5W96D+TtHSlonuFqGHIWVuyCkGJLwGh9JJFs= golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -nhooyr.io/websocket v1.8.2 h1:LwdzfyyOZKtVFoXay6A39Acu03KmidSZ3YUUvPa13PA= -nhooyr.io/websocket v1.8.2/go.mod h1:LiqdCg1Cu7TPWxEvPjPa0TGYxCsy4pHNTN9gGluwBpQ= diff --git a/chat-example/index.css b/chat-example/index.css index 29804662..73a8e0f3 100644 --- a/chat-example/index.css +++ b/chat-example/index.css @@ -5,7 +5,7 @@ body { #root { padding: 40px 20px; - max-width: 480px; + max-width: 600px; margin: auto; height: 100vh; diff --git a/chat-example/index.js b/chat-example/index.js index 8fb3dfb8..a42c2d30 100644 --- a/chat-example/index.js +++ b/chat-example/index.js @@ -7,8 +7,11 @@ const conn = new WebSocket(`ws://${location.host}/subscribe`) conn.addEventListener("close", ev => { - console.info("websocket disconnected, reconnecting in 1000ms", ev) - setTimeout(dial, 1000) + appendLog(`WebSocket Disconnected code: ${ev.code}, reason: ${ev.reason}`, true) + if (ev.code !== 1001) { + appendLog("Reconnecting in 1s", true) + setTimeout(dial, 1000) + } }) conn.addEventListener("open", ev => { console.info("websocket connected") @@ -34,10 +37,14 @@ const messageInput = document.getElementById("message-input") // appendLog appends the passed text to messageLog. - function appendLog(text) { + function appendLog(text, error) { const p = document.createElement("p") // Adding a timestamp to each message makes the log easier to read. p.innerText = `${new Date().toLocaleTimeString()}: ${text}` + if (error) { + p.style.color = "red" + p.style.fontStyle = "bold" + } messageLog.append(p) return p } diff --git a/chat-example/main.go b/chat-example/main.go index 2a520924..f985d382 100644 --- a/chat-example/main.go +++ b/chat-example/main.go @@ -1,12 +1,16 @@ package main import ( + "context" "errors" "log" "net" "net/http" "os" + "os/signal" "time" + + "nhooyr.io/websocket" ) func main() { @@ -38,10 +42,31 @@ func run() error { m.HandleFunc("/subscribe", ws.subscribeHandler) m.HandleFunc("/publish", ws.publishHandler) + var g websocket.Grace s := http.Server{ - Handler: m, + Handler: g.Handler(m), ReadTimeout: time.Second * 10, WriteTimeout: time.Second * 10, } - return s.Serve(l) + errc := make(chan error, 1) + go func() { + errc <- s.Serve(l) + }() + + sigs := make(chan os.Signal, 1) + signal.Notify(sigs, os.Interrupt) + select { + case err := <-errc: + log.Printf("failed to serve: %v", err) + case sig := <-sigs: + log.Printf("terminating: %v", sig) + } + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) + defer cancel() + + s.Shutdown(ctx) + g.Shutdown(ctx) + + return nil } diff --git a/example_test.go b/example_test.go index ce049bc3..462de376 100644 --- a/example_test.go +++ b/example_test.go @@ -167,12 +167,20 @@ func ExampleGrace() { ReadTimeout: time.Second * 15, WriteTimeout: time.Second * 15, } - go s.ListenAndServe() + + errc := make(chan error, 1) + go func() { + errc <- s.ListenAndServe() + }() sigs := make(chan os.Signal, 1) signal.Notify(sigs, os.Interrupt) - sig := <-sigs - log.Printf("recieved %v, shutting down", sig) + select { + case err := <-errc: + log.Printf("failed to listen and serve: %v", err) + case sig := <-sigs: + log.Printf("terminating: %v", sig) + } ctx, cancel := context.WithTimeout(context.Background(), time.Second*30) defer cancel() diff --git a/grace.go b/grace.go index 8dadc43d..c53cd40b 100644 --- a/grace.go +++ b/grace.go @@ -15,10 +15,13 @@ import ( // and then use Close or Shutdown to gracefully close these connections. // // Grace is intended to be used in harmony with net/http.Server's Shutdown and Close methods. +// It's required as net/http's Shutdown and Close methods do not keep track of WebSocket +// connections. type Grace struct { - mu sync.Mutex - closing bool - conns map[*Conn]struct{} + mu sync.Mutex + closed bool + shuttingDown bool + conns map[*Conn]struct{} } // Handler returns a handler that wraps around h to record @@ -33,10 +36,10 @@ func (g *Grace) Handler(h http.Handler) http.Handler { }) } -func (g *Grace) isClosing() bool { +func (g *Grace) isShuttingdown() bool { g.mu.Lock() defer g.mu.Unlock() - return g.closing + return g.shuttingDown } func graceFromRequest(r *http.Request) *Grace { @@ -47,7 +50,7 @@ func graceFromRequest(r *http.Request) *Grace { func (g *Grace) addConn(c *Conn) error { g.mu.Lock() defer g.mu.Unlock() - if g.closing { + if g.closed { c.Close(StatusGoingAway, "server shutting down") return errors.New("server shutting down") } @@ -72,7 +75,8 @@ type gracefulContextKey struct{} // connections with StatusGoingAway. func (g *Grace) Close() error { g.mu.Lock() - g.closing = true + g.shuttingDown = true + g.closed = true var wg sync.WaitGroup for c := range g.conns { wg.Add(1) @@ -97,7 +101,7 @@ func (g *Grace) Shutdown(ctx context.Context) error { defer g.Close() g.mu.Lock() - g.closing = true + g.shuttingDown = true g.mu.Unlock() // Same poll period used by net/http. From 190981dcf7f6af74049e8c6eab9dd500b0a9a47f Mon Sep 17 00:00:00 2001 From: Anmol Sethi Date: Wed, 26 Feb 2020 15:31:29 -0500 Subject: [PATCH 3/4] Add automated test to chat example --- chat-example/chat.go | 61 ++++++++++++----- chat-example/chat_test.go | 137 ++++++++++++++++++++++++++++++++++++++ chat-example/go.mod | 7 -- chat-example/main.go | 10 +-- 4 files changed, 183 insertions(+), 32 deletions(-) create mode 100644 chat-example/chat_test.go delete mode 100644 chat-example/go.mod diff --git a/chat-example/chat.go b/chat-example/chat.go index e6e355d0..9b264195 100644 --- a/chat-example/chat.go +++ b/chat-example/chat.go @@ -15,8 +15,28 @@ import ( // chatServer enables broadcasting to a set of subscribers. type chatServer struct { + registerOnce sync.Once + m http.ServeMux + subscribersMu sync.RWMutex - subscribers map[chan<- []byte]struct{} + subscribers map[*subscriber]struct{} +} + +// subscriber represents a subscriber. +// Messages are sent on the msgs channel and if the client +// cannot keep up with the messages, closeSlow is called. +type subscriber struct { + msgs chan []byte + closeSlow func() +} + +func (cs *chatServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { + cs.registerOnce.Do(func() { + cs.m.Handle("/", http.FileServer(http.Dir("."))) + cs.m.HandleFunc("/subscribe", cs.subscribeHandler) + cs.m.HandleFunc("/publish", cs.publishHandler) + }) + cs.m.ServeHTTP(w, r) } // subscribeHandler accepts the WebSocket connection and then subscribes @@ -57,11 +77,13 @@ func (cs *chatServer) publishHandler(w http.ResponseWriter, r *http.Request) { } cs.publish(msg) + + w.WriteHeader(http.StatusAccepted) } // subscribe subscribes the given WebSocket to all broadcast messages. -// It creates a msgs chan with a buffer of 16 to give some room to slower -// connections and then registers it. It then listens for all messages +// It creates a subscriber with a buffered msgs chan to give some room to slower +// connections and then registers the subscriber. It then listens for all messages // and writes them to the WebSocket. If the context is cancelled or // an error occurs, it returns and deletes the subscription. // @@ -70,13 +92,18 @@ func (cs *chatServer) publishHandler(w http.ResponseWriter, r *http.Request) { func (cs *chatServer) subscribe(ctx context.Context, c *websocket.Conn) error { ctx = c.CloseRead(ctx) - msgs := make(chan []byte, 16) - cs.addSubscriber(msgs) - defer cs.deleteSubscriber(msgs) + s := &subscriber{ + msgs: make(chan []byte, 16), + closeSlow: func() { + c.Close(websocket.StatusPolicyViolation, "connection too slow to keep up with messages") + }, + } + cs.addSubscriber(s) + defer cs.deleteSubscriber(s) for { select { - case msg := <-msgs: + case msg := <-s.msgs: err := writeTimeout(ctx, time.Second*5, c, msg) if err != nil { return err @@ -94,29 +121,29 @@ func (cs *chatServer) publish(msg []byte) { cs.subscribersMu.RLock() defer cs.subscribersMu.RUnlock() - for c := range cs.subscribers { + for s := range cs.subscribers { select { - case c <- msg: + case s.msgs <- msg: default: + go s.closeSlow() } } } -// addSubscriber registers a subscriber with a channel -// on which to send messages. -func (cs *chatServer) addSubscriber(msgs chan<- []byte) { +// addSubscriber registers a subscriber. +func (cs *chatServer) addSubscriber(s *subscriber) { cs.subscribersMu.Lock() if cs.subscribers == nil { - cs.subscribers = make(map[chan<- []byte]struct{}) + cs.subscribers = make(map[*subscriber]struct{}) } - cs.subscribers[msgs] = struct{}{} + cs.subscribers[s] = struct{}{} cs.subscribersMu.Unlock() } -// deleteSubscriber deletes the subscriber with the given msgs channel. -func (cs *chatServer) deleteSubscriber(msgs chan []byte) { +// deleteSubscriber deletes the given subscriber. +func (cs *chatServer) deleteSubscriber(s *subscriber) { cs.subscribersMu.Lock() - delete(cs.subscribers, msgs) + delete(cs.subscribers, s) cs.subscribersMu.Unlock() } diff --git a/chat-example/chat_test.go b/chat-example/chat_test.go new file mode 100644 index 00000000..d1772381 --- /dev/null +++ b/chat-example/chat_test.go @@ -0,0 +1,137 @@ +// +build !js + +package main + +import ( + "context" + "errors" + "fmt" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "nhooyr.io/websocket" +) + +func TestGrace(t *testing.T) { + t.Parallel() + + var cs chatServer + var g websocket.Grace + s := httptest.NewServer(g.Handler(&cs)) + defer s.Close() + defer g.Close() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + + cl1, err := newClient(ctx, s.URL) + assertSuccess(t, err) + defer cl1.Close() + + cl2, err := newClient(ctx, s.URL) + assertSuccess(t, err) + defer cl2.Close() + + err = cl1.publish(ctx, "hello") + assertSuccess(t, err) + + assertReceivedMessage(ctx, cl1, "hello") + assertReceivedMessage(ctx, cl2, "hello") +} + +type client struct { + msgs chan string + url string + c *websocket.Conn +} + +func newClient(ctx context.Context, url string) (*client, error) { + wsURL := strings.ReplaceAll(url, "http://", "ws://") + c, _, err := websocket.Dial(ctx, wsURL+"/subscribe", nil) + if err != nil { + return nil, err + } + + cl := &client{ + msgs: make(chan string, 16), + url: url, + c: c, + } + go cl.readLoop() + + return cl, nil +} + +func (cl *client) readLoop() { + defer cl.c.Close(websocket.StatusInternalError, "") + defer close(cl.msgs) + + for { + typ, b, err := cl.c.Read(context.Background()) + if err != nil { + return + } + + if typ != websocket.MessageText { + cl.c.Close(websocket.StatusUnsupportedData, "expected text message") + return + } + + select { + case cl.msgs <- string(b): + default: + cl.c.Close(websocket.StatusInternalError, "messages coming in too fast to handle") + return + } + } +} + +func (cl *client) receive(ctx context.Context) (string, error) { + select { + case msg, ok := <-cl.msgs: + if !ok { + return "", errors.New("client closed") + } + return msg, nil + case <-ctx.Done(): + return "", ctx.Err() + } +} + +func (cl *client) publish(ctx context.Context, msg string) error { + req, _ := http.NewRequestWithContext(ctx, http.MethodPost, cl.url+"/publish", strings.NewReader(msg)) + resp, err := http.DefaultClient.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusAccepted { + return fmt.Errorf("publish request failed: %v", resp.StatusCode) + } + return nil +} + +func (cl *client) Close() error { + return cl.c.Close(websocket.StatusNormalClosure, "") +} + +func assertSuccess(t *testing.T, err error) { + t.Helper() + if err != nil { + t.Fatal(err) + } +} + +func assertReceivedMessage(ctx context.Context, cl *client, msg string) error { + msg, err := cl.receive(ctx) + if err != nil { + return err + } + if msg != "hello" { + return fmt.Errorf("expected hello but got %q", msg) + } + return nil +} diff --git a/chat-example/go.mod b/chat-example/go.mod deleted file mode 100644 index c47a5a2f..00000000 --- a/chat-example/go.mod +++ /dev/null @@ -1,7 +0,0 @@ -module nhooyr.io/websocket/example-chat - -go 1.13 - -require nhooyr.io/websocket v0.0.0 - -replace nhooyr.io/websocket => ../ diff --git a/chat-example/main.go b/chat-example/main.go index f985d382..a265f60c 100644 --- a/chat-example/main.go +++ b/chat-example/main.go @@ -35,16 +35,10 @@ func run() error { } log.Printf("listening on http://%v", l.Addr()) - var ws chatServer - - m := http.NewServeMux() - m.Handle("/", http.FileServer(http.Dir("."))) - m.HandleFunc("/subscribe", ws.subscribeHandler) - m.HandleFunc("/publish", ws.publishHandler) - + var cs chatServer var g websocket.Grace s := http.Server{ - Handler: g.Handler(m), + Handler: g.Handler(&cs), ReadTimeout: time.Second * 10, WriteTimeout: time.Second * 10, } From da3aa8cfcc08909ea3cd41153637e5c1697bac59 Mon Sep 17 00:00:00 2001 From: Anmol Sethi Date: Wed, 26 Feb 2020 20:39:59 -0500 Subject: [PATCH 4/4] Improve chat example test --- chat-example/README.md | 6 + chat-example/chat.go | 67 ++++++--- chat-example/chat_test.go | 285 ++++++++++++++++++++++++++++---------- chat-example/index.js | 17 ++- chat-example/main.go | 4 +- ci/test.mk | 3 +- 6 files changed, 284 insertions(+), 98 deletions(-) diff --git a/chat-example/README.md b/chat-example/README.md index ef06275d..a4c99a93 100644 --- a/chat-example/README.md +++ b/chat-example/README.md @@ -25,3 +25,9 @@ assets, the `/subscribe` WebSocket endpoint and the HTTP POST `/publish` endpoin The code is well commented. I would recommend starting in `main.go` and then `chat.go` followed by `index.html` and then `index.js`. + +There are two automated tests for the server included in `chat_test.go`. The first is a simple one +client echo test. It publishes a single message and ensures it's received. + +The second is a complex concurrency test where 10 clients send 128 unique messages +of max 128 bytes concurrently. The test ensures all messages are seen by every client. diff --git a/chat-example/chat.go b/chat-example/chat.go index 9b264195..532e50f5 100644 --- a/chat-example/chat.go +++ b/chat-example/chat.go @@ -3,25 +3,57 @@ package main import ( "context" "errors" - "io" "io/ioutil" "log" "net/http" "sync" "time" + "golang.org/x/time/rate" + "nhooyr.io/websocket" ) // chatServer enables broadcasting to a set of subscribers. type chatServer struct { - registerOnce sync.Once - m http.ServeMux - - subscribersMu sync.RWMutex + // subscriberMessageBuffer controls the max number + // of messages that can be queued for a subscriber + // before it is kicked. + // + // Defaults to 16. + subscriberMessageBuffer int + + // publishLimiter controls the rate limit applied to the publish endpoint. + // + // Defaults to one publish every 100ms with a burst of 8. + publishLimiter *rate.Limiter + + // logf controls where logs are sent. + // Defaults to log.Printf. + logf func(f string, v ...interface{}) + + // serveMux routes the various endpoints to the appropriate handler. + serveMux http.ServeMux + + subscribersMu sync.Mutex subscribers map[*subscriber]struct{} } +// newChatServer constructs a chatServer with the defaults. +func newChatServer() *chatServer { + cs := &chatServer{ + subscriberMessageBuffer: 16, + logf: log.Printf, + subscribers: make(map[*subscriber]struct{}), + publishLimiter: rate.NewLimiter(rate.Every(time.Millisecond*100), 8), + } + cs.serveMux.Handle("/", http.FileServer(http.Dir("."))) + cs.serveMux.HandleFunc("/subscribe", cs.subscribeHandler) + cs.serveMux.HandleFunc("/publish", cs.publishHandler) + + return cs +} + // subscriber represents a subscriber. // Messages are sent on the msgs channel and if the client // cannot keep up with the messages, closeSlow is called. @@ -31,12 +63,7 @@ type subscriber struct { } func (cs *chatServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { - cs.registerOnce.Do(func() { - cs.m.Handle("/", http.FileServer(http.Dir("."))) - cs.m.HandleFunc("/subscribe", cs.subscribeHandler) - cs.m.HandleFunc("/publish", cs.publishHandler) - }) - cs.m.ServeHTTP(w, r) + cs.serveMux.ServeHTTP(w, r) } // subscribeHandler accepts the WebSocket connection and then subscribes @@ -44,7 +71,7 @@ func (cs *chatServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (cs *chatServer) subscribeHandler(w http.ResponseWriter, r *http.Request) { c, err := websocket.Accept(w, r, nil) if err != nil { - log.Print(err) + cs.logf("%v", err) return } defer c.Close(websocket.StatusInternalError, "") @@ -58,7 +85,8 @@ func (cs *chatServer) subscribeHandler(w http.ResponseWriter, r *http.Request) { return } if err != nil { - log.Print(err) + cs.logf("%v", err) + return } } @@ -69,7 +97,7 @@ func (cs *chatServer) publishHandler(w http.ResponseWriter, r *http.Request) { http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed) return } - body := io.LimitReader(r.Body, 8192) + body := http.MaxBytesReader(w, r.Body, 8192) msg, err := ioutil.ReadAll(body) if err != nil { http.Error(w, http.StatusText(http.StatusRequestEntityTooLarge), http.StatusRequestEntityTooLarge) @@ -93,7 +121,7 @@ func (cs *chatServer) subscribe(ctx context.Context, c *websocket.Conn) error { ctx = c.CloseRead(ctx) s := &subscriber{ - msgs: make(chan []byte, 16), + msgs: make(chan []byte, cs.subscriberMessageBuffer), closeSlow: func() { c.Close(websocket.StatusPolicyViolation, "connection too slow to keep up with messages") }, @@ -118,8 +146,10 @@ func (cs *chatServer) subscribe(ctx context.Context, c *websocket.Conn) error { // It never blocks and so messages to slow subscribers // are dropped. func (cs *chatServer) publish(msg []byte) { - cs.subscribersMu.RLock() - defer cs.subscribersMu.RUnlock() + cs.subscribersMu.Lock() + defer cs.subscribersMu.Unlock() + + cs.publishLimiter.Wait(context.Background()) for s := range cs.subscribers { select { @@ -133,9 +163,6 @@ func (cs *chatServer) publish(msg []byte) { // addSubscriber registers a subscriber. func (cs *chatServer) addSubscriber(s *subscriber) { cs.subscribersMu.Lock() - if cs.subscribers == nil { - cs.subscribers = make(map[*subscriber]struct{}) - } cs.subscribers[s] = struct{}{} cs.subscribersMu.Unlock() } diff --git a/chat-example/chat_test.go b/chat-example/chat_test.go index d1772381..491499cc 100644 --- a/chat-example/chat_test.go +++ b/chat-example/chat_test.go @@ -4,104 +4,214 @@ package main import ( "context" - "errors" + "crypto/rand" "fmt" + "math/big" "net/http" "net/http/httptest" "strings" + "sync" "testing" "time" + "golang.org/x/time/rate" + "nhooyr.io/websocket" ) -func TestGrace(t *testing.T) { +func Test_chatServer(t *testing.T) { t.Parallel() - var cs chatServer + // This is a simple echo test with a single client. + // The client sends a message and ensures it receives + // it on its WebSocket. + t.Run("simple", func(t *testing.T) { + t.Parallel() + + url, closeFn := setupTest(t) + defer closeFn() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) + defer cancel() + + cl, err := newClient(ctx, url) + assertSuccess(t, err) + defer cl.Close() + + expMsg := randString(512) + err = cl.publish(ctx, expMsg) + assertSuccess(t, err) + + msg, err := cl.nextMessage() + assertSuccess(t, err) + + if expMsg != msg { + t.Fatalf("expected %v but got %v", expMsg, msg) + } + }) + + // This test is a complex concurrency test. + // 10 clients are started that send 128 different + // messages of max 128 bytes concurrently. + // + // The test verifies that every message is seen by ever client + // and no errors occur anywhere. + t.Run("concurrency", func(t *testing.T) { + t.Parallel() + + const nmessages = 128 + const maxMessageSize = 128 + const nclients = 10 + + url, closeFn := setupTest(t) + defer closeFn() + + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + defer cancel() + + var clients []*client + var clientMsgs []map[string]struct{} + for i := 0; i < nclients; i++ { + cl, err := newClient(ctx, url) + assertSuccess(t, err) + defer cl.Close() + + clients = append(clients, cl) + clientMsgs = append(clientMsgs, randMessages(nmessages, maxMessageSize)) + } + + allMessages := make(map[string]struct{}) + for _, msgs := range clientMsgs { + for m := range msgs { + allMessages[m] = struct{}{} + } + } + + var wg sync.WaitGroup + for i, cl := range clients { + i := i + cl := cl + + wg.Add(1) + go func() { + defer wg.Done() + err := cl.publishMsgs(ctx, clientMsgs[i]) + if err != nil { + t.Errorf("client %d failed to publish all messages: %v", i, err) + } + }() + + wg.Add(1) + go func() { + defer wg.Done() + err := testAllMessagesReceived(cl, nclients*nmessages, allMessages) + if err != nil { + t.Errorf("client %d failed to receive all messages: %v", i, err) + } + }() + } + + wg.Wait() + }) +} + +// setupTest sets up chatServer that can be used +// via the returned url. +// +// Defer closeFn to ensure everything is cleaned up at +// the end of the test. +// +// chatServer logs will be logged via t.Logf. +func setupTest(t *testing.T) (url string, closeFn func()) { + cs := newChatServer() + cs.logf = t.Logf + + // To ensure tests run quickly under even -race. + cs.subscriberMessageBuffer = 4096 + cs.publishLimiter.SetLimit(rate.Inf) + var g websocket.Grace - s := httptest.NewServer(g.Handler(&cs)) - defer s.Close() - defer g.Close() + s := httptest.NewServer(g.Handler(cs)) + return s.URL, func() { + s.Close() + g.Close() + } +} - ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) - defer cancel() +// testAllMessagesReceived ensures that after n reads, all msgs in msgs +// have been read. +func testAllMessagesReceived(cl *client, n int, msgs map[string]struct{}) error { + msgs = cloneMessages(msgs) - cl1, err := newClient(ctx, s.URL) - assertSuccess(t, err) - defer cl1.Close() + for i := 0; i < n; i++ { + msg, err := cl.nextMessage() + if err != nil { + return err + } + delete(msgs, msg) + } - cl2, err := newClient(ctx, s.URL) - assertSuccess(t, err) - defer cl2.Close() + if len(msgs) != 0 { + return fmt.Errorf("did not receive all expected messages: %q", msgs) + } + return nil +} - err = cl1.publish(ctx, "hello") - assertSuccess(t, err) +func cloneMessages(msgs map[string]struct{}) map[string]struct{} { + msgs2 := make(map[string]struct{}, len(msgs)) + for m := range msgs { + msgs2[m] = struct{}{} + } + return msgs2 +} - assertReceivedMessage(ctx, cl1, "hello") - assertReceivedMessage(ctx, cl2, "hello") +func randMessages(n, maxMessageLength int) map[string]struct{} { + msgs := make(map[string]struct{}) + for i := 0; i < n; i++ { + m := randString(randInt(maxMessageLength)) + if _, ok := msgs[m]; ok { + i-- + continue + } + msgs[m] = struct{}{} + } + return msgs +} + +func assertSuccess(t *testing.T, err error) { + t.Helper() + if err != nil { + t.Fatal(err) + } } type client struct { - msgs chan string - url string - c *websocket.Conn + url string + c *websocket.Conn } func newClient(ctx context.Context, url string) (*client, error) { - wsURL := strings.ReplaceAll(url, "http://", "ws://") + wsURL := strings.Replace(url, "http://", "ws://", 1) c, _, err := websocket.Dial(ctx, wsURL+"/subscribe", nil) if err != nil { return nil, err } cl := &client{ - msgs: make(chan string, 16), - url: url, - c: c, + url: url, + c: c, } - go cl.readLoop() return cl, nil } -func (cl *client) readLoop() { - defer cl.c.Close(websocket.StatusInternalError, "") - defer close(cl.msgs) - - for { - typ, b, err := cl.c.Read(context.Background()) +func (cl *client) publish(ctx context.Context, msg string) (err error) { + defer func() { if err != nil { - return + cl.c.Close(websocket.StatusInternalError, "publish failed") } + }() - if typ != websocket.MessageText { - cl.c.Close(websocket.StatusUnsupportedData, "expected text message") - return - } - - select { - case cl.msgs <- string(b): - default: - cl.c.Close(websocket.StatusInternalError, "messages coming in too fast to handle") - return - } - } -} - -func (cl *client) receive(ctx context.Context) (string, error) { - select { - case msg, ok := <-cl.msgs: - if !ok { - return "", errors.New("client closed") - } - return msg, nil - case <-ctx.Done(): - return "", ctx.Err() - } -} - -func (cl *client) publish(ctx context.Context, msg string) error { req, _ := http.NewRequestWithContext(ctx, http.MethodPost, cl.url+"/publish", strings.NewReader(msg)) resp, err := http.DefaultClient.Do(req) if err != nil { @@ -114,24 +224,59 @@ func (cl *client) publish(ctx context.Context, msg string) error { return nil } +func (cl *client) publishMsgs(ctx context.Context, msgs map[string]struct{}) error { + for m := range msgs { + err := cl.publish(ctx, m) + if err != nil { + return err + } + } + return nil +} + +func (cl *client) nextMessage() (string, error) { + typ, b, err := cl.c.Read(context.Background()) + if err != nil { + return "", err + } + + if typ != websocket.MessageText { + cl.c.Close(websocket.StatusUnsupportedData, "expected text message") + return "", fmt.Errorf("expected text message but got %v", typ) + } + return string(b), nil +} + func (cl *client) Close() error { return cl.c.Close(websocket.StatusNormalClosure, "") } -func assertSuccess(t *testing.T, err error) { - t.Helper() +// randString generates a random string with length n. +func randString(n int) string { + b := make([]byte, n) + _, err := rand.Reader.Read(b) if err != nil { - t.Fatal(err) + panic(fmt.Sprintf("failed to generate rand bytes: %v", err)) + } + + s := strings.ToValidUTF8(string(b), "_") + s = strings.ReplaceAll(s, "\x00", "_") + if len(s) > n { + return s[:n] } + if len(s) < n { + // Pad with = + extra := n - len(s) + return s + strings.Repeat("=", extra) + } + return s } -func assertReceivedMessage(ctx context.Context, cl *client, msg string) error { - msg, err := cl.receive(ctx) +// randInt returns a randomly generated integer between [0, max). +func randInt(max int) int { + x, err := rand.Int(rand.Reader, big.NewInt(int64(max))) if err != nil { - return err + panic(fmt.Sprintf("failed to get random int: %v", err)) } - if msg != "hello" { - return fmt.Errorf("expected hello but got %q", msg) - } - return nil + return int(x.Int64()) } diff --git a/chat-example/index.js b/chat-example/index.js index a42c2d30..5868e7ca 100644 --- a/chat-example/index.js +++ b/chat-example/index.js @@ -51,7 +51,7 @@ appendLog("Submit a message to get started!") // onsubmit publishes the message from the user when the form is submitted. - publishForm.onsubmit = ev => { + publishForm.onsubmit = async ev => { ev.preventDefault() const msg = messageInput.value @@ -61,9 +61,16 @@ messageInput.value = "" expectingMessage = true - fetch("/publish", { - method: "POST", - body: msg, - }) + try { + const resp = await fetch("/publish", { + method: "POST", + body: msg, + }) + if (resp.status !== 202) { + throw new Error(`Unexpected HTTP Status ${resp.status} ${resp.statusText}`) + } + } catch (err) { + appendLog(`Publish failed: ${err.message}`, true) + } } })() diff --git a/chat-example/main.go b/chat-example/main.go index a265f60c..1b6f3266 100644 --- a/chat-example/main.go +++ b/chat-example/main.go @@ -35,10 +35,10 @@ func run() error { } log.Printf("listening on http://%v", l.Addr()) - var cs chatServer + cs := newChatServer() var g websocket.Grace s := http.Server{ - Handler: g.Handler(&cs), + Handler: g.Handler(cs), ReadTimeout: time.Second * 10, WriteTimeout: time.Second * 10, } diff --git a/ci/test.mk b/ci/test.mk index c62a25b6..291d6beb 100644 --- a/ci/test.mk +++ b/ci/test.mk @@ -11,6 +11,7 @@ coveralls: gotest goveralls -coverprofile=ci/out/coverage.prof gotest: - go test -timeout=30m -covermode=count -coverprofile=ci/out/coverage.prof -coverpkg=./... $${GOTESTFLAGS-} ./... + go test -timeout=30m -covermode=atomic -coverprofile=ci/out/coverage.prof -coverpkg=./... $${GOTESTFLAGS-} ./... sed -i '/stringer\.go/d' ci/out/coverage.prof sed -i '/nhooyr.io\/websocket\/internal\/test/d' ci/out/coverage.prof + sed -i '/chat-example/d' ci/out/coverage.prof