Skip to content

Commit 0d9471d

Browse files
authored
Merge pull request #200 from nhooyr/server
Add Grace to gracefully close WebSocket connections
2 parents deb14cf + da3aa8c commit 0d9471d

File tree

16 files changed

+643
-64
lines changed

16 files changed

+643
-64
lines changed

accept.go

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,13 @@ func Accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (*Conn,
7575
func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Conn, err error) {
7676
defer errd.Wrap(&err, "failed to accept WebSocket connection")
7777

78+
g := graceFromRequest(r)
79+
if g != nil && g.isShuttingdown() {
80+
err := errors.New("server shutting down")
81+
http.Error(w, err.Error(), http.StatusServiceUnavailable)
82+
return nil, err
83+
}
84+
7885
if opts == nil {
7986
opts = &AcceptOptions{}
8087
}
@@ -134,7 +141,7 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Con
134141
b, _ := brw.Reader.Peek(brw.Reader.Buffered())
135142
brw.Reader.Reset(io.MultiReader(bytes.NewReader(b), netConn))
136143

137-
return newConn(connConfig{
144+
c := newConn(connConfig{
138145
subprotocol: w.Header().Get("Sec-WebSocket-Protocol"),
139146
rwc: netConn,
140147
client: false,
@@ -143,7 +150,16 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Con
143150

144151
br: brw.Reader,
145152
bw: brw.Writer,
146-
}), nil
153+
})
154+
155+
if g != nil {
156+
err = g.addConn(c)
157+
if err != nil {
158+
return nil, err
159+
}
160+
}
161+
162+
return c, nil
147163
}
148164

149165
func verifyClientRequest(w http.ResponseWriter, r *http.Request) (errCode int, _ error) {

chat-example/README.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,9 @@ assets, the `/subscribe` WebSocket endpoint and the HTTP POST `/publish` endpoin
2525

2626
The code is well commented. I would recommend starting in `main.go` and then `chat.go` followed by
2727
`index.html` and then `index.js`.
28+
29+
There are two automated tests for the server included in `chat_test.go`. The first is a simple one
30+
client echo test. It publishes a single message and ensures it's received.
31+
32+
The second is a complex concurrency test where 10 clients send 128 unique messages
33+
of max 128 bytes concurrently. The test ensures all messages are seen by every client.

chat-example/chat.go

Lines changed: 80 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -3,28 +3,75 @@ package main
33
import (
44
"context"
55
"errors"
6-
"io"
76
"io/ioutil"
87
"log"
98
"net/http"
109
"sync"
1110
"time"
1211

12+
"golang.org/x/time/rate"
13+
1314
"nhooyr.io/websocket"
1415
)
1516

1617
// chatServer enables broadcasting to a set of subscribers.
1718
type chatServer struct {
18-
subscribersMu sync.RWMutex
19-
subscribers map[chan<- []byte]struct{}
19+
// subscriberMessageBuffer controls the max number
20+
// of messages that can be queued for a subscriber
21+
// before it is kicked.
22+
//
23+
// Defaults to 16.
24+
subscriberMessageBuffer int
25+
26+
// publishLimiter controls the rate limit applied to the publish endpoint.
27+
//
28+
// Defaults to one publish every 100ms with a burst of 8.
29+
publishLimiter *rate.Limiter
30+
31+
// logf controls where logs are sent.
32+
// Defaults to log.Printf.
33+
logf func(f string, v ...interface{})
34+
35+
// serveMux routes the various endpoints to the appropriate handler.
36+
serveMux http.ServeMux
37+
38+
subscribersMu sync.Mutex
39+
subscribers map[*subscriber]struct{}
40+
}
41+
42+
// newChatServer constructs a chatServer with the defaults.
43+
func newChatServer() *chatServer {
44+
cs := &chatServer{
45+
subscriberMessageBuffer: 16,
46+
logf: log.Printf,
47+
subscribers: make(map[*subscriber]struct{}),
48+
publishLimiter: rate.NewLimiter(rate.Every(time.Millisecond*100), 8),
49+
}
50+
cs.serveMux.Handle("/", http.FileServer(http.Dir(".")))
51+
cs.serveMux.HandleFunc("/subscribe", cs.subscribeHandler)
52+
cs.serveMux.HandleFunc("/publish", cs.publishHandler)
53+
54+
return cs
55+
}
56+
57+
// subscriber represents a subscriber.
58+
// Messages are sent on the msgs channel and if the client
59+
// cannot keep up with the messages, closeSlow is called.
60+
type subscriber struct {
61+
msgs chan []byte
62+
closeSlow func()
63+
}
64+
65+
func (cs *chatServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
66+
cs.serveMux.ServeHTTP(w, r)
2067
}
2168

2269
// subscribeHandler accepts the WebSocket connection and then subscribes
2370
// it to all future messages.
2471
func (cs *chatServer) subscribeHandler(w http.ResponseWriter, r *http.Request) {
2572
c, err := websocket.Accept(w, r, nil)
2673
if err != nil {
27-
log.Print(err)
74+
cs.logf("%v", err)
2875
return
2976
}
3077
defer c.Close(websocket.StatusInternalError, "")
@@ -38,7 +85,8 @@ func (cs *chatServer) subscribeHandler(w http.ResponseWriter, r *http.Request) {
3885
return
3986
}
4087
if err != nil {
41-
log.Print(err)
88+
cs.logf("%v", err)
89+
return
4290
}
4391
}
4492

@@ -49,19 +97,21 @@ func (cs *chatServer) publishHandler(w http.ResponseWriter, r *http.Request) {
4997
http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed)
5098
return
5199
}
52-
body := io.LimitReader(r.Body, 8192)
100+
body := http.MaxBytesReader(w, r.Body, 8192)
53101
msg, err := ioutil.ReadAll(body)
54102
if err != nil {
55103
http.Error(w, http.StatusText(http.StatusRequestEntityTooLarge), http.StatusRequestEntityTooLarge)
56104
return
57105
}
58106

59107
cs.publish(msg)
108+
109+
w.WriteHeader(http.StatusAccepted)
60110
}
61111

62112
// subscribe subscribes the given WebSocket to all broadcast messages.
63-
// It creates a msgs chan with a buffer of 16 to give some room to slower
64-
// connections and then registers it. It then listens for all messages
113+
// It creates a subscriber with a buffered msgs chan to give some room to slower
114+
// connections and then registers the subscriber. It then listens for all messages
65115
// and writes them to the WebSocket. If the context is cancelled or
66116
// an error occurs, it returns and deletes the subscription.
67117
//
@@ -70,13 +120,18 @@ func (cs *chatServer) publishHandler(w http.ResponseWriter, r *http.Request) {
70120
func (cs *chatServer) subscribe(ctx context.Context, c *websocket.Conn) error {
71121
ctx = c.CloseRead(ctx)
72122

73-
msgs := make(chan []byte, 16)
74-
cs.addSubscriber(msgs)
75-
defer cs.deleteSubscriber(msgs)
123+
s := &subscriber{
124+
msgs: make(chan []byte, cs.subscriberMessageBuffer),
125+
closeSlow: func() {
126+
c.Close(websocket.StatusPolicyViolation, "connection too slow to keep up with messages")
127+
},
128+
}
129+
cs.addSubscriber(s)
130+
defer cs.deleteSubscriber(s)
76131

77132
for {
78133
select {
79-
case msg := <-msgs:
134+
case msg := <-s.msgs:
80135
err := writeTimeout(ctx, time.Second*5, c, msg)
81136
if err != nil {
82137
return err
@@ -91,32 +146,31 @@ func (cs *chatServer) subscribe(ctx context.Context, c *websocket.Conn) error {
91146
// It never blocks and so messages to slow subscribers
92147
// are dropped.
93148
func (cs *chatServer) publish(msg []byte) {
94-
cs.subscribersMu.RLock()
95-
defer cs.subscribersMu.RUnlock()
149+
cs.subscribersMu.Lock()
150+
defer cs.subscribersMu.Unlock()
151+
152+
cs.publishLimiter.Wait(context.Background())
96153

97-
for c := range cs.subscribers {
154+
for s := range cs.subscribers {
98155
select {
99-
case c <- msg:
156+
case s.msgs <- msg:
100157
default:
158+
go s.closeSlow()
101159
}
102160
}
103161
}
104162

105-
// addSubscriber registers a subscriber with a channel
106-
// on which to send messages.
107-
func (cs *chatServer) addSubscriber(msgs chan<- []byte) {
163+
// addSubscriber registers a subscriber.
164+
func (cs *chatServer) addSubscriber(s *subscriber) {
108165
cs.subscribersMu.Lock()
109-
if cs.subscribers == nil {
110-
cs.subscribers = make(map[chan<- []byte]struct{})
111-
}
112-
cs.subscribers[msgs] = struct{}{}
166+
cs.subscribers[s] = struct{}{}
113167
cs.subscribersMu.Unlock()
114168
}
115169

116-
// deleteSubscriber deletes the subscriber with the given msgs channel.
117-
func (cs *chatServer) deleteSubscriber(msgs chan []byte) {
170+
// deleteSubscriber deletes the given subscriber.
171+
func (cs *chatServer) deleteSubscriber(s *subscriber) {
118172
cs.subscribersMu.Lock()
119-
delete(cs.subscribers, msgs)
173+
delete(cs.subscribers, s)
120174
cs.subscribersMu.Unlock()
121175
}
122176

0 commit comments

Comments
 (0)