Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 17 additions & 8 deletions pymongo/client_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -965,10 +965,12 @@ def _txn_read_preference(self) -> Optional[_ServerMode]:
return self._transaction.opts.read_preference
return None

def _materialize(self) -> None:
def _materialize(self, logical_session_timeout_minutes: float) -> None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

float -> int.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The original code has session_timeout_minutes as a float, is that an unfixed mistake?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah that was a mistake. It's int everywhere else.

if isinstance(self._server_session, _EmptyServerSession):
old = self._server_session
self._server_session = self._client._topology.get_server_session()
self._server_session = self._client._topology.get_server_session(
logical_session_timeout_minutes, old
)
if old.started_retryable_write:
self._server_session.inc_transaction_id()

Expand All @@ -979,8 +981,12 @@ def _apply_to(
read_preference: _ServerMode,
conn: Connection,
) -> None:
if not conn.supports_sessions:
if not self._implicit:
raise ConfigurationError("Sessions are not supported by this MongoDB deployment")
return
self._check_ended()
self._materialize()
self._materialize(conn.logical_session_timeout_minutes)
if self.options.snapshot:
self._update_read_concern(command, conn)

Expand Down Expand Up @@ -1032,11 +1038,12 @@ def __copy__(self) -> NoReturn:


class _EmptyServerSession:
__slots__ = "dirty", "started_retryable_write"
__slots__ = "dirty", "started_retryable_write", "session_id"

def __init__(self) -> None:
self.dirty = False
self.started_retryable_write = False
self.session_id = {"id": Binary(uuid.uuid4().bytes, 4)}

def mark_dirty(self) -> None:
self.dirty = True
Expand All @@ -1046,9 +1053,9 @@ def inc_transaction_id(self) -> None:


class _ServerSession:
def __init__(self, generation: int):
def __init__(self, generation: int, session_id: Optional[dict[str, Binary]] = None):
# Ensure id is type 4, regardless of CodecOptions.uuid_representation.
self.session_id = {"id": Binary(uuid.uuid4().bytes, 4)}
self.session_id = session_id or {"id": Binary(uuid.uuid4().bytes, 4)}
self.last_use = time.monotonic()
self._transaction_id = 0
self.dirty = False
Expand Down Expand Up @@ -1097,7 +1104,9 @@ def pop_all(self) -> list[_ServerSession]:
ids.append(self.pop().session_id)
return ids

def get_server_session(self, session_timeout_minutes: float) -> _ServerSession:
def get_server_session(
self, session_timeout_minutes: float, old: _EmptyServerSession
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

float -> int.

) -> _ServerSession:
# Although the Driver Sessions Spec says we only clear stale sessions
# in return_server_session, PyMongo can't take a lock when returning
# sessions from a __del__ method (like in Cursor.__die), so it can't
Expand All @@ -1111,7 +1120,7 @@ def get_server_session(self, session_timeout_minutes: float) -> _ServerSession:
if not s.timed_out(session_timeout_minutes):
return s

return _ServerSession(self.generation)
return _ServerSession(self.generation, old.session_id)
Copy link
Member

@ShaneHarvey ShaneHarvey Jan 31, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the cache is empty we'll create a new session with the same id but if the cache isn't empty we'll return an existing session with a different id. What problem are we trying to solve here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, good point if the cache is not empty. Our unified test runner expects that sessions have an id on creation, with several tests relying on that behavior. Adding a simple ping command after creation to populate the session id emits unexpected events that cause those same tests to fail.

Copy link
Member

@ShaneHarvey ShaneHarvey Jan 31, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I see. I think the only approach is that accessing the session_id property needs to materialize the session too either without checking the lstm field or using the topologies lstm (either way without doing server selection).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I don't quite understand what you mean. You're saying we need to have the session materialize when session_id is accessed, but without actually checking if sessions are supported?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. The session support check would still happen later in _apply_to.


def return_server_session(
self, server_session: _ServerSession, session_timeout_minutes: Optional[float]
Expand Down
22 changes: 8 additions & 14 deletions pymongo/mongo_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1750,12 +1750,7 @@ def _process_periodic_tasks(self) -> None:
helpers._handle_exception()

def __start_session(self, implicit: bool, **kwargs: Any) -> ClientSession:
# Raises ConfigurationError if sessions are not supported.
if implicit:
self._topology._check_implicit_session_support()
server_session: Union[_EmptyServerSession, _ServerSession] = _EmptyServerSession()
else:
server_session = self._get_server_session()
server_session = _EmptyServerSession()
opts = client_session.SessionOptions(**kwargs)
return client_session.ClientSession(self, server_session, opts, implicit)

Expand Down Expand Up @@ -1788,10 +1783,6 @@ def start_session(
snapshot=snapshot,
)

def _get_server_session(self) -> _ServerSession:
"""Internal: start or resume a _ServerSession."""
return self._topology.get_server_session()

def _return_server_session(
self, server_session: Union[_ServerSession, _EmptyServerSession], lock: bool
) -> None:
Expand Down Expand Up @@ -2393,12 +2384,15 @@ def _write(self) -> T:
try:
max_wire_version = 0
self._server = self._get_server()
supports_session = (
self._session is not None and self._server.description.retryable_writes_supported
)
with self._client._checkout(self._server, self._session) as conn:
max_wire_version = conn.max_wire_version
if self._retryable and not supports_session:
sessions_supported = (
self._retryable
and self._server.description.retryable_writes_supported
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we remove this line?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And just check conn.supports_sessions instead?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can remove two of the variables in the check, the remaining two are required for tests to pass.

and self._session
and conn.supports_sessions
)
if not sessions_supported:
# A retry is not possible because this server does
# not support sessions raise the last error.
self._check_last_error()
Expand Down
2 changes: 2 additions & 0 deletions pymongo/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -752,6 +752,7 @@ def __init__(
self.active = False
self.last_timeout = self.opts.socket_timeout
self.connect_rtt = 0.0
self.logical_session_timeout_minutes: float = 0.0

def set_conn_timeout(self, timeout: Optional[float]) -> None:
"""Cache last timeout to avoid duplicate calls to conn.settimeout."""
Expand Down Expand Up @@ -870,6 +871,7 @@ def _hello(
self.max_message_size = hello.max_message_size
self.max_write_batch_size = hello.max_write_batch_size
self.supports_sessions = hello.logical_session_timeout_minutes is not None
self.logical_session_timeout_minutes = hello.logical_session_timeout_minutes or 0.0
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or 0 since logical_session_timeout_minutes is an int.

self.hello_ok = hello.hello_ok
self.is_repl = hello.server_type in (
SERVER_TYPE.RSPrimary,
Expand Down
3 changes: 1 addition & 2 deletions pymongo/server_description.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,8 +243,7 @@ def is_server_type_known(self) -> bool:
def retryable_writes_supported(self) -> bool:
"""Checks if this server supports retryable writes."""
return (
self._ls_timeout_minutes is not None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we remove _ls_timeout_minutes and all code for it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe so, there's no more use for it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, according to the spec https://github.com/mongodb/specifications/blob/master/source/sessions/driver-sessions.rst, we still need to store the lowest logicalSessionTimeoutMinutes value in the ToplogyDescription.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, so we still need to track it so we can use the lowest value for the server session cache expiration.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bingo.

and self._server_type in (SERVER_TYPE.Mongos, SERVER_TYPE.RSPrimary)
self._server_type in (SERVER_TYPE.Mongos, SERVER_TYPE.RSPrimary)
) or self._server_type == SERVER_TYPE.LoadBalancer

@property
Expand Down
38 changes: 5 additions & 33 deletions pymongo/topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,8 @@
from typing import TYPE_CHECKING, Any, Callable, Mapping, Optional, cast

from pymongo import _csot, common, helpers, periodic_executor
from pymongo.client_session import _ServerSession, _ServerSessionPool
from pymongo.client_session import _EmptyServerSession, _ServerSession, _ServerSessionPool
from pymongo.errors import (
ConfigurationError,
ConnectionFailure,
InvalidOperation,
NetworkTimeout,
Expand All @@ -47,7 +46,6 @@
Selection,
any_server_selector,
arbiter_server_selector,
readable_server_selector,
secondary_server_selector,
writable_server_selector,
)
Expand Down Expand Up @@ -579,38 +577,12 @@ def pop_all_sessions(self) -> list[_ServerSession]:
with self._lock:
return self._session_pool.pop_all()

def _check_implicit_session_support(self) -> None:
with self._lock:
self._check_session_support()

def _check_session_support(self) -> float:
"""Internal check for session support on clusters."""
if self._settings.load_balanced:
# Sessions never time out in load balanced mode.
return float("inf")
session_timeout = self._description.logical_session_timeout_minutes
if session_timeout is None:
# Maybe we need an initial scan? Can raise ServerSelectionError.
if self._description.topology_type == TOPOLOGY_TYPE.Single:
if not self._description.has_known_servers:
self._select_servers_loop(
any_server_selector, self.get_server_selection_timeout(), None
)
elif not self._description.readable_servers:
self._select_servers_loop(
readable_server_selector, self.get_server_selection_timeout(), None
)

session_timeout = self._description.logical_session_timeout_minutes
if session_timeout is None:
raise ConfigurationError("Sessions are not supported by this MongoDB deployment")
return session_timeout

def get_server_session(self) -> _ServerSession:
def get_server_session(
self, session_timeout_minutes: float, old: _EmptyServerSession
) -> _ServerSession:
"""Start or resume a server session, or raise ConfigurationError."""
with self._lock:
session_timeout = self._check_session_support()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we remove _check_session_support?

return self._session_pool.get_server_session(session_timeout)
return self._session_pool.get_server_session(session_timeout_minutes, old)

def return_server_session(self, server_session: _ServerSession, lock: bool) -> None:
if lock:
Expand Down
55 changes: 55 additions & 0 deletions test/test_encryption.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from typing import Any, Dict, Mapping

from pymongo.collection import Collection
from pymongo.daemon import _spawn_daemon

sys.path[0:0] = [""]

Expand Down Expand Up @@ -2997,5 +2998,59 @@ def test_collection_name_collision(self):
self.assertIsInstance(exc.exception.encrypted_fields["fields"][0]["keyId"], Binary)


def start_mongocryptd(port) -> None:
args = ["mongocryptd", f"--port={port}", "--idleShutdownTimeoutSecs=60"]
_spawn_daemon(args)


class TestNoSessionsSupport(EncryptionIntegrationTest):
mongocryptd_client: MongoClient
MONGOCRYPTD_PORT = 27020

@classmethod
@unittest.skipIf(os.environ.get("TEST_CRYPT_SHARED"), "crypt_shared lib is installed")
def setUpClass(cls):
super().setUpClass()
start_mongocryptd(cls.MONGOCRYPTD_PORT)

@classmethod
def tearDownClass(cls):
super().tearDownClass()

def setUp(self) -> None:
self.listener = OvertCommandListener()
self.mongocryptd_client = MongoClient(
f"mongodb://localhost:{self.MONGOCRYPTD_PORT}", event_listeners=[self.listener]
)
self.addCleanup(self.mongocryptd_client.close)

hello = self.mongocryptd_client.db.command("hello")
self.assertNotIn("logicalSessionTimeoutMinutes", hello)

def test_implicit_session_ignored_when_unsupported(self):
self.listener.reset()
with self.assertRaises(OperationFailure):
self.mongocryptd_client.db.test.find_one()

self.assertNotIn("lsid", self.listener.started_events[0].command)

with self.assertRaises(OperationFailure):
self.mongocryptd_client.db.test.insert_one({"x": 1})

self.assertNotIn("lsid", self.listener.started_events[1].command)

def test_explicit_session_errors_when_unsupported(self):
self.listener.reset()
with self.mongocryptd_client.start_session() as s:
with self.assertRaisesRegex(
ConfigurationError, r"Sessions are not supported by this MongoDB deployment"
):
self.mongocryptd_client.db.test.find_one(session=s)
with self.assertRaisesRegex(
ConfigurationError, r"Sessions are not supported by this MongoDB deployment"
):
self.mongocryptd_client.db.test.insert_one({"x": 1}, session=s)


if __name__ == "__main__":
unittest.main()
10 changes: 6 additions & 4 deletions test/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,16 +116,12 @@ def _test_ops(self, client, *ops):

for f, args, kw in ops:
with client.start_session() as s:
last_use = s._server_session.last_use
start = time.monotonic()
self.assertLessEqual(last_use, start)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have any other tests for last_use? Are we loosing coverage here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, we can fix the issue by materializing the session before accessing last_use.

listener.reset()
# In case "f" modifies its inputs.
args = copy.copy(args)
kw = copy.copy(kw)
kw["session"] = s
f(*args, **kw)
self.assertGreaterEqual(s._server_session.last_use, start)
self.assertGreaterEqual(len(listener.started_events), 1)
for event in listener.started_events:
self.assertTrue(
Expand Down Expand Up @@ -239,16 +235,20 @@ def test_pool_lifo(self):
# "Pool is LIFO" test from Driver Sessions Spec.
a = self.client.start_session()
b = self.client.start_session()
self.client.admin.command("ping", session=a)
self.client.admin.command("ping", session=b)
a_id = a.session_id
b_id = b.session_id
a.end_session()
b.end_session()

s = self.client.start_session()
self.client.admin.command("ping", session=s)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we remove these pings now that s.session_id materializes the session?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, this test requires the sessions actually enter the pool, which materializing itself does not do. The test fails without these pings, sadly.

Copy link
Member

@ShaneHarvey ShaneHarvey Feb 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain? It appears this test is only checking session_id.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test checks that session a and b are both in the session pool after they are ended, with b at the top due to LIFO ordering. We expect the next session created to be b, but since session creation no longer interacts with the pool unless a timeout is present, materializing the new session instead of using a ping results in a new session entirely.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh that's a bug in get_server_session. After it's fixed we should remove the pings here.

self.assertEqual(b_id, s.session_id)
self.assertNotEqual(a_id, s.session_id)

s2 = self.client.start_session()
self.client.admin.command("ping", session=s2)
self.assertEqual(a_id, s2.session_id)
self.assertNotEqual(b_id, s2.session_id)

Expand All @@ -274,6 +274,8 @@ def test_end_sessions(self):
client = rs_or_single_client(event_listeners=[listener])
# Start many sessions.
sessions = [client.start_session() for _ in range(_MAX_END_SESSIONS + 1)]
for s in sessions:
client.admin.command("ping", session=s)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is gonna make the test take a long time. Can we instead just do s.session_id?

for s in sessions:
s.end_session()

Expand Down