Skip to content

Commit cea2a03

Browse files
authored
Merge pull request #3 from wpaulino/conn-timeout-block-read
zmq: prevent unnecessary timeouts and return EOF on connection termination
2 parents 462a8a7 + 6e9a863 commit cea2a03

File tree

2 files changed

+73
-16
lines changed

2 files changed

+73
-16
lines changed

go.mod

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
module github.com/lightninglabs/gozmq
2+
3+
go 1.12

zmq.go

Lines changed: 70 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import (
1212
"os"
1313
"regexp"
1414
"strings"
15+
"sync"
1516
"time"
1617
)
1718

@@ -64,6 +65,9 @@ type Conn struct {
6465
conn net.Conn
6566
topics []string
6667
timeout time.Duration
68+
69+
closeConn sync.Once
70+
quit chan struct{}
6771
}
6872

6973
func (c *Conn) writeAll(buf []byte) error {
@@ -193,7 +197,7 @@ func (c *Conn) subscribe(prefix string) error {
193197
}
194198

195199
func (c *Conn) readCommand() (string, []byte, error) {
196-
flag, buf, err := c.readFrame()
200+
flag, buf, err := c.readFrame(true)
197201
if err != nil {
198202
return "", nil, err
199203
}
@@ -256,10 +260,20 @@ func (c *Conn) readReady() error {
256260
}
257261

258262
// Read a frame from the socket, setting deadline before each read to prevent
259-
// timeouts during or between frames.
260-
func (c *Conn) readFrame() (byte, []byte, error) {
263+
// timeouts during or between frames. The initialFrame should be used to denote
264+
// whether this is the first frame we'll read for a _new_ message.
265+
//
266+
// NOTE: This is a blocking call if there is nothing to read from the
267+
// connection.
268+
func (c *Conn) readFrame(initialFrame bool) (byte, []byte, error) {
269+
// We'll only set a read deadline if this is not the first frame of a
270+
// message. We do this to ensure we receive complete messages in a
271+
// timely manner.
272+
if !initialFrame {
273+
c.conn.SetReadDeadline(time.Now().Add(c.timeout))
274+
}
275+
261276
var flagBuf [1]byte
262-
c.conn.SetReadDeadline(time.Now().Add(c.timeout))
263277
if _, err := io.ReadFull(c.conn, flagBuf[:1]); err != nil {
264278
return 0, nil, err
265279
}
@@ -305,11 +319,22 @@ func (c *Conn) readFrame() (byte, []byte, error) {
305319
return flag, buf, nil
306320
}
307321

308-
// Read a message from the socket.
322+
// readMessage reads a new message from the connection.
323+
//
324+
// NOTE: This is a blocking call if there is nothing to read from the
325+
// connection.
309326
func (c *Conn) readMessage() ([][]byte, error) {
327+
// We'll only set read deadlines on the underlying connection when
328+
// reading messages of multiple frames after the first frame has been
329+
// read. This is done to ensure we receive all of the frames of a
330+
// message within a reasonable time frame. When reading the first frame,
331+
// we want to avoid setting them as we don't know when a new message
332+
// will be available for us to read.
333+
initialFrame := true
334+
310335
var parts [][]byte
311336
for {
312-
flag, buf, err := c.readFrame()
337+
flag, buf, err := c.readFrame(initialFrame)
313338
if err != nil {
314339
return nil, err
315340
}
@@ -326,6 +351,8 @@ func (c *Conn) readMessage() ([][]byte, error) {
326351
if len(parts) > 16 {
327352
return nil, errors.New("message has too many parts")
328353
}
354+
355+
initialFrame = false
329356
}
330357
return parts, nil
331358
}
@@ -339,7 +366,12 @@ func Subscribe(addr string, topics []string, timeout time.Duration) (*Conn, erro
339366

340367
conn.SetDeadline(time.Now().Add(10 * time.Second))
341368

342-
c := &Conn{conn, topics, timeout}
369+
c := &Conn{
370+
conn: conn,
371+
topics: topics,
372+
timeout: timeout,
373+
quit: make(chan struct{}),
374+
}
343375

344376
if err := c.writeGreeting(); err != nil {
345377
conn.Close()
@@ -373,36 +405,58 @@ func Subscribe(addr string, topics []string, timeout time.Duration) (*Conn, erro
373405
}
374406

375407
// Receive a message from the publisher. It blocks until a new message is
376-
// received.
408+
// received. If the connection times out and it was not explicitly terminated,
409+
// then a timeout error is returned. Otherwise, if it was explicitly terminated,
410+
// then io.EOF is returned.
377411
func (c *Conn) Receive() ([][]byte, error) {
378412
messages, err := c.readMessage()
379413
// If the error is either nil or a non-EOF error, we return it as-is.
380414
if err != io.EOF {
381415
return messages, err
382416
}
383-
// We got an EOF, so our socket is disconnected. We attempt to
384-
// reconnect. If successful, replace the existing connection with the
385-
// new one. Either way, return a timeout error.
417+
418+
// We got an EOF, so our socket is disconnected. If the connection was
419+
// explicitly terminated, we'll return the EOF error.
420+
select {
421+
case <-c.quit:
422+
return nil, io.EOF
423+
default:
424+
}
425+
426+
// Otherwise, we'll attempt to reconnect. If successful, we'll replace
427+
// the existing connection with the new one. Either way, return a
428+
// timeout error.
386429
errTimeout := &net.OpError{
387430
Op: "read",
388431
Net: c.conn.LocalAddr().Network(),
389432
Source: c.conn.LocalAddr(),
390433
Addr: c.conn.RemoteAddr(),
391434
Err: &reconnectError{err},
392435
}
393-
newConn, err := Subscribe(c.conn.RemoteAddr().String(), c.topics,
394-
c.timeout)
436+
newConn, err := Subscribe(
437+
c.conn.RemoteAddr().String(), c.topics, c.timeout,
438+
)
395439
if err != nil {
396440
// Prevent CPU overuse by refused reconnection attempts.
397441
time.Sleep(c.timeout)
398442
} else {
399-
c.Close()
400-
*c = *newConn
443+
c.conn.Close()
444+
c.conn = newConn.conn
401445
}
402446
return nil, errTimeout
403447
}
404448

405449
// Close the underlying connection. Any further operations will fail.
406450
func (c *Conn) Close() error {
407-
return c.conn.Close()
451+
var err error
452+
c.closeConn.Do(func() {
453+
close(c.quit)
454+
err = c.conn.Close()
455+
})
456+
return err
457+
}
458+
459+
// RemoteAddr returns the remote network address.
460+
func (c *Conn) RemoteAddr() net.Addr {
461+
return c.conn.RemoteAddr()
408462
}

0 commit comments

Comments
 (0)