@@ -3,28 +3,75 @@ package main
3
3
import (
4
4
"context"
5
5
"errors"
6
- "io"
7
6
"io/ioutil"
8
7
"log"
9
8
"net/http"
10
9
"sync"
11
10
"time"
12
11
12
+ "golang.org/x/time/rate"
13
+
13
14
"nhooyr.io/websocket"
14
15
)
15
16
16
17
// chatServer enables broadcasting to a set of subscribers.
17
18
type chatServer struct {
18
- subscribersMu sync.RWMutex
19
- subscribers map [chan <- []byte ]struct {}
19
+ // subscriberMessageBuffer controls the max number
20
+ // of messages that can be queued for a subscriber
21
+ // before it is kicked.
22
+ //
23
+ // Defaults to 16.
24
+ subscriberMessageBuffer int
25
+
26
+ // publishLimiter controls the rate limit applied to the publish endpoint.
27
+ //
28
+ // Defaults to one publish every 100ms with a burst of 8.
29
+ publishLimiter * rate.Limiter
30
+
31
+ // logf controls where logs are sent.
32
+ // Defaults to log.Printf.
33
+ logf func (f string , v ... interface {})
34
+
35
+ // serveMux routes the various endpoints to the appropriate handler.
36
+ serveMux http.ServeMux
37
+
38
+ subscribersMu sync.Mutex
39
+ subscribers map [* subscriber ]struct {}
40
+ }
41
+
42
+ // newChatServer constructs a chatServer with the defaults.
43
+ func newChatServer () * chatServer {
44
+ cs := & chatServer {
45
+ subscriberMessageBuffer : 16 ,
46
+ logf : log .Printf ,
47
+ subscribers : make (map [* subscriber ]struct {}),
48
+ publishLimiter : rate .NewLimiter (rate .Every (time .Millisecond * 100 ), 8 ),
49
+ }
50
+ cs .serveMux .Handle ("/" , http .FileServer (http .Dir ("." )))
51
+ cs .serveMux .HandleFunc ("/subscribe" , cs .subscribeHandler )
52
+ cs .serveMux .HandleFunc ("/publish" , cs .publishHandler )
53
+
54
+ return cs
55
+ }
56
+
57
+ // subscriber represents a subscriber.
58
+ // Messages are sent on the msgs channel and if the client
59
+ // cannot keep up with the messages, closeSlow is called.
60
+ type subscriber struct {
61
+ msgs chan []byte
62
+ closeSlow func ()
63
+ }
64
+
65
+ func (cs * chatServer ) ServeHTTP (w http.ResponseWriter , r * http.Request ) {
66
+ cs .serveMux .ServeHTTP (w , r )
20
67
}
21
68
22
69
// subscribeHandler accepts the WebSocket connection and then subscribes
23
70
// it to all future messages.
24
71
func (cs * chatServer ) subscribeHandler (w http.ResponseWriter , r * http.Request ) {
25
72
c , err := websocket .Accept (w , r , nil )
26
73
if err != nil {
27
- log . Print ( err )
74
+ cs . logf ( "%v" , err )
28
75
return
29
76
}
30
77
defer c .Close (websocket .StatusInternalError , "" )
@@ -38,7 +85,8 @@ func (cs *chatServer) subscribeHandler(w http.ResponseWriter, r *http.Request) {
38
85
return
39
86
}
40
87
if err != nil {
41
- log .Print (err )
88
+ cs .logf ("%v" , err )
89
+ return
42
90
}
43
91
}
44
92
@@ -49,19 +97,21 @@ func (cs *chatServer) publishHandler(w http.ResponseWriter, r *http.Request) {
49
97
http .Error (w , http .StatusText (http .StatusMethodNotAllowed ), http .StatusMethodNotAllowed )
50
98
return
51
99
}
52
- body := io . LimitReader ( r .Body , 8192 )
100
+ body := http . MaxBytesReader ( w , r .Body , 8192 )
53
101
msg , err := ioutil .ReadAll (body )
54
102
if err != nil {
55
103
http .Error (w , http .StatusText (http .StatusRequestEntityTooLarge ), http .StatusRequestEntityTooLarge )
56
104
return
57
105
}
58
106
59
107
cs .publish (msg )
108
+
109
+ w .WriteHeader (http .StatusAccepted )
60
110
}
61
111
62
112
// subscribe subscribes the given WebSocket to all broadcast messages.
63
- // It creates a msgs chan with a buffer of 16 to give some room to slower
64
- // connections and then registers it . It then listens for all messages
113
+ // It creates a subscriber with a buffered msgs chan to give some room to slower
114
+ // connections and then registers the subscriber . It then listens for all messages
65
115
// and writes them to the WebSocket. If the context is cancelled or
66
116
// an error occurs, it returns and deletes the subscription.
67
117
//
@@ -70,13 +120,18 @@ func (cs *chatServer) publishHandler(w http.ResponseWriter, r *http.Request) {
70
120
func (cs * chatServer ) subscribe (ctx context.Context , c * websocket.Conn ) error {
71
121
ctx = c .CloseRead (ctx )
72
122
73
- msgs := make (chan []byte , 16 )
74
- cs .addSubscriber (msgs )
75
- defer cs .deleteSubscriber (msgs )
123
+ s := & subscriber {
124
+ msgs : make (chan []byte , cs .subscriberMessageBuffer ),
125
+ closeSlow : func () {
126
+ c .Close (websocket .StatusPolicyViolation , "connection too slow to keep up with messages" )
127
+ },
128
+ }
129
+ cs .addSubscriber (s )
130
+ defer cs .deleteSubscriber (s )
76
131
77
132
for {
78
133
select {
79
- case msg := <- msgs :
134
+ case msg := <- s . msgs :
80
135
err := writeTimeout (ctx , time .Second * 5 , c , msg )
81
136
if err != nil {
82
137
return err
@@ -91,32 +146,31 @@ func (cs *chatServer) subscribe(ctx context.Context, c *websocket.Conn) error {
91
146
// It never blocks and so messages to slow subscribers
92
147
// are dropped.
93
148
func (cs * chatServer ) publish (msg []byte ) {
94
- cs .subscribersMu .RLock ()
95
- defer cs .subscribersMu .RUnlock ()
149
+ cs .subscribersMu .Lock ()
150
+ defer cs .subscribersMu .Unlock ()
151
+
152
+ cs .publishLimiter .Wait (context .Background ())
96
153
97
- for c := range cs .subscribers {
154
+ for s := range cs .subscribers {
98
155
select {
99
- case c <- msg :
156
+ case s . msgs <- msg :
100
157
default :
158
+ go s .closeSlow ()
101
159
}
102
160
}
103
161
}
104
162
105
- // addSubscriber registers a subscriber with a channel
106
- // on which to send messages.
107
- func (cs * chatServer ) addSubscriber (msgs chan <- []byte ) {
163
+ // addSubscriber registers a subscriber.
164
+ func (cs * chatServer ) addSubscriber (s * subscriber ) {
108
165
cs .subscribersMu .Lock ()
109
- if cs .subscribers == nil {
110
- cs .subscribers = make (map [chan <- []byte ]struct {})
111
- }
112
- cs .subscribers [msgs ] = struct {}{}
166
+ cs .subscribers [s ] = struct {}{}
113
167
cs .subscribersMu .Unlock ()
114
168
}
115
169
116
- // deleteSubscriber deletes the subscriber with the given msgs channel .
117
- func (cs * chatServer ) deleteSubscriber (msgs chan [] byte ) {
170
+ // deleteSubscriber deletes the given subscriber .
171
+ func (cs * chatServer ) deleteSubscriber (s * subscriber ) {
118
172
cs .subscribersMu .Lock ()
119
- delete (cs .subscribers , msgs )
173
+ delete (cs .subscribers , s )
120
174
cs .subscribersMu .Unlock ()
121
175
}
122
176
0 commit comments