Skip to content

Commit bd53fdc

Browse files
committed
wip better handling of reauth
1 parent 7bb678f commit bd53fdc

File tree

7 files changed

+119
-54
lines changed

7 files changed

+119
-54
lines changed

pymongo/auth.py

-2
Original file line numberDiff line numberDiff line change
@@ -589,8 +589,6 @@ def authenticate(credentials, sock_info, reauthenticate=False):
589589
"""Authenticate sock_info."""
590590
mechanism = credentials.mechanism
591591
auth_func = _AUTH_MAP[mechanism]
592-
if reauthenticate:
593-
sock_info.handle_reauthenticate()
594592
if mechanism == "MONGODB-OIDC":
595593
_authenticate_oidc(credentials, sock_info, reauthenticate)
596594
else:

pymongo/bulk.py

-1
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,6 @@ def _execute_command(
305305
run.op_type,
306306
self.collection.codec_options,
307307
)
308-
309308
while run.idx_offset < len(run.ops):
310309
# If this is the last possible operation, use the
311310
# final write concern.

pymongo/message.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
ProtocolError,
5555
)
5656
from pymongo.hello import HelloCompat
57+
from pymongo.helpers import _REAUTHENTICATION_REQUIRED_CODE
5758
from pymongo.read_preferences import ReadPreference
5859
from pymongo.write_concern import WriteConcern
5960

@@ -839,7 +840,13 @@ def _batch_command(self, cmd, docs):
839840

840841
def execute(self, cmd, docs, client):
841842
request_id, msg, to_send = self._batch_command(cmd, docs)
842-
result = self.write_command(cmd, request_id, msg, to_send)
843+
try:
844+
result = self.write_command(cmd, request_id, msg, to_send)
845+
except OperationFailure as exc:
846+
if exc.code == _REAUTHENTICATION_REQUIRED_CODE:
847+
self.sock_info.authenticate(True)
848+
result = self.write_command(cmd, request_id, msg, to_send)
849+
raise
843850
client._process_response(result, self.session)
844851
return result, to_send
845852

pymongo/mongo_client.py

+2-16
Original file line numberDiff line numberDiff line change
@@ -1397,14 +1397,7 @@ def is_retrying():
13971397
assert last_error is not None
13981398
raise last_error
13991399
retryable = False
1400-
# Handle re-authentication.
1401-
try:
1402-
return func(session, sock_info, retryable)
1403-
except OperationFailure as exc:
1404-
if exc.code == helpers._REAUTHENTICATION_REQUIRED_CODE:
1405-
sock_info.authenticate(reauthenticate=True)
1406-
return func(session, sock_info, retryable)
1407-
raise
1400+
return func(session, sock_info, retryable)
14081401
except ServerSelectionTimeoutError:
14091402
if is_retrying():
14101403
# The application may think the write was never attempted
@@ -1468,14 +1461,7 @@ def _retryable_read(self, func, read_pref, session, address=None, retryable=True
14681461
# not support retryable reads, raise the last error.
14691462
assert last_error is not None
14701463
raise last_error
1471-
# Handle re-authentication.
1472-
try:
1473-
return func(session, server, sock_info, read_pref)
1474-
except OperationFailure as exc:
1475-
if exc.code == helpers._REAUTHENTICATION_REQUIRED_CODE:
1476-
sock_info.authenticate(reauthenticate=True)
1477-
return func(session, server, sock_info, read_pref)
1478-
raise
1464+
return func(session, server, sock_info, read_pref)
14791465
except ServerSelectionTimeoutError:
14801466
if retrying:
14811467
# The application may think the write was never attempted

pymongo/pool.py

+39-32
Original file line numberDiff line numberDiff line change
@@ -763,32 +763,40 @@ def command(
763763
unacknowledged = write_concern and not write_concern.acknowledged
764764
if self.op_msg_enabled:
765765
self._raise_if_not_writable(unacknowledged)
766+
args = (
767+
self,
768+
dbname,
769+
spec,
770+
self.is_mongos,
771+
read_preference,
772+
codec_options,
773+
session,
774+
client,
775+
check,
776+
allowable_errors,
777+
self.address,
778+
listeners,
779+
self.max_bson_size,
780+
read_concern,
781+
)
782+
kwargs = dict(
783+
parse_write_concern_error=parse_write_concern_error,
784+
collation=collation,
785+
compression_ctx=self.compression_context,
786+
use_op_msg=self.op_msg_enabled,
787+
unacknowledged=unacknowledged,
788+
user_fields=user_fields,
789+
exhaust_allowed=exhaust_allowed,
790+
write_concern=write_concern,
791+
)
766792
try:
767-
return command(
768-
self,
769-
dbname,
770-
spec,
771-
self.is_mongos,
772-
read_preference,
773-
codec_options,
774-
session,
775-
client,
776-
check,
777-
allowable_errors,
778-
self.address,
779-
listeners,
780-
self.max_bson_size,
781-
read_concern,
782-
parse_write_concern_error=parse_write_concern_error,
783-
collation=collation,
784-
compression_ctx=self.compression_context,
785-
use_op_msg=self.op_msg_enabled,
786-
unacknowledged=unacknowledged,
787-
user_fields=user_fields,
788-
exhaust_allowed=exhaust_allowed,
789-
write_concern=write_concern,
790-
)
791-
except (OperationFailure, NotPrimaryError):
793+
return command(*args, **kwargs)
794+
except OperationFailure as exc:
795+
if exc.code == helpers._REAUTHENTICATION_REQUIRED_CODE:
796+
self.authenticate(True)
797+
return command(*args, **kwargs)
798+
raise
799+
except NotPrimaryError:
792800
raise
793801
# Catch socket.error, KeyboardInterrupt, etc. and close ourselves.
794802
except BaseException as error:
@@ -864,7 +872,12 @@ def authenticate(self, reauthenticate=False):
864872
"""
865873
# CMAP spec says to publish the ready event only after authenticating
866874
# the connection.
867-
if not self.ready or reauthenticate:
875+
if reauthenticate:
876+
self.ready = False
877+
if self.performed_handshake:
878+
# Existing auth_ctx is stale, remove it.
879+
self.auth_ctx = None
880+
if not self.ready:
868881
creds = self.opts._credentials
869882
if creds:
870883
auth.authenticate(creds, self, reauthenticate=reauthenticate)
@@ -927,12 +940,6 @@ def idle_time_seconds(self):
927940
"""Seconds since this socket was last checked into its pool."""
928941
return time.monotonic() - self.last_checkin_time
929942

930-
def handle_reauthenticate(self):
931-
"""Handle a reauthentication."""
932-
if self.performed_handshake:
933-
# Existing auth_ctx is stale, remove it.
934-
self.auth_ctx = None
935-
936943
def _raise_connection_failure(self, error):
937944
# Catch *all* exceptions from socket methods and close the socket. In
938945
# regular Python, socket operations only raise socket.error, even if

pymongo/server.py

+12-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
from bson import _decode_all_selective
2020
from pymongo.errors import NotPrimaryError, OperationFailure
21-
from pymongo.helpers import _check_command_response
21+
from pymongo.helpers import _REAUTHENTICATION_REQUIRED_CODE, _check_command_response
2222
from pymongo.message import _convert_exception, _OpMsg
2323
from pymongo.response import PinnedResponse, Response
2424

@@ -87,6 +87,17 @@ def run_operation(self, sock_info, operation, read_preference, listeners, unpack
8787
- `listeners`: Instance of _EventListeners or None.
8888
- `unpack_res`: A callable that decodes the wire protocol response.
8989
"""
90+
try:
91+
return self._run_operation(sock_info, operation, read_preference, listeners, unpack_res)
92+
except OperationFailure as exc:
93+
if exc.code == _REAUTHENTICATION_REQUIRED_CODE:
94+
sock_info.authenticate(True)
95+
return self._run_operation(
96+
sock_info, operation, read_preference, listeners, unpack_res
97+
)
98+
raise
99+
100+
def _run_operation(self, sock_info, operation, read_preference, listeners, unpack_res):
90101
duration = None
91102
publish = listeners.enabled_for_commands
92103
if publish:

test/auth_aws/test_auth_oidc.py

+58-1
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from pymongo.auth import MongoCredential
3232
from pymongo.auth_oidc import _CACHE as _oidc_cache
3333
from pymongo.errors import ConfigurationError, OperationFailure
34+
from pymongo.operations import InsertOne
3435

3536

3637
class TestAuthOIDC(unittest.TestCase):
@@ -529,7 +530,63 @@ def test_reauthenticate_succeeds(self):
529530
self.assertEqual(self.refresh_called, 1)
530531
client.close()
531532

532-
def test_reauthenticate_retries_and_succees_with_cache(self):
533+
def test_reauthenticate_succeeds_bulk_write(self):
534+
request_cb = self.create_request_cb()
535+
refresh_cb = self.create_refresh_cb()
536+
537+
# Create a client with the callbacks.
538+
props: Dict = dict(request_token_callback=request_cb, refresh_token_callback=refresh_cb)
539+
client = MongoClient(self.uri_single, authmechanismproperties=props)
540+
541+
# Perform a find operation.
542+
client.test.test.find_one()
543+
544+
# Assert that the refresh callback has not been called.
545+
self.assertEqual(self.refresh_called, 0)
546+
547+
with self.fail_point(
548+
{
549+
"mode": {"times": 2},
550+
"data": {"failCommands": ["insert", "saslStart"], "errorCode": 391},
551+
}
552+
):
553+
# Perform a bulk write operation.
554+
client.test.test.bulk_write([InsertOne({})])
555+
556+
# Assert that the refresh callback has been called.
557+
self.assertEqual(self.refresh_called, 1)
558+
client.close()
559+
560+
def test_reauthenticate_succeeds_cursor(self):
561+
request_cb = self.create_request_cb()
562+
refresh_cb = self.create_refresh_cb()
563+
564+
# Create a client with the callbacks.
565+
props: Dict = dict(request_token_callback=request_cb, refresh_token_callback=refresh_cb)
566+
client = MongoClient(self.uri_single, authmechanismproperties=props)
567+
568+
# Perform an insert operation.
569+
client.test.test.insert_one({"a": 1})
570+
571+
# Assert that the refresh callback has not been called.
572+
self.assertEqual(self.refresh_called, 0)
573+
574+
with self.fail_point(
575+
{
576+
"mode": {"times": 2},
577+
"data": {"failCommands": ["find", "saslStart"], "errorCode": 391},
578+
}
579+
):
580+
# Perform a find operation.
581+
cursor = client.test.test.find({"a": 1})
582+
583+
self.assertGreaterEqual(len(list(cursor)), 1)
584+
585+
# Assert that the refresh callback has been called.
586+
self.assertEqual(self.refresh_called, 1)
587+
client.close()
588+
589+
def test_reauthenticate_retries_and_succeeds_with_cache(self):
533590
listener = EventListener()
534591

535592
# Create request and refresh callbacks that return valid credentials

0 commit comments

Comments
 (0)