diff --git a/pymongo/client_session.py b/pymongo/client_session.py index e23ab4ad13..e2626d843c 100644 --- a/pymongo/client_session.py +++ b/pymongo/client_session.py @@ -554,8 +554,15 @@ def options(self) -> SessionOptions: def session_id(self) -> Mapping[str, Any]: """A BSON document, the opaque server session identifier.""" self._check_ended() + self._materialize(self._client.topology_description.logical_session_timeout_minutes) return self._server_session.session_id + @property + def _transaction_id(self) -> Int64: + """The current transaction id for the underlying server session.""" + self._materialize(self._client.topology_description.logical_session_timeout_minutes) + return self._server_session.transaction_id + @property def cluster_time(self) -> Optional[ClusterTime]: """The cluster time returned by the last operation executed @@ -965,10 +972,12 @@ def _txn_read_preference(self) -> Optional[_ServerMode]: return self._transaction.opts.read_preference return None - def _materialize(self) -> None: + def _materialize(self, logical_session_timeout_minutes: Optional[int] = None) -> None: if isinstance(self._server_session, _EmptyServerSession): old = self._server_session - self._server_session = self._client._topology.get_server_session() + self._server_session = self._client._topology.get_server_session( + logical_session_timeout_minutes + ) if old.started_retryable_write: self._server_session.inc_transaction_id() @@ -979,8 +988,12 @@ def _apply_to( read_preference: _ServerMode, conn: Connection, ) -> None: + if not conn.supports_sessions: + if not self._implicit: + raise ConfigurationError("Sessions are not supported by this MongoDB deployment") + return self._check_ended() - self._materialize() + self._materialize(conn.logical_session_timeout_minutes) if self.options.snapshot: self._update_read_concern(command, conn) @@ -1062,7 +1075,10 @@ def mark_dirty(self) -> None: """ self.dirty = True - def timed_out(self, session_timeout_minutes: float) -> bool: + def timed_out(self, session_timeout_minutes: Optional[int]) -> bool: + if session_timeout_minutes is None: + return False + idle_seconds = time.monotonic() - self.last_use # Timed out if we have less than a minute to live. @@ -1097,7 +1113,7 @@ def pop_all(self) -> list[_ServerSession]: ids.append(self.pop().session_id) return ids - def get_server_session(self, session_timeout_minutes: float) -> _ServerSession: + def get_server_session(self, session_timeout_minutes: Optional[int]) -> _ServerSession: # Although the Driver Sessions Spec says we only clear stale sessions # in return_server_session, PyMongo can't take a lock when returning # sessions from a __del__ method (like in Cursor.__die), so it can't @@ -1114,7 +1130,7 @@ def get_server_session(self, session_timeout_minutes: float) -> _ServerSession: return _ServerSession(self.generation) def return_server_session( - self, server_session: _ServerSession, session_timeout_minutes: Optional[float] + self, server_session: _ServerSession, session_timeout_minutes: Optional[int] ) -> None: if session_timeout_minutes is not None: self._clear_stale(session_timeout_minutes) @@ -1128,7 +1144,7 @@ def return_server_session_no_lock(self, server_session: _ServerSession) -> None: if server_session.generation == self.generation and not server_session.dirty: self.appendleft(server_session) - def _clear_stale(self, session_timeout_minutes: float) -> None: + def _clear_stale(self, session_timeout_minutes: Optional[int]) -> None: # Clear stale sessions. The least recently used are on the right. while self: if self[-1].timed_out(session_timeout_minutes): diff --git a/pymongo/mongo_client.py b/pymongo/mongo_client.py index 4991d59f10..83b64af853 100644 --- a/pymongo/mongo_client.py +++ b/pymongo/mongo_client.py @@ -1750,12 +1750,7 @@ def _process_periodic_tasks(self) -> None: helpers._handle_exception() def __start_session(self, implicit: bool, **kwargs: Any) -> ClientSession: - # Raises ConfigurationError if sessions are not supported. - if implicit: - self._topology._check_implicit_session_support() - server_session: Union[_EmptyServerSession, _ServerSession] = _EmptyServerSession() - else: - server_session = self._get_server_session() + server_session = _EmptyServerSession() opts = client_session.SessionOptions(**kwargs) return client_session.ClientSession(self, server_session, opts, implicit) @@ -1788,10 +1783,6 @@ def start_session( snapshot=snapshot, ) - def _get_server_session(self) -> _ServerSession: - """Internal: start or resume a _ServerSession.""" - return self._topology.get_server_session() - def _return_server_session( self, server_session: Union[_ServerSession, _EmptyServerSession], lock: bool ) -> None: @@ -2393,12 +2384,14 @@ def _write(self) -> T: try: max_wire_version = 0 self._server = self._get_server() - supports_session = ( - self._session is not None and self._server.description.retryable_writes_supported - ) with self._client._checkout(self._server, self._session) as conn: max_wire_version = conn.max_wire_version - if self._retryable and not supports_session: + sessions_supported = ( + self._session + and self._server.description.retryable_writes_supported + and conn.supports_sessions + ) + if not sessions_supported: # A retry is not possible because this server does # not support sessions raise the last error. self._check_last_error() diff --git a/pymongo/pool.py b/pymongo/pool.py index fb7b45bc5f..528fa7f50a 100644 --- a/pymongo/pool.py +++ b/pymongo/pool.py @@ -869,7 +869,10 @@ def _hello( self.max_bson_size = hello.max_bson_size self.max_message_size = hello.max_message_size self.max_write_batch_size = hello.max_write_batch_size - self.supports_sessions = hello.logical_session_timeout_minutes is not None + self.supports_sessions = ( + hello.logical_session_timeout_minutes is not None and hello.is_readable + ) + self.logical_session_timeout_minutes: Optional[int] = hello.logical_session_timeout_minutes self.hello_ok = hello.hello_ok self.is_repl = hello.server_type in ( SERVER_TYPE.RSPrimary, diff --git a/pymongo/server_description.py b/pymongo/server_description.py index 7943f4f5c8..6393fce0a1 100644 --- a/pymongo/server_description.py +++ b/pymongo/server_description.py @@ -243,8 +243,7 @@ def is_server_type_known(self) -> bool: def retryable_writes_supported(self) -> bool: """Checks if this server supports retryable writes.""" return ( - self._ls_timeout_minutes is not None - and self._server_type in (SERVER_TYPE.Mongos, SERVER_TYPE.RSPrimary) + self._server_type in (SERVER_TYPE.Mongos, SERVER_TYPE.RSPrimary) ) or self._server_type == SERVER_TYPE.LoadBalancer @property diff --git a/pymongo/topology.py b/pymongo/topology.py index b5afc31b2b..7f1085d9f3 100644 --- a/pymongo/topology.py +++ b/pymongo/topology.py @@ -27,7 +27,6 @@ from pymongo import _csot, common, helpers, periodic_executor from pymongo.client_session import _ServerSession, _ServerSessionPool from pymongo.errors import ( - ConfigurationError, ConnectionFailure, InvalidOperation, NetworkTimeout, @@ -47,7 +46,6 @@ Selection, any_server_selector, arbiter_server_selector, - readable_server_selector, secondary_server_selector, writable_server_selector, ) @@ -579,38 +577,10 @@ def pop_all_sessions(self) -> list[_ServerSession]: with self._lock: return self._session_pool.pop_all() - def _check_implicit_session_support(self) -> None: - with self._lock: - self._check_session_support() - - def _check_session_support(self) -> float: - """Internal check for session support on clusters.""" - if self._settings.load_balanced: - # Sessions never time out in load balanced mode. - return float("inf") - session_timeout = self._description.logical_session_timeout_minutes - if session_timeout is None: - # Maybe we need an initial scan? Can raise ServerSelectionError. - if self._description.topology_type == TOPOLOGY_TYPE.Single: - if not self._description.has_known_servers: - self._select_servers_loop( - any_server_selector, self.get_server_selection_timeout(), None - ) - elif not self._description.readable_servers: - self._select_servers_loop( - readable_server_selector, self.get_server_selection_timeout(), None - ) - - session_timeout = self._description.logical_session_timeout_minutes - if session_timeout is None: - raise ConfigurationError("Sessions are not supported by this MongoDB deployment") - return session_timeout - - def get_server_session(self) -> _ServerSession: + def get_server_session(self, session_timeout_minutes: Optional[int]) -> _ServerSession: """Start or resume a server session, or raise ConfigurationError.""" with self._lock: - session_timeout = self._check_session_support() - return self._session_pool.get_server_session(session_timeout) + return self._session_pool.get_server_session(session_timeout_minutes) def return_server_session(self, server_session: _ServerSession, lock: bool) -> None: if lock: diff --git a/test/test_encryption.py b/test/test_encryption.py index 04982fa9cf..0afaebabda 100644 --- a/test/test_encryption.py +++ b/test/test_encryption.py @@ -30,6 +30,7 @@ from typing import Any, Dict, Mapping from pymongo.collection import Collection +from pymongo.daemon import _spawn_daemon sys.path[0:0] = [""] @@ -2997,5 +2998,59 @@ def test_collection_name_collision(self): self.assertIsInstance(exc.exception.encrypted_fields["fields"][0]["keyId"], Binary) +def start_mongocryptd(port) -> None: + args = ["mongocryptd", f"--port={port}", "--idleShutdownTimeoutSecs=60"] + _spawn_daemon(args) + + +class TestNoSessionsSupport(EncryptionIntegrationTest): + mongocryptd_client: MongoClient + MONGOCRYPTD_PORT = 27020 + + @classmethod + @unittest.skipIf(os.environ.get("TEST_CRYPT_SHARED"), "crypt_shared lib is installed") + def setUpClass(cls): + super().setUpClass() + start_mongocryptd(cls.MONGOCRYPTD_PORT) + + @classmethod + def tearDownClass(cls): + super().tearDownClass() + + def setUp(self) -> None: + self.listener = OvertCommandListener() + self.mongocryptd_client = MongoClient( + f"mongodb://localhost:{self.MONGOCRYPTD_PORT}", event_listeners=[self.listener] + ) + self.addCleanup(self.mongocryptd_client.close) + + hello = self.mongocryptd_client.db.command("hello") + self.assertNotIn("logicalSessionTimeoutMinutes", hello) + + def test_implicit_session_ignored_when_unsupported(self): + self.listener.reset() + with self.assertRaises(OperationFailure): + self.mongocryptd_client.db.test.find_one() + + self.assertNotIn("lsid", self.listener.started_events[0].command) + + with self.assertRaises(OperationFailure): + self.mongocryptd_client.db.test.insert_one({"x": 1}) + + self.assertNotIn("lsid", self.listener.started_events[1].command) + + def test_explicit_session_errors_when_unsupported(self): + self.listener.reset() + with self.mongocryptd_client.start_session() as s: + with self.assertRaisesRegex( + ConfigurationError, r"Sessions are not supported by this MongoDB deployment" + ): + self.mongocryptd_client.db.test.find_one(session=s) + with self.assertRaisesRegex( + ConfigurationError, r"Sessions are not supported by this MongoDB deployment" + ): + self.mongocryptd_client.db.test.insert_one({"x": 1}, session=s) + + if __name__ == "__main__": unittest.main() diff --git a/test/test_retryable_writes.py b/test/test_retryable_writes.py index 98bf0e5c94..ccc6b12e01 100644 --- a/test/test_retryable_writes.py +++ b/test/test_retryable_writes.py @@ -439,7 +439,7 @@ def test_batch_splitting_retry_fails(self): ) self.listener.reset() with self.client.start_session() as session: - initial_txn = session._server_session._transaction_id + initial_txn = session._transaction_id try: coll.bulk_write( [ @@ -467,7 +467,7 @@ def test_batch_splitting_retry_fails(self): started[1].command.pop("$clusterTime") started[2].command.pop("$clusterTime") self.assertEqual(started[1].command, started[2].command) - final_txn = session._server_session._transaction_id + final_txn = session._transaction_id self.assertEqual(final_txn, expected_txn) self.assertEqual(coll.find_one(projection={"_id": True}), {"_id": 1}) @@ -561,7 +561,7 @@ def test_RetryableWriteError_error_label_RawBSONDocument(self): "insert", "testcoll", documents=[{"_id": 1}], - txnNumber=s._server_session.transaction_id, + txnNumber=s._transaction_id, session=s, codec_options=DEFAULT_CODEC_OPTIONS.with_options( document_class=RawBSONDocument @@ -712,7 +712,7 @@ def raise_connection_err_select_server(*args, **kwargs): kwargs = copy.deepcopy(kwargs) kwargs["session"] = session msg = f"{method.__name__}(*{args!r}, **{kwargs!r})" - initial_txn_id = session._server_session.transaction_id + initial_txn_id = session._transaction_id # Each operation should fail on the first attempt and succeed # on the second. @@ -720,7 +720,7 @@ def raise_connection_err_select_server(*args, **kwargs): self.assertEqual(len(listener.started_events), 1, msg) retry_cmd = listener.started_events[0].command sent_txn_id = retry_cmd["txnNumber"] - final_txn_id = session._server_session.transaction_id + final_txn_id = session._transaction_id self.assertEqual(Int64(initial_txn_id + 1), sent_txn_id, msg) self.assertEqual(sent_txn_id, final_txn_id, msg) diff --git a/test/test_session.py b/test/test_session.py index c95691be15..c5cf77b754 100644 --- a/test/test_session.py +++ b/test/test_session.py @@ -116,16 +116,16 @@ def _test_ops(self, client, *ops): for f, args, kw in ops: with client.start_session() as s: + listener.reset() + s._materialize() last_use = s._server_session.last_use start = time.monotonic() self.assertLessEqual(last_use, start) - listener.reset() # In case "f" modifies its inputs. args = copy.copy(args) kw = copy.copy(kw) kw["session"] = s f(*args, **kw) - self.assertGreaterEqual(s._server_session.last_use, start) self.assertGreaterEqual(len(listener.started_events), 1) for event in listener.started_events: self.assertTrue( @@ -274,6 +274,8 @@ def test_end_sessions(self): client = rs_or_single_client(event_listeners=[listener]) # Start many sessions. sessions = [client.start_session() for _ in range(_MAX_END_SESSIONS + 1)] + for s in sessions: + s._materialize() for s in sessions: s.end_session()