Skip to content

Commit 85add85

Browse files
committed
Add context manager
1 parent de89b23 commit 85add85

File tree

4 files changed

+44
-6
lines changed

4 files changed

+44
-6
lines changed

pymongo/asynchronous/client_session.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,8 @@
139139
import time
140140
import uuid
141141
from collections.abc import Mapping as _Mapping
142-
from contextvars import ContextVar
142+
from contextlib import AbstractAsyncContextManager
143+
from contextvars import ContextVar, Token
143144
from typing import (
144145
TYPE_CHECKING,
145146
Any,
@@ -1071,6 +1072,24 @@ def __copy__(self) -> NoReturn:
10711072
_SESSION: ContextVar[Optional[AsyncClientSession]] = ContextVar("SESSION", default=None)
10721073

10731074

1075+
class _BindSession(AbstractAsyncContextManager):
1076+
def __init__(self, session: AsyncClientSession) -> None:
1077+
self.session = session
1078+
self.token: Optional[Token[Optional[AsyncClientSession]]] = None
1079+
1080+
async def __aenter__(self) -> None:
1081+
self.token = _SESSION.set(self.session)
1082+
1083+
async def __aexit__(
1084+
self,
1085+
exc_type: Optional[Type[BaseException]],
1086+
exc_val: Optional[BaseException],
1087+
exc_tb: Optional[TracebackType],
1088+
) -> Optional[bool]:
1089+
if self.token is not None:
1090+
_SESSION.reset(self.token)
1091+
1092+
10741093
class _EmptyServerSession:
10751094
__slots__ = "dirty", "started_retryable_write"
10761095

pymongo/asynchronous/mongo_client.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@
6565
from pymongo.asynchronous import client_session, database, uri_parser
6666
from pymongo.asynchronous.change_stream import AsyncChangeStream, AsyncClusterChangeStream
6767
from pymongo.asynchronous.client_bulk import _AsyncClientBulk
68-
from pymongo.asynchronous.client_session import _SESSION, _EmptyServerSession
68+
from pymongo.asynchronous.client_session import _BindSession, _EmptyServerSession
6969
from pymongo.asynchronous.command_cursor import AsyncCommandCursor
7070
from pymongo.asynchronous.settings import TopologySettings
7171
from pymongo.asynchronous.topology import Topology, _ErrorContext
@@ -1358,7 +1358,7 @@ def _start_session(self, implicit: bool, **kwargs: Any) -> AsyncClientSession:
13581358
bind = opts._bind
13591359
session = client_session.AsyncClientSession(self, server_session, opts, implicit)
13601360
if bind:
1361-
_SESSION.set(session)
1361+
session = _BindSession(session)
13621362
return session
13631363

13641364
def start_session(

pymongo/synchronous/client_session.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,8 @@
139139
import time
140140
import uuid
141141
from collections.abc import Mapping as _Mapping
142-
from contextvars import ContextVar
142+
from contextlib import AbstractContextManager
143+
from contextvars import ContextVar, Token
143144
from typing import (
144145
TYPE_CHECKING,
145146
Any,
@@ -1066,6 +1067,24 @@ def __copy__(self) -> NoReturn:
10661067
_SESSION: ContextVar[Optional[ClientSession]] = ContextVar("SESSION", default=None)
10671068

10681069

1070+
class _BindSession(AbstractContextManager):
1071+
def __init__(self, session: ClientSession) -> None:
1072+
self.session = session
1073+
self.token: Optional[Token[Optional[ClientSession]]] = None
1074+
1075+
def __enter__(self) -> None:
1076+
self.token = _SESSION.set(self.session)
1077+
1078+
def __exit__(
1079+
self,
1080+
exc_type: Optional[Type[BaseException]],
1081+
exc_val: Optional[BaseException],
1082+
exc_tb: Optional[TracebackType],
1083+
) -> Optional[bool]:
1084+
if self.token is not None:
1085+
_SESSION.reset(self.token)
1086+
1087+
10691088
class _EmptyServerSession:
10701089
__slots__ = "dirty", "started_retryable_write"
10711090

pymongo/synchronous/mongo_client.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@
107107
from pymongo.synchronous import client_session, database, uri_parser
108108
from pymongo.synchronous.change_stream import ChangeStream, ClusterChangeStream
109109
from pymongo.synchronous.client_bulk import _ClientBulk
110-
from pymongo.synchronous.client_session import _SESSION, _EmptyServerSession
110+
from pymongo.synchronous.client_session import _BindSession, _EmptyServerSession
111111
from pymongo.synchronous.command_cursor import CommandCursor
112112
from pymongo.synchronous.settings import TopologySettings
113113
from pymongo.synchronous.topology import Topology, _ErrorContext
@@ -1356,7 +1356,7 @@ def _start_session(self, implicit: bool, **kwargs: Any) -> ClientSession:
13561356
bind = opts._bind
13571357
session = client_session.ClientSession(self, server_session, opts, implicit)
13581358
if bind:
1359-
_SESSION.set(session)
1359+
session = _BindSession(session)
13601360
return session
13611361

13621362
def start_session(

0 commit comments

Comments
 (0)