diff --git a/CHANGELOG.md b/CHANGELOG.md index 7930b34d1..dab110777 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,6 +22,8 @@ Versioning](http://semver.org/spec/v2.0.0.html) except to the first release. - Decimal package uses a test variable DecimalPrecision instead of a package-level variable decimalPrecision (#233) +- Flaky tests TestClientRequestObjectsWithContext and + TestClientIdRequestObjectWithContext (#244) ## [1.9.0] - 2022-11-02 diff --git a/connection.go b/connection.go index b657dc2ae..5537a0adb 100644 --- a/connection.go +++ b/connection.go @@ -1062,22 +1062,18 @@ func (conn *Connection) newFuture(ctx context.Context) (fut *Future) { } // This method removes a future from the internal queue if the context -// is "done" before the response is come. Such select logic is inspired -// from this thread: https://groups.google.com/g/golang-dev/c/jX4oQEls3uk +// is "done" before the response is come. func (conn *Connection) contextWatchdog(fut *Future, ctx context.Context) { select { case <-fut.done: + case <-ctx.Done(): + } + + select { + case <-fut.done: + return default: - select { - case <-ctx.Done(): - conn.cancelFuture(fut, fmt.Errorf("context is done")) - default: - select { - case <-fut.done: - case <-ctx.Done(): - conn.cancelFuture(fut, fmt.Errorf("context is done")) - } - } + conn.cancelFuture(fut, fmt.Errorf("context is done")) } } @@ -1093,11 +1089,9 @@ func (conn *Connection) send(req Request, streamId uint64) *Future { return fut default: } - } - conn.putFuture(fut, req, streamId) - if req.Ctx() != nil { go conn.contextWatchdog(fut, req.Ctx()) } + conn.putFuture(fut, req, streamId) return fut } @@ -1310,15 +1304,6 @@ func (conn *Connection) Do(req Request) *Future { return fut } } - if req.Ctx() != nil { - select { - case <-req.Ctx().Done(): - fut := NewFuture() - fut.SetError(fmt.Errorf("context is done")) - return fut - default: - } - } return conn.send(req, ignoreStreamId) } diff --git a/request_test.go b/request_test.go index a078f6514..70aa29f45 100644 --- a/request_test.go +++ b/request_test.go @@ -2,6 +2,7 @@ package tarantool_test import ( "bytes" + "context" "errors" "testing" "time" @@ -30,6 +31,11 @@ const validTimeout = 500 * time.Millisecond var validStmt *Prepared = &Prepared{StatementID: 1, Conn: &Connection{}} +var validProtocolInfo ProtocolInfo = ProtocolInfo{ + Version: ProtocolVersion(3), + Features: []ProtocolFeature{StreamsFeature}, +} + type ValidSchemeResolver struct { } @@ -184,6 +190,7 @@ func TestRequestsCodes(t *testing.T) { {req: NewBeginRequest(), code: BeginRequestCode}, {req: NewCommitRequest(), code: CommitRequestCode}, {req: NewRollbackRequest(), code: RollbackRequestCode}, + {req: NewIdRequest(validProtocolInfo), code: IdRequestCode}, {req: NewBroadcastRequest(validKey), code: CallRequestCode}, } @@ -216,6 +223,7 @@ func TestRequestsAsync(t *testing.T) { {req: NewBeginRequest(), async: false}, {req: NewCommitRequest(), async: false}, {req: NewRollbackRequest(), async: false}, + {req: NewIdRequest(validProtocolInfo), async: false}, {req: NewBroadcastRequest(validKey), async: false}, } @@ -226,6 +234,73 @@ func TestRequestsAsync(t *testing.T) { } } +func TestRequestsCtx_default(t *testing.T) { + tests := []struct { + req Request + expected context.Context + }{ + {req: NewSelectRequest(validSpace), expected: nil}, + {req: NewUpdateRequest(validSpace), expected: nil}, + {req: NewUpsertRequest(validSpace), expected: nil}, + {req: NewInsertRequest(validSpace), expected: nil}, + {req: NewReplaceRequest(validSpace), expected: nil}, + {req: NewDeleteRequest(validSpace), expected: nil}, + {req: NewCall16Request(validExpr), expected: nil}, + {req: NewCall17Request(validExpr), expected: nil}, + {req: NewEvalRequest(validExpr), expected: nil}, + {req: NewExecuteRequest(validExpr), expected: nil}, + {req: NewPingRequest(), expected: nil}, + {req: NewPrepareRequest(validExpr), expected: nil}, + {req: NewUnprepareRequest(validStmt), expected: nil}, + {req: NewExecutePreparedRequest(validStmt), expected: nil}, + {req: NewBeginRequest(), expected: nil}, + {req: NewCommitRequest(), expected: nil}, + {req: NewRollbackRequest(), expected: nil}, + {req: NewIdRequest(validProtocolInfo), expected: nil}, + {req: NewBroadcastRequest(validKey), expected: nil}, + } + + for _, test := range tests { + if ctx := test.req.Ctx(); ctx != test.expected { + t.Errorf("An invalid ctx %t, expected %t", ctx, test.expected) + } + } +} + +func TestRequestsCtx_setter(t *testing.T) { + ctx := context.Background() + tests := []struct { + req Request + expected context.Context + }{ + {req: NewSelectRequest(validSpace).Context(ctx), expected: ctx}, + {req: NewUpdateRequest(validSpace).Context(ctx), expected: ctx}, + {req: NewUpsertRequest(validSpace).Context(ctx), expected: ctx}, + {req: NewInsertRequest(validSpace).Context(ctx), expected: ctx}, + {req: NewReplaceRequest(validSpace).Context(ctx), expected: ctx}, + {req: NewDeleteRequest(validSpace).Context(ctx), expected: ctx}, + {req: NewCall16Request(validExpr).Context(ctx), expected: ctx}, + {req: NewCall17Request(validExpr).Context(ctx), expected: ctx}, + {req: NewEvalRequest(validExpr).Context(ctx), expected: ctx}, + {req: NewExecuteRequest(validExpr).Context(ctx), expected: ctx}, + {req: NewPingRequest().Context(ctx), expected: ctx}, + {req: NewPrepareRequest(validExpr).Context(ctx), expected: ctx}, + {req: NewUnprepareRequest(validStmt).Context(ctx), expected: ctx}, + {req: NewExecutePreparedRequest(validStmt).Context(ctx), expected: ctx}, + {req: NewBeginRequest().Context(ctx), expected: ctx}, + {req: NewCommitRequest().Context(ctx), expected: ctx}, + {req: NewRollbackRequest().Context(ctx), expected: ctx}, + {req: NewIdRequest(validProtocolInfo).Context(ctx), expected: ctx}, + {req: NewBroadcastRequest(validKey).Context(ctx), expected: ctx}, + } + + for _, test := range tests { + if ctx := test.req.Ctx(); ctx != test.expected { + t.Errorf("An invalid ctx %t, expected %t", ctx, test.expected) + } + } +} + func TestPingRequestDefaultValues(t *testing.T) { var refBuf bytes.Buffer diff --git a/tarantool_test.go b/tarantool_test.go index 96bf02254..82e531d37 100644 --- a/tarantool_test.go +++ b/tarantool_test.go @@ -2397,15 +2397,57 @@ func TestClientRequestObjectsWithPassedCanceledContext(t *testing.T) { } } +// waitCtxRequest waits for the WaitGroup in Body() call and returns +// the context from Ctx() call. The request helps us to make sure that +// the context's cancel() call is called before a response received. +type waitCtxRequest struct { + ctx context.Context + wg sync.WaitGroup +} + +func (req *waitCtxRequest) Code() int32 { + return NewPingRequest().Code() +} + +func (req *waitCtxRequest) Body(res SchemaResolver, enc *encoder) error { + req.wg.Wait() + return NewPingRequest().Body(res, enc) +} + +func (req *waitCtxRequest) Ctx() context.Context { + return req.ctx +} + +func (req *waitCtxRequest) Async() bool { + return NewPingRequest().Async() +} + func TestClientRequestObjectsWithContext(t *testing.T) { var err error conn := test_helpers.ConnectWithValidation(t, server, opts) defer conn.Close() ctx, cancel := context.WithCancel(context.Background()) - req := NewPingRequest().Context(ctx) - fut := conn.Do(req) + req := &waitCtxRequest{ctx: ctx} + req.wg.Add(1) + + var futWg sync.WaitGroup + var fut *Future + + futWg.Add(1) + go func() { + defer futWg.Done() + fut = conn.Do(req) + }() + cancel() + req.wg.Done() + + futWg.Wait() + if fut == nil { + t.Fatalf("fut must be not nil") + } + resp, err := fut.Get() if resp != nil { t.Fatalf("response must be nil") @@ -2973,24 +3015,6 @@ func TestClientIdRequestObjectWithPassedCanceledContext(t *testing.T) { require.Equal(t, err.Error(), "context is done") } -func TestClientIdRequestObjectWithContext(t *testing.T) { - var err error - conn := test_helpers.ConnectWithValidation(t, server, opts) - defer conn.Close() - - ctx, cancel := context.WithCancel(context.Background()) - req := NewIdRequest(ProtocolInfo{ - Version: ProtocolVersion(1), - Features: []ProtocolFeature{StreamsFeature}, - }).Context(ctx) //nolint - fut := conn.Do(req) - cancel() - resp, err := fut.Get() - require.Nilf(t, resp, "Response is empty") - require.NotNilf(t, err, "Error is not empty") - require.Equal(t, err.Error(), "context is done") -} - func TestConnectionProtocolInfoUnsupported(t *testing.T) { test_helpers.SkipIfIdSupported(t)