Skip to content

Commit e3e28fa

Browse files
committed
api: add context to connection create
`connection.Connect` and `pool.Connect` no longer return non-working connection objects. Those functions now accept context as their first arguments, which user may cancel in process. `connection.Connect` will block until either the working connection created (and returned), `opts.MaxReconnects` creation attempts were made (returns error) or the context is canceled by user (returns error too). Closes #136
1 parent d8df65d commit e3e28fa

29 files changed

+406
-176
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@ Versioning](http://semver.org/spec/v2.0.0.html) except to the first release.
2828
decoded to a varbinary object (#313).
2929
- Use objects of the Decimal type instead of pointers (#238)
3030
- Use objects of the Datetime type instead of pointers (#238)
31+
- `connection.Connect` and `pool.Connect` no longer return non-working
32+
connection objects (#136). Those functions now accept context as their first
33+
arguments, which user may cancel in process.
3134

3235
### Deprecated
3336

README.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,13 +105,15 @@ about what it does.
105105
package tarantool
106106

107107
import (
108+
"context"
108109
"fmt"
109110
"github.com/tarantool/go-tarantool/v2"
110111
)
111112

112113
func main() {
113114
opts := tarantool.Opts{User: "guest"}
114-
conn, err := tarantool.Connect("127.0.0.1:3301", opts)
115+
ctx := context.Background()
116+
conn, err := tarantool.Connect(ctx, "127.0.0.1:3301", opts)
115117
if err != nil {
116118
fmt.Println("Connection refused:", err)
117119
}

connection.go

Lines changed: 32 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -381,10 +381,11 @@ func (opts Opts) Clone() Opts {
381381
// - If opts.Reconnect is zero (default), then connection either already connected
382382
// or error is returned.
383383
//
384-
// - If opts.Reconnect is non-zero, then error will be returned only if authorization
385-
// fails. But if Tarantool is not reachable, then it will make an attempt to reconnect later
386-
// and will not finish to make attempts on authorization failures.
387-
func Connect(addr string, opts Opts) (conn *Connection, err error) {
384+
// - If opts.Reconnect is non-zero, then error will be returned if authorization
385+
// fails, or user has canceled context. If Tarantool is not reachable, then it
386+
// will make attempts to reconnect and will not finish to make attempts on
387+
// authorization failures.
388+
func Connect(ctx context.Context, addr string, opts Opts) (conn *Connection, err error) {
388389
conn = &Connection{
389390
addr: addr,
390391
requestId: 0,
@@ -432,25 +433,8 @@ func Connect(addr string, opts Opts) (conn *Connection, err error) {
432433

433434
conn.cond = sync.NewCond(&conn.mutex)
434435

435-
if err = conn.createConnection(false); err != nil {
436-
ter, ok := err.(Error)
437-
if conn.opts.Reconnect <= 0 {
438-
return nil, err
439-
} else if ok && (ter.Code == iproto.ER_NO_SUCH_USER ||
440-
ter.Code == iproto.ER_CREDS_MISMATCH) {
441-
// Reported auth errors immediately.
442-
return nil, err
443-
} else {
444-
// Without SkipSchema it is useless.
445-
go func(conn *Connection) {
446-
conn.mutex.Lock()
447-
defer conn.mutex.Unlock()
448-
if err := conn.createConnection(true); err != nil {
449-
conn.closeConnection(err, true)
450-
}
451-
}(conn)
452-
err = nil
453-
}
436+
if err = conn.createConnection(ctx, false); err != nil {
437+
return nil, err
454438
}
455439

456440
go conn.pinger()
@@ -534,18 +518,19 @@ func (conn *Connection) cancelFuture(fut *Future, err error) {
534518
}
535519
}
536520

537-
func (conn *Connection) dial() (err error) {
521+
func (conn *Connection) dial(ctx context.Context) (err error) {
538522
opts := conn.opts
539523
dialTimeout := opts.Reconnect / 2
540524
if dialTimeout == 0 {
541525
dialTimeout = 500 * time.Millisecond
542526
} else if dialTimeout > 5*time.Second {
543527
dialTimeout = 5 * time.Second
544528
}
529+
nestedCtx, cancel := context.WithTimeout(ctx, dialTimeout)
530+
defer cancel()
545531

546532
var c Conn
547-
c, err = conn.opts.Dialer.Dial(conn.addr, DialOpts{
548-
DialTimeout: dialTimeout,
533+
c, err = conn.opts.Dialer.Dial(nestedCtx, conn.addr, DialOpts{
549534
IoTimeout: opts.Timeout,
550535
Transport: opts.Transport,
551536
Ssl: opts.Ssl,
@@ -658,34 +643,46 @@ func pack(h *smallWBuf, enc *msgpack.Encoder, reqid uint32,
658643
return
659644
}
660645

661-
func (conn *Connection) createConnection(reconnect bool) (err error) {
646+
func (conn *Connection) createConnection(ctx context.Context,
647+
reconnect bool) error {
662648
var reconnects uint
663649
for conn.c == nil && conn.state == connDisconnected {
664650
now := time.Now()
665-
err = conn.dial()
651+
err := conn.dial(ctx)
666652
if err == nil || !reconnect {
667653
if err == nil {
668654
conn.notify(Connected)
669655
}
670-
return
656+
return err
671657
}
672658
if conn.opts.MaxReconnects > 0 && reconnects > conn.opts.MaxReconnects {
673659
conn.opts.Logger.Report(LogLastReconnectFailed, conn, err)
674-
err = ClientError{ErrConnectionClosed, "last reconnect failed"}
675660
// mark connection as closed to avoid reopening by another goroutine
676-
return
661+
return ClientError{ErrConnectionClosed, "last reconnect failed"}
677662
}
678663
conn.opts.Logger.Report(LogReconnectFailed, conn, reconnects, err)
679664
conn.notify(ReconnectFailed)
680665
reconnects++
681666
conn.mutex.Unlock()
682-
time.Sleep(time.Until(now.Add(conn.opts.Reconnect)))
667+
668+
timer := time.NewTimer(time.Until(now.Add(conn.opts.Reconnect)))
669+
waitLoop:
670+
for {
671+
select {
672+
case <-ctx.Done():
673+
conn.mutex.Lock()
674+
return ClientError{ErrConnectionClosed, "operation was canceled"}
675+
case <-timer.C:
676+
break waitLoop
677+
}
678+
}
679+
683680
conn.mutex.Lock()
684681
}
685682
if conn.state == connClosed {
686-
err = ClientError{ErrConnectionClosed, "using closed connection"}
683+
return ClientError{ErrConnectionClosed, "using closed connection"}
687684
}
688-
return
685+
return nil
689686
}
690687

691688
func (conn *Connection) closeConnection(neterr error, forever bool) (err error) {
@@ -731,7 +728,7 @@ func (conn *Connection) reconnectImpl(neterr error, c Conn) {
731728
if conn.opts.Reconnect > 0 {
732729
if c == conn.c {
733730
conn.closeConnection(neterr, false)
734-
if err := conn.createConnection(true); err != nil {
731+
if err := conn.createConnection(context.Background(), true); err != nil {
735732
conn.closeConnection(err, true)
736733
}
737734
}

crud/example_test.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package crud_test
22

33
import (
4+
"context"
45
"fmt"
56
"reflect"
67
"time"
@@ -21,7 +22,7 @@ var exampleOpts = tarantool.Opts{
2122
}
2223

2324
func exampleConnect() *tarantool.Connection {
24-
conn, err := tarantool.Connect(exampleServer, exampleOpts)
25+
conn, err := tarantool.Connect(context.Background(), exampleServer, exampleOpts)
2526
if err != nil {
2627
panic("Connection is not established: " + err.Error())
2728
}

crud/tarantool_test.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package crud_test
22

33
import (
4+
"context"
45
"fmt"
56
"log"
67
"os"
@@ -108,7 +109,7 @@ var object = crud.MapObject{
108109

109110
func connect(t testing.TB) *tarantool.Connection {
110111
for i := 0; i < 10; i++ {
111-
conn, err := tarantool.Connect(server, opts)
112+
conn, err := tarantool.Connect(context.Background(), server, opts)
112113
if err != nil {
113114
t.Fatalf("Failed to connect: %s", err)
114115
}

datetime/example_test.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
package datetime_test
1010

1111
import (
12+
"context"
1213
"fmt"
1314
"time"
1415

@@ -23,7 +24,7 @@ func Example() {
2324
User: "test",
2425
Pass: "test",
2526
}
26-
conn, err := tarantool.Connect("127.0.0.1:3013", opts)
27+
conn, err := tarantool.Connect(context.Background(), "127.0.0.1:3013", opts)
2728
if err != nil {
2829
fmt.Printf("Error in connect is %v", err)
2930
return

decimal/example_test.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
package decimal_test
1010

1111
import (
12+
"context"
1213
"log"
1314
"time"
1415

@@ -28,7 +29,7 @@ func Example() {
2829
User: "test",
2930
Pass: "test",
3031
}
31-
client, err := tarantool.Connect(server, opts)
32+
client, err := tarantool.Connect(context.Background(), server, opts)
3233
if err != nil {
3334
log.Fatalf("Failed to connect: %s", err.Error())
3435
}

dial.go

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package tarantool
33
import (
44
"bufio"
55
"bytes"
6+
"context"
67
"errors"
78
"fmt"
89
"io"
@@ -56,8 +57,6 @@ type Conn interface {
5657

5758
// DialOpts is a way to configure a Dial method to create a new Conn.
5859
type DialOpts struct {
59-
// DialTimeout is a timeout for an initial network dial.
60-
DialTimeout time.Duration
6160
// IoTimeout is a timeout per a network read/write.
6261
IoTimeout time.Duration
6362
// Transport is a connect transport type.
@@ -86,7 +85,7 @@ type DialOpts struct {
8685
type Dialer interface {
8786
// Dial connects to a Tarantool instance to the address with specified
8887
// options.
89-
Dial(address string, opts DialOpts) (Conn, error)
88+
Dial(ctx context.Context, address string, opts DialOpts) (Conn, error)
9089
}
9190

9291
type tntConn struct {
@@ -104,11 +103,11 @@ type TtDialer struct {
104103

105104
// Dial connects to a Tarantool instance to the address with specified
106105
// options.
107-
func (t TtDialer) Dial(address string, opts DialOpts) (Conn, error) {
106+
func (t TtDialer) Dial(ctx context.Context, address string, opts DialOpts) (Conn, error) {
108107
var err error
109108
conn := new(tntConn)
110109

111-
if conn.net, err = dial(address, opts); err != nil {
110+
if conn.net, err = dial(ctx, address, opts); err != nil {
112111
return nil, fmt.Errorf("failed to dial: %w", err)
113112
}
114113

@@ -199,13 +198,14 @@ func (c *tntConn) ProtocolInfo() ProtocolInfo {
199198
}
200199

201200
// dial connects to a Tarantool instance.
202-
func dial(address string, opts DialOpts) (net.Conn, error) {
201+
func dial(ctx context.Context, address string, opts DialOpts) (net.Conn, error) {
203202
network, address := parseAddress(address)
204203
switch opts.Transport {
205204
case dialTransportNone:
206-
return net.DialTimeout(network, address, opts.DialTimeout)
205+
dialer := net.Dialer{}
206+
return dialer.DialContext(ctx, network, address)
207207
case dialTransportSsl:
208-
return sslDialTimeout(network, address, opts.DialTimeout, opts.Ssl)
208+
return sslDialContext(ctx, network, address, opts.Ssl)
209209
default:
210210
return nil, fmt.Errorf("unsupported transport type: %s", opts.Transport)
211211
}

dial_test.go

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package tarantool_test
22

33
import (
44
"bytes"
5+
"context"
56
"errors"
67
"net"
78
"sync"
@@ -18,7 +19,7 @@ type mockErrorDialer struct {
1819
err error
1920
}
2021

21-
func (m mockErrorDialer) Dial(address string,
22+
func (m mockErrorDialer) Dial(ctx context.Context, address string,
2223
opts tarantool.DialOpts) (tarantool.Conn, error) {
2324
return nil, m.err
2425
}
@@ -29,9 +30,10 @@ func TestDialer_Dial_error(t *testing.T) {
2930
err: errors.New(errMsg),
3031
}
3132

32-
conn, err := tarantool.Connect("any", tarantool.Opts{
33-
Dialer: dialer,
34-
})
33+
conn, err := tarantool.Connect(context.Background(), "any",
34+
tarantool.Opts{
35+
Dialer: dialer,
36+
})
3537
assert.Nil(t, conn)
3638
assert.ErrorContains(t, err, errMsg)
3739
}
@@ -41,7 +43,7 @@ type mockPassedDialer struct {
4143
opts tarantool.DialOpts
4244
}
4345

44-
func (m *mockPassedDialer) Dial(address string,
46+
func (m *mockPassedDialer) Dial(ctx context.Context, address string,
4547
opts tarantool.DialOpts) (tarantool.Conn, error) {
4648
m.address = address
4749
m.opts = opts
@@ -51,9 +53,8 @@ func (m *mockPassedDialer) Dial(address string,
5153
func TestDialer_Dial_passedOpts(t *testing.T) {
5254
const addr = "127.0.0.1:8080"
5355
opts := tarantool.DialOpts{
54-
DialTimeout: 500 * time.Millisecond,
55-
IoTimeout: 2,
56-
Transport: "any",
56+
IoTimeout: 2,
57+
Transport: "any",
5758
Ssl: tarantool.SslOpts{
5859
KeyFile: "a",
5960
CertFile: "b",
@@ -73,7 +74,9 @@ func TestDialer_Dial_passedOpts(t *testing.T) {
7374
}
7475

7576
dialer := &mockPassedDialer{}
76-
conn, err := tarantool.Connect(addr, tarantool.Opts{
77+
ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
78+
defer cancel()
79+
conn, err := tarantool.Connect(ctx, addr, tarantool.Opts{
7780
Dialer: dialer,
7881
Timeout: opts.IoTimeout,
7982
Transport: opts.Transport,
@@ -187,7 +190,7 @@ func newMockIoConn() *mockIoConn {
187190
return conn
188191
}
189192

190-
func (m *mockIoDialer) Dial(address string,
193+
func (m *mockIoDialer) Dial(ctx context.Context, address string,
191194
opts tarantool.DialOpts) (tarantool.Conn, error) {
192195
m.conn = newMockIoConn()
193196
if m.init != nil {
@@ -203,11 +206,12 @@ func dialIo(t *testing.T,
203206
dialer := mockIoDialer{
204207
init: init,
205208
}
206-
conn, err := tarantool.Connect("any", tarantool.Opts{
207-
Dialer: &dialer,
208-
Timeout: 1000 * time.Second, // Avoid pings.
209-
SkipSchema: true,
210-
})
209+
conn, err := tarantool.Connect(context.Background(), "any",
210+
tarantool.Opts{
211+
Dialer: &dialer,
212+
Timeout: 1000 * time.Second, // Avoid pings.
213+
SkipSchema: true,
214+
})
211215
require.Nil(t, err)
212216
require.NotNil(t, conn)
213217

example_custom_unpacking_test.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package tarantool_test
22

33
import (
4+
"context"
45
"fmt"
56
"log"
67
"time"
@@ -84,7 +85,7 @@ func Example_customUnpacking() {
8485
User: "test",
8586
Pass: "test",
8687
}
87-
conn, err := tarantool.Connect(server, opts)
88+
conn, err := tarantool.Connect(context.Background(), server, opts)
8889
if err != nil {
8990
log.Fatalf("Failed to connect: %s", err.Error())
9091
}

0 commit comments

Comments
 (0)