Skip to content

Commit 6235206

Browse files
authored
GODRIVER-2672 Return ServerError with "TransientTransactionError" label for server selection timeouts. (#1198)
1 parent 119a351 commit 6235206

File tree

3 files changed

+52
-7
lines changed

3 files changed

+52
-7
lines changed

mongo/database_test.go

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010
"context"
1111
"errors"
1212
"testing"
13+
"time"
1314

1415
"go.mongodb.org/mongo-driver/bson"
1516
"go.mongodb.org/mongo-driver/bson/bsoncodec"
@@ -18,6 +19,7 @@ import (
1819
"go.mongodb.org/mongo-driver/mongo/readconcern"
1920
"go.mongodb.org/mongo-driver/mongo/readpref"
2021
"go.mongodb.org/mongo-driver/mongo/writeconcern"
22+
"go.mongodb.org/mongo-driver/x/mongo/driver/topology"
2123
)
2224

2325
func setupDb(name string, opts ...*options.DatabaseOptions) *Database {
@@ -84,7 +86,6 @@ func TestDatabase(t *testing.T) {
8486
})
8587
t.Run("replace topology error", func(t *testing.T) {
8688
db := setupDb("foo")
87-
8889
err := db.RunCommand(bgCtx, bson.D{{"x", 1}}).Err()
8990
assert.Equal(t, ErrClientDisconnected, err, "expected error %v, got %v", ErrClientDisconnected, err)
9091

@@ -94,6 +95,43 @@ func TestDatabase(t *testing.T) {
9495
_, err = db.ListCollections(bgCtx, bson.D{})
9596
assert.Equal(t, ErrClientDisconnected, err, "expected error %v, got %v", ErrClientDisconnected, err)
9697
})
98+
t.Run("TransientTransactionError label", func(t *testing.T) {
99+
client := setupClient(options.Client().ApplyURI("mongodb://nonexistent").SetServerSelectionTimeout(3 * time.Second))
100+
err := client.Connect(bgCtx)
101+
defer client.Disconnect(bgCtx)
102+
assert.Nil(t, err, "expected nil, got %v", err)
103+
104+
t.Run("negative case of non-transaction", func(t *testing.T) {
105+
var sse topology.ServerSelectionError
106+
var le LabeledError
107+
108+
err := client.Ping(bgCtx, nil)
109+
assert.NotNil(t, err, "expected error, got nil")
110+
assert.True(t, errors.As(err, &sse), `expected error to be a "topology.ServerSelectionError"`)
111+
if errors.As(err, &le) {
112+
assert.False(t, le.HasErrorLabel("TransientTransactionError"), `expected error not to include the "TransientTransactionError" label`)
113+
}
114+
})
115+
116+
t.Run("positive case of transaction", func(t *testing.T) {
117+
var sse topology.ServerSelectionError
118+
var le LabeledError
119+
120+
sess, err := client.StartSession()
121+
assert.Nil(t, err, "expected nil, got %v", err)
122+
defer sess.EndSession(bgCtx)
123+
124+
sessCtx := NewSessionContext(bgCtx, sess)
125+
err = sess.StartTransaction()
126+
assert.Nil(t, err, "expected nil, got %v", err)
127+
128+
err = client.Ping(sessCtx, nil)
129+
assert.NotNil(t, err, "expected error, got nil")
130+
assert.True(t, errors.As(err, &sse), `expected error to be a "topology.ServerSelectionError"`)
131+
assert.True(t, errors.As(err, &le), `expected error to implement the "LabeledError" interface`)
132+
assert.True(t, le.HasErrorLabel("TransientTransactionError"), `expected error to include the "TransientTransactionError" label`)
133+
})
134+
})
97135
t.Run("nil document error", func(t *testing.T) {
98136
db := setupDb("foo")
99137

mongo/errors.go

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ func IsTimeout(err error) bool {
129129
return ne.Timeout()
130130
}
131131
//timeout error labels
132-
if le, ok := err.(labeledError); ok {
132+
if le, ok := err.(LabeledError); ok {
133133
if le.HasErrorLabel("NetworkTimeoutError") || le.HasErrorLabel("ExceededTimeLimitError") {
134134
return true
135135
}
@@ -153,7 +153,7 @@ func unwrap(err error) error {
153153
// errorHasLabel returns true if err contains the specified label
154154
func errorHasLabel(err error, label string) bool {
155155
for ; err != nil; err = unwrap(err) {
156-
if le, ok := err.(labeledError); ok && le.HasErrorLabel(label) {
156+
if le, ok := err.(LabeledError); ok && le.HasErrorLabel(label) {
157157
return true
158158
}
159159
}
@@ -207,7 +207,8 @@ func (e MongocryptdError) Unwrap() error {
207207
return e.Wrapped
208208
}
209209

210-
type labeledError interface {
210+
// LabeledError is an interface for errors with labels.
211+
type LabeledError interface {
211212
error
212213
// HasErrorLabel returns true if the error contains the specified label.
213214
HasErrorLabel(string) bool
@@ -216,11 +217,9 @@ type labeledError interface {
216217
// ServerError is the interface implemented by errors returned from the server. Custom implementations of this
217218
// interface should not be used in production.
218219
type ServerError interface {
219-
error
220+
LabeledError
220221
// HasErrorCode returns true if the error has the specified code.
221222
HasErrorCode(int) bool
222-
// HasErrorLabel returns true if the error contains the specified label.
223-
HasErrorLabel(string) bool
224223
// HasErrorMessage returns true if the error contains the specified message.
225224
HasErrorMessage(string) bool
226225
// HasErrorCodeWithMessage returns true if any of the contained errors have the specified code and message.

x/mongo/driver/operation.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,14 @@ func (op Operation) selectServer(ctx context.Context) (Server, error) {
319319
func (op Operation) getServerAndConnection(ctx context.Context) (Server, Connection, error) {
320320
server, err := op.selectServer(ctx)
321321
if err != nil {
322+
if op.Client != nil &&
323+
!(op.Client.Committing || op.Client.Aborting) && op.Client.TransactionRunning() {
324+
err = Error{
325+
Message: err.Error(),
326+
Labels: []string{TransientTransactionError},
327+
Wrapped: err,
328+
}
329+
}
322330
return nil, nil, err
323331
}
324332

0 commit comments

Comments
 (0)