Skip to content

Add context support #173

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ Versioning](http://semver.org/spec/v2.0.0.html) except to the first release.
- Support decimal type in msgpack (#96)
- Support datetime type in msgpack (#118)
- Prepared SQL statements (#117)
- Context support for request objects (#48)

### Changed

Expand Down
206 changes: 150 additions & 56 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package tarantool
import (
"bufio"
"bytes"
"context"
"errors"
"fmt"
"io"
Expand Down Expand Up @@ -125,8 +126,11 @@ type Connection struct {
c net.Conn
mutex sync.Mutex
// Schema contains schema loaded on connection.
Schema *Schema
Schema *Schema
// requestId contains the last request ID for requests with nil context.
requestId uint32
// contextRequestId contains the last request ID for requests with context.
contextRequestId uint32
// Greeting contains first message sent by Tarantool.
Greeting *Greeting

Expand All @@ -143,16 +147,56 @@ type Connection struct {

var _ = Connector(&Connection{}) // Check compatibility with connector interface.

type futureList struct {
first *Future
last **Future
}

func (list *futureList) findFuture(reqid uint32, fetch bool) *Future {
root := &list.first
for {
fut := *root
if fut == nil {
return nil
}
if fut.requestId == reqid {
if fetch {
*root = fut.next
if fut.next == nil {
list.last = root
} else {
fut.next = nil
}
}
return fut
}
root = &fut.next
}
}

func (list *futureList) addFuture(fut *Future) {
*list.last = fut
list.last = &fut.next
}

func (list *futureList) clear(err error, conn *Connection) {
fut := list.first
list.first = nil
list.last = &list.first
for fut != nil {
fut.SetError(err)
conn.markDone(fut)
fut, fut.next = fut.next, nil
}
}

type connShard struct {
rmut sync.Mutex
requests [requestsMap]struct {
first *Future
last **Future
}
bufmut sync.Mutex
buf smallWBuf
enc *msgpack.Encoder
_pad [16]uint64 //nolint: unused,structcheck
rmut sync.Mutex
requests [requestsMap]futureList
requestsWithCtx [requestsMap]futureList
bufmut sync.Mutex
buf smallWBuf
enc *msgpack.Encoder
}

// Greeting is a message sent by Tarantool on connect.
Expand All @@ -167,6 +211,11 @@ type Opts struct {
// push messages are received. If Timeout is zero, any request can be
// blocked infinitely.
// Also used to setup net.TCPConn.Set(Read|Write)Deadline.
//
// Pay attention, when using contexts with request objects,
// the timeout option for Connection does not affect the lifetime
// of the request. For those purposes use context.WithTimeout() as
// the root context.
Timeout time.Duration
// Timeout between reconnect attempts. If Reconnect is zero, no
// reconnect attempts will be made.
Expand Down Expand Up @@ -262,12 +311,13 @@ type SslOpts struct {
// and will not finish to make attempts on authorization failures.
func Connect(addr string, opts Opts) (conn *Connection, err error) {
conn = &Connection{
addr: addr,
requestId: 0,
Greeting: &Greeting{},
control: make(chan struct{}),
opts: opts,
dec: msgpack.NewDecoder(&smallBuf{}),
addr: addr,
requestId: 0,
contextRequestId: 1,
Greeting: &Greeting{},
control: make(chan struct{}),
opts: opts,
dec: msgpack.NewDecoder(&smallBuf{}),
}
maxprocs := uint32(runtime.GOMAXPROCS(-1))
if conn.opts.Concurrency == 0 || conn.opts.Concurrency > maxprocs*128 {
Expand All @@ -283,8 +333,11 @@ func Connect(addr string, opts Opts) (conn *Connection, err error) {
conn.shard = make([]connShard, conn.opts.Concurrency)
for i := range conn.shard {
shard := &conn.shard[i]
for j := range shard.requests {
shard.requests[j].last = &shard.requests[j].first
requestsLists := []*[requestsMap]futureList{&shard.requests, &shard.requestsWithCtx}
for _, requests := range requestsLists {
for j := range requests {
requests[j].last = &requests[j].first
}
}
}

Expand Down Expand Up @@ -387,6 +440,13 @@ func (conn *Connection) Handle() interface{} {
return conn.opts.Handle
}

func (conn *Connection) cancelFuture(fut *Future, err error) {
if fut = conn.fetchFuture(fut.requestId); fut != nil {
fut.SetError(err)
conn.markDone(fut)
}
}

func (conn *Connection) dial() (err error) {
var connection net.Conn
network := "tcp"
Expand Down Expand Up @@ -580,15 +640,10 @@ func (conn *Connection) closeConnection(neterr error, forever bool) (err error)
}
for i := range conn.shard {
conn.shard[i].buf.Reset()
requests := &conn.shard[i].requests
for pos := range requests {
fut := requests[pos].first
requests[pos].first = nil
requests[pos].last = &requests[pos].first
for fut != nil {
fut.SetError(neterr)
conn.markDone(fut)
fut, fut.next = fut.next, nil
requestsLists := []*[requestsMap]futureList{&conn.shard[i].requests, &conn.shard[i].requestsWithCtx}
for _, requests := range requestsLists {
for pos := range requests {
requests[pos].clear(neterr, conn)
}
}
}
Expand Down Expand Up @@ -721,7 +776,7 @@ func (conn *Connection) reader(r *bufio.Reader, c net.Conn) {
}
}

func (conn *Connection) newFuture() (fut *Future) {
func (conn *Connection) newFuture(ctx context.Context) (fut *Future) {
fut = NewFuture()
if conn.rlimit != nil && conn.opts.RLimitAction == RLimitDrop {
select {
Expand All @@ -736,7 +791,7 @@ func (conn *Connection) newFuture() (fut *Future) {
return
}
}
fut.requestId = conn.nextRequestId()
fut.requestId = conn.nextRequestId(ctx != nil)
shardn := fut.requestId & (conn.opts.Concurrency - 1)
shard := &conn.shard[shardn]
shard.rmut.Lock()
Expand All @@ -761,11 +816,20 @@ func (conn *Connection) newFuture() (fut *Future) {
return
}
pos := (fut.requestId / conn.opts.Concurrency) & (requestsMap - 1)
pair := &shard.requests[pos]
*pair.last = fut
pair.last = &fut.next
if conn.opts.Timeout > 0 {
fut.timeout = time.Since(epoch) + conn.opts.Timeout
if ctx != nil {
select {
case <-ctx.Done():
fut.SetError(fmt.Errorf("context is done"))
shard.rmut.Unlock()
return
default:
}
shard.requestsWithCtx[pos].addFuture(fut)
} else {
shard.requests[pos].addFuture(fut)
if conn.opts.Timeout > 0 {
fut.timeout = time.Since(epoch) + conn.opts.Timeout
}
}
shard.rmut.Unlock()
if conn.rlimit != nil && conn.opts.RLimitAction == RLimitWait {
Expand All @@ -785,12 +849,43 @@ func (conn *Connection) newFuture() (fut *Future) {
return
}

// This method removes a future from the internal queue if the context
// is "done" before the response is come. Such select logic is inspired
// from this thread: https://groups.google.com/g/golang-dev/c/jX4oQEls3uk
func (conn *Connection) contextWatchdog(fut *Future, ctx context.Context) {
select {
case <-fut.done:
default:
select {
case <-ctx.Done():
conn.cancelFuture(fut, fmt.Errorf("context is done"))
default:
select {
case <-fut.done:
case <-ctx.Done():
conn.cancelFuture(fut, fmt.Errorf("context is done"))
}
}
}
}

func (conn *Connection) send(req Request) *Future {
fut := conn.newFuture()
fut := conn.newFuture(req.Ctx())
if fut.ready == nil {
return fut
}
if req.Ctx() != nil {
select {
case <-req.Ctx().Done():
conn.cancelFuture(fut, fmt.Errorf("context is done"))
return fut
default:
}
}
conn.putFuture(fut, req)
if req.Ctx() != nil {
go conn.contextWatchdog(fut, req.Ctx())
}
return fut
}

Expand Down Expand Up @@ -877,25 +972,11 @@ func (conn *Connection) fetchFuture(reqid uint32) (fut *Future) {
func (conn *Connection) getFutureImp(reqid uint32, fetch bool) *Future {
shard := &conn.shard[reqid&(conn.opts.Concurrency-1)]
pos := (reqid / conn.opts.Concurrency) & (requestsMap - 1)
pair := &shard.requests[pos]
root := &pair.first
for {
fut := *root
if fut == nil {
return nil
}
if fut.requestId == reqid {
if fetch {
*root = fut.next
if fut.next == nil {
pair.last = root
} else {
fut.next = nil
}
}
return fut
}
root = &fut.next
// futures with even requests id belong to requests list with nil context
if reqid%2 == 0 {
return shard.requests[pos].findFuture(reqid, fetch)
} else {
return shard.requestsWithCtx[pos].findFuture(reqid, fetch)
}
}

Expand Down Expand Up @@ -984,8 +1065,12 @@ func (conn *Connection) read(r io.Reader) (response []byte, err error) {
return
}

func (conn *Connection) nextRequestId() (requestId uint32) {
return atomic.AddUint32(&conn.requestId, 1)
func (conn *Connection) nextRequestId(context bool) (requestId uint32) {
if context {
return atomic.AddUint32(&conn.contextRequestId, 2)
} else {
return atomic.AddUint32(&conn.requestId, 2)
}
}

// Do performs a request asynchronously on the connection.
Expand All @@ -1000,6 +1085,15 @@ func (conn *Connection) Do(req Request) *Future {
return fut
}
}
if req.Ctx() != nil {
select {
case <-req.Ctx().Done():
fut := NewFuture()
fut.SetError(fmt.Errorf("context is done"))
return fut
default:
}
}
return conn.send(req)
}

Expand Down
31 changes: 31 additions & 0 deletions example_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package tarantool_test

import (
"context"
"fmt"
"time"

Expand Down Expand Up @@ -691,3 +692,33 @@ func ExampleConnection_NewPrepared() {
fmt.Printf("Failed to prepare")
}
}

// To pass contexts to request objects, use the Context() method.
// Pay attention that when using context with request objects,
// the timeout option for Connection will not affect the lifetime
// of the request. For those purposes use context.WithTimeout() as
// the root context.
func ExamplePingRequest_Context() {
conn := example_connect()
defer conn.Close()

timeout := time.Nanosecond

// this way you may set the common timeout for requests with context
rootCtx, cancelRoot := context.WithTimeout(context.Background(), timeout)
defer cancelRoot()

// this context will be canceled with the root after commonTimeout
ctx, cancel := context.WithCancel(rootCtx)
defer cancel()

req := tarantool.NewPingRequest().Context(ctx)

// Ping a Tarantool instance to check connection.
resp, err := conn.Do(req).Get()
fmt.Println("Ping Resp", resp)
fmt.Println("Ping Error", err)
// Output:
// Ping Resp <nil>
// Ping Error context is done
}
Loading