Skip to content

Commit 3bf52d2

Browse files
committed
Fix OIDC reauthentication when a session is involved (mongodb#1719)
JAVA-5880
1 parent c59ab52 commit 3bf52d2

File tree

2 files changed

+30
-10
lines changed

2 files changed

+30
-10
lines changed

driver-core/src/main/com/mongodb/internal/connection/Authenticator.java

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,13 +103,18 @@ abstract void authenticateAsync(InternalConnection connection, ConnectionDescrip
103103
OperationContext operationContext, SingleResultCallback<Void> callback);
104104

105105
public void reauthenticate(final InternalConnection connection, final OperationContext operationContext) {
106-
authenticate(connection, connection.getDescription(), operationContext);
106+
authenticate(connection, connection.getDescription(), operationContextWithoutSession(operationContext));
107107
}
108108

109109
public void reauthenticateAsync(final InternalConnection connection, final OperationContext operationContext,
110110
final SingleResultCallback<Void> callback) {
111111
beginAsync().thenRun((c) -> {
112-
authenticateAsync(connection, connection.getDescription(), operationContext, c);
112+
authenticateAsync(connection, connection.getDescription(), operationContextWithoutSession(operationContext), c);
113113
}).finish(callback);
114114
}
115+
116+
private static OperationContext operationContextWithoutSession(final OperationContext operationContext) {
117+
return operationContext.withSessionContext(
118+
new ReadConcernAwareNoOpSessionContext(operationContext.getSessionContext().getReadConcern()));
119+
}
115120
}

driver-sync/src/test/functional/com/mongodb/internal/connection/OidcAuthenticationProseTests.java

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,12 @@
2424
import com.mongodb.MongoSecurityException;
2525
import com.mongodb.MongoSocketException;
2626
import com.mongodb.assertions.Assertions;
27+
import com.mongodb.client.ClientSession;
28+
import com.mongodb.client.FindIterable;
2729
import com.mongodb.client.Fixture;
2830
import com.mongodb.client.MongoClient;
2931
import com.mongodb.client.MongoClients;
32+
import com.mongodb.client.MongoCollection;
3033
import com.mongodb.client.TestListener;
3134
import com.mongodb.event.CommandListener;
3235
import com.mongodb.lang.Nullable;
@@ -334,12 +337,17 @@ public void test3p3UnexpectedErrorDoesNotClearCache() {
334337

335338
@Test
336339
public void test4p1Reauthentication() {
340+
testReauthentication(false);
341+
}
342+
343+
private void testReauthentication(final boolean inSession) {
337344
TestCallback callback = createCallback();
338345
MongoClientSettings clientSettings = createSettings(callback);
339-
try (MongoClient mongoClient = createMongoClient(clientSettings)) {
346+
try (MongoClient mongoClient = createMongoClient(clientSettings);
347+
ClientSession session = inSession ? mongoClient.startSession() : null) {
340348
failCommand(391, 1, "find");
341349
// #. Perform a find operation that succeeds.
342-
performFind(mongoClient);
350+
performFind(mongoClient, session);
343351
}
344352
assertEquals(2, callback.invocations.get());
345353
}
@@ -392,6 +400,11 @@ private static void performInsert(final MongoClient mongoClient) {
392400
.insertOne(Document.parse("{ x: 1 }"));
393401
}
394402

403+
@Test
404+
public void test4p5ReauthenticationInSession() {
405+
testReauthentication(true);
406+
}
407+
395408
@Test
396409
public void test5p1AzureSucceedsWithNoUsername() {
397410
assumeAzure();
@@ -914,12 +927,14 @@ private <T extends Throwable> void assertFindFails(
914927
}
915928
}
916929

917-
private void performFind(final MongoClient mongoClient) {
918-
mongoClient
919-
.getDatabase("test")
920-
.getCollection("test")
921-
.find()
922-
.first();
930+
private static void performFind(final MongoClient mongoClient) {
931+
performFind(mongoClient, null);
932+
}
933+
934+
private static void performFind(final MongoClient mongoClient, @Nullable final ClientSession session) {
935+
MongoCollection<Document> collection = mongoClient.getDatabase("test").getCollection("test");
936+
FindIterable<Document> findIterable = session == null ? collection.find() : collection.find(session);
937+
findIterable.first();
923938
}
924939

925940
protected void delayNextFind() {

0 commit comments

Comments
 (0)