@@ -12,6 +12,7 @@ import (
12
12
"os"
13
13
"regexp"
14
14
"strings"
15
+ "sync"
15
16
"time"
16
17
)
17
18
@@ -64,6 +65,9 @@ type Conn struct {
64
65
conn net.Conn
65
66
topics []string
66
67
timeout time.Duration
68
+
69
+ closeConn sync.Once
70
+ quit chan struct {}
67
71
}
68
72
69
73
func (c * Conn ) writeAll (buf []byte ) error {
@@ -193,7 +197,7 @@ func (c *Conn) subscribe(prefix string) error {
193
197
}
194
198
195
199
func (c * Conn ) readCommand () (string , []byte , error ) {
196
- flag , buf , err := c .readFrame ()
200
+ flag , buf , err := c .readFrame (true )
197
201
if err != nil {
198
202
return "" , nil , err
199
203
}
@@ -256,10 +260,20 @@ func (c *Conn) readReady() error {
256
260
}
257
261
258
262
// Read a frame from the socket, setting deadline before each read to prevent
259
- // timeouts during or between frames.
260
- func (c * Conn ) readFrame () (byte , []byte , error ) {
263
+ // timeouts during or between frames. The initialFrame should be used to denote
264
+ // whether this is the first frame we'll read for a _new_ message.
265
+ //
266
+ // NOTE: This is a blocking call if there is nothing to read from the
267
+ // connection.
268
+ func (c * Conn ) readFrame (initialFrame bool ) (byte , []byte , error ) {
269
+ // We'll only set a read deadline if this is not the first frame of a
270
+ // message. We do this to ensure we receive complete messages in a
271
+ // timely manner.
272
+ if ! initialFrame {
273
+ c .conn .SetReadDeadline (time .Now ().Add (c .timeout ))
274
+ }
275
+
261
276
var flagBuf [1 ]byte
262
- c .conn .SetReadDeadline (time .Now ().Add (c .timeout ))
263
277
if _ , err := io .ReadFull (c .conn , flagBuf [:1 ]); err != nil {
264
278
return 0 , nil , err
265
279
}
@@ -305,11 +319,22 @@ func (c *Conn) readFrame() (byte, []byte, error) {
305
319
return flag , buf , nil
306
320
}
307
321
308
- // Read a message from the socket.
322
+ // readMessage reads a new message from the connection.
323
+ //
324
+ // NOTE: This is a blocking call if there is nothing to read from the
325
+ // connection.
309
326
func (c * Conn ) readMessage () ([][]byte , error ) {
327
+ // We'll only set read deadlines on the underlying connection when
328
+ // reading messages of multiple frames after the first frame has been
329
+ // read. This is done to ensure we receive all of the frames of a
330
+ // message within a reasonable time frame. When reading the first frame,
331
+ // we want to avoid setting them as we don't know when a new message
332
+ // will be available for us to read.
333
+ initialFrame := true
334
+
310
335
var parts [][]byte
311
336
for {
312
- flag , buf , err := c .readFrame ()
337
+ flag , buf , err := c .readFrame (initialFrame )
313
338
if err != nil {
314
339
return nil , err
315
340
}
@@ -326,6 +351,8 @@ func (c *Conn) readMessage() ([][]byte, error) {
326
351
if len (parts ) > 16 {
327
352
return nil , errors .New ("message has too many parts" )
328
353
}
354
+
355
+ initialFrame = false
329
356
}
330
357
return parts , nil
331
358
}
@@ -339,7 +366,12 @@ func Subscribe(addr string, topics []string, timeout time.Duration) (*Conn, erro
339
366
340
367
conn .SetDeadline (time .Now ().Add (10 * time .Second ))
341
368
342
- c := & Conn {conn , topics , timeout }
369
+ c := & Conn {
370
+ conn : conn ,
371
+ topics : topics ,
372
+ timeout : timeout ,
373
+ quit : make (chan struct {}),
374
+ }
343
375
344
376
if err := c .writeGreeting (); err != nil {
345
377
conn .Close ()
@@ -373,36 +405,58 @@ func Subscribe(addr string, topics []string, timeout time.Duration) (*Conn, erro
373
405
}
374
406
375
407
// Receive a message from the publisher. It blocks until a new message is
376
- // received.
408
+ // received. If the connection times out and it was not explicitly terminated,
409
+ // then a timeout error is returned. Otherwise, if it was explicitly terminated,
410
+ // then io.EOF is returned.
377
411
func (c * Conn ) Receive () ([][]byte , error ) {
378
412
messages , err := c .readMessage ()
379
413
// If the error is either nil or a non-EOF error, we return it as-is.
380
414
if err != io .EOF {
381
415
return messages , err
382
416
}
383
- // We got an EOF, so our socket is disconnected. We attempt to
384
- // reconnect. If successful, replace the existing connection with the
385
- // new one. Either way, return a timeout error.
417
+
418
+ // We got an EOF, so our socket is disconnected. If the connection was
419
+ // explicitly terminated, we'll return the EOF error.
420
+ select {
421
+ case <- c .quit :
422
+ return nil , io .EOF
423
+ default :
424
+ }
425
+
426
+ // Otherwise, we'll attempt to reconnect. If successful, we'll replace
427
+ // the existing connection with the new one. Either way, return a
428
+ // timeout error.
386
429
errTimeout := & net.OpError {
387
430
Op : "read" ,
388
431
Net : c .conn .LocalAddr ().Network (),
389
432
Source : c .conn .LocalAddr (),
390
433
Addr : c .conn .RemoteAddr (),
391
434
Err : & reconnectError {err },
392
435
}
393
- newConn , err := Subscribe (c .conn .RemoteAddr ().String (), c .topics ,
394
- c .timeout )
436
+ newConn , err := Subscribe (
437
+ c .conn .RemoteAddr ().String (), c .topics , c .timeout ,
438
+ )
395
439
if err != nil {
396
440
// Prevent CPU overuse by refused reconnection attempts.
397
441
time .Sleep (c .timeout )
398
442
} else {
399
- c .Close ()
400
- * c = * newConn
443
+ c .conn . Close ()
444
+ c . conn = newConn . conn
401
445
}
402
446
return nil , errTimeout
403
447
}
404
448
405
449
// Close the underlying connection. Any further operations will fail.
406
450
func (c * Conn ) Close () error {
407
- return c .conn .Close ()
451
+ var err error
452
+ c .closeConn .Do (func () {
453
+ close (c .quit )
454
+ err = c .conn .Close ()
455
+ })
456
+ return err
457
+ }
458
+
459
+ // RemoteAddr returns the remote network address.
460
+ func (c * Conn ) RemoteAddr () net.Addr {
461
+ return c .conn .RemoteAddr ()
408
462
}
0 commit comments