Skip to content

Commit b7d0949

Browse files
refactor: stores register cleanup callbacks directly with exit stack
Simplified the BaseContextManagerStore lifecycle management pattern by: - Removing abstract _register_cleanup_callbacks method - Initializing AsyncExitStack in __init__ instead of lazily - Allowing stores to directly register cleanup callbacks via _exit_stack - Stores now call _exit_stack.push_async_callback() or _exit_stack.callback() in their _setup() method to register cleanup operations This gives stores direct control over exit stack management while keeping the base class responsible for entering and exiting the stack. Updated all stores that inherit from BaseContextManagerStore: - MongoDB, DynamoDB: register client context manager - Redis, Valkey, Memcached, Elasticsearch: register async close callbacks - RocksDB, Disk, MultiDisk, DuckDB: register sync close callbacks 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]>
1 parent 3b465db commit b7d0949

File tree

20 files changed

+147
-172
lines changed

20 files changed

+147
-172
lines changed

key-value/key-value-aio/src/key_value/aio/stores/base.py

Lines changed: 24 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -434,13 +434,14 @@ class BaseContextManagerStore(BaseStore, ABC):
434434
the constructor. This ensures the store does not manage the lifecycle of user-provided
435435
clients (i.e., does not close them).
436436
437-
Stores that have clients requiring context manager entry should override
438-
`_register_cleanup_callbacks(stack)` to register their cleanup with the provided exit stack.
439-
This method is only called when the store owns the client (client_provided_by_user=False).
437+
The base class provides an AsyncExitStack that stores can use to register cleanup
438+
callbacks. Stores should add their cleanup operations to the exit stack as needed.
439+
The base class handles entering and exiting the exit stack.
440440
"""
441441

442442
_client_provided_by_user: bool
443-
_exit_stack: AsyncExitStack | None
443+
_exit_stack: AsyncExitStack
444+
_exit_stack_entered: bool
444445

445446
def __init__(self, *, client_provided_by_user: bool = False, **kwargs: Any) -> None:
446447
"""Initialize the context manager store with client ownership configuration.
@@ -452,73 +453,44 @@ def __init__(self, *, client_provided_by_user: bool = False, **kwargs: Any) -> N
452453
**kwargs: Additional arguments to pass to the base store constructor.
453454
"""
454455
self._client_provided_by_user = client_provided_by_user
455-
self._exit_stack = None
456-
super().__init__(**kwargs)
457-
458-
async def _register_cleanup_callbacks(self, stack: AsyncExitStack) -> None:
459-
"""Register cleanup callbacks with the exit stack.
460-
461-
Stores should override this method to register their cleanup callbacks with the exit stack.
462-
This method is only called when the store owns the client (client_provided_by_user=False).
463-
464-
Examples:
465-
# For context manager clients:
466-
await stack.enter_async_context(self._client)
467-
468-
# For clients with close() methods:
469-
stack.push_async_callback(self._client.aclose)
470-
471-
Args:
472-
stack: The AsyncExitStack to register cleanup callbacks with.
473-
"""
474-
475-
async def _ensure_exit_stack(self) -> AsyncExitStack:
476-
"""Ensure the exit stack exists and register cleanup callbacks if needed.
477-
478-
Returns:
479-
The exit stack instance.
480-
"""
481-
if self._exit_stack is not None:
482-
return self._exit_stack
483-
484456
self._exit_stack = AsyncExitStack()
485-
await self._exit_stack.__aenter__()
486-
487-
# Register cleanup callbacks if we own the client
488-
if not self._client_provided_by_user:
489-
await self._register_cleanup_callbacks(self._exit_stack)
457+
self._exit_stack_entered = False
458+
super().__init__(**kwargs)
490459

491-
return self._exit_stack
460+
async def _ensure_exit_stack_entered(self) -> None:
461+
"""Ensure the exit stack has been entered."""
462+
if not self._exit_stack_entered:
463+
await self._exit_stack.__aenter__()
464+
self._exit_stack_entered = True
492465

493466
async def __aenter__(self) -> Self:
494-
# Create exit stack and enter client context
495-
await self._ensure_exit_stack()
467+
# Enter the exit stack
468+
await self._ensure_exit_stack_entered()
496469
await self.setup()
497470
return self
498471

499472
async def __aexit__(
500473
self, exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: TracebackType | None
501474
) -> None:
502475
# Close the exit stack, which handles all cleanup
503-
if self._exit_stack is not None:
504-
await self._exit_stack.aclose()
505-
self._exit_stack = None
476+
if self._exit_stack_entered:
477+
await self._exit_stack.__aexit__(exc_type, exc_value, traceback)
478+
self._exit_stack_entered = False
506479

507480
async def close(self) -> None:
508-
# Close the exit stack if it exists
509-
if self._exit_stack is not None:
481+
# Close the exit stack if it has been entered
482+
if self._exit_stack_entered:
510483
await self._exit_stack.aclose()
511-
self._exit_stack = None
484+
self._exit_stack_entered = False
512485

513486
async def setup(self) -> None:
514487
"""Initialize the store if not already initialized.
515488
516-
This override ensures that if a client needs context manager entry, the exit stack
517-
is created and the client is entered before the store's _setup() method is called.
518-
This allows stores to work correctly even when not used with `async with`.
489+
This override ensures the exit stack is entered before the store's _setup()
490+
method is called, allowing stores to register cleanup callbacks during setup.
519491
"""
520-
# Ensure exit stack exists and client context is entered (if needed)
521-
await self._ensure_exit_stack()
492+
# Ensure exit stack is entered
493+
await self._ensure_exit_stack_entered()
522494
# Call parent setup
523495
await super().setup()
524496

key-value/key-value-aio/src/key_value/aio/stores/disk/multi_store.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,11 @@ def default_disk_cache_factory(collection: str) -> Cache:
106106
stable_api=True,
107107
)
108108

109+
@override
110+
async def _setup(self) -> None:
111+
"""Register cache cleanup."""
112+
self._exit_stack.callback(self._sync_close)
113+
109114
@override
110115
async def _setup_collection(self, *, collection: str) -> None:
111116
self._cache[collection] = self._disk_cache_factory(collection)
@@ -148,8 +153,5 @@ def _sync_close(self) -> None:
148153
for cache in self._cache.values():
149154
cache.close()
150155

151-
async def _close(self) -> None:
152-
self._sync_close()
153-
154156
def __del__(self) -> None:
155157
self._sync_close()

key-value/key-value-aio/src/key_value/aio/stores/disk/store.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,12 @@ def __init__(
8686
stable_api=True,
8787
)
8888

89+
@override
90+
async def _setup(self) -> None:
91+
"""Register cache cleanup if we own the cache."""
92+
if not self._client_provided_by_user:
93+
self._exit_stack.callback(self._cache.close)
94+
8995
@override
9096
async def _get_managed_entry(self, *, key: str, collection: str) -> ManagedEntry | None:
9197
combo_key: str = compound_key(collection=collection, key=key)
@@ -126,9 +132,6 @@ async def _delete_managed_entry(self, *, key: str, collection: str) -> bool:
126132

127133
return self._cache.delete(key=combo_key, retry=True)
128134

129-
async def _close(self) -> None:
130-
self._cache.close()
131-
132135
def __del__(self) -> None:
133136
if not getattr(self, "_client_provided_by_user", False) and hasattr(self, "_cache"):
134137
self._cache.close()

key-value/key-value-aio/src/key_value/aio/stores/duckdb/store.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,16 @@ async def _setup(self) -> None:
263263
- Metadata queries without JSON deserialization
264264
- Native JSON column support for rich querying capabilities
265265
"""
266+
# Register connection cleanup if we own the connection
267+
if not self._client_provided_by_user:
268+
269+
def close_connection() -> None:
270+
if not self._is_closed:
271+
self._connection.close()
272+
self._is_closed = True
273+
274+
self._exit_stack.callback(close_connection)
275+
266276
# Create the main table for storing key-value entries
267277
self._connection.execute(self._get_create_table_sql())
268278

@@ -362,12 +372,6 @@ async def _delete_managed_entry(self, *, key: str, collection: str) -> bool:
362372
deleted_rows = result.fetchall()
363373
return len(deleted_rows) > 0
364374

365-
async def _close(self) -> None:
366-
"""Close the DuckDB connection."""
367-
if not self._is_closed:
368-
self._connection.close()
369-
self._is_closed = True
370-
371375
def __del__(self) -> None:
372376
"""Clean up the DuckDB connection on deletion."""
373377
try:

key-value/key-value-aio/src/key_value/aio/stores/dynamodb/store.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
from contextlib import AsyncExitStack
21
from datetime import datetime, timezone
32
from typing import TYPE_CHECKING, Any, overload
43

@@ -123,12 +122,6 @@ def __init__(
123122
client_provided_by_user=client_provided,
124123
)
125124

126-
@override
127-
async def _register_cleanup_callbacks(self, stack: AsyncExitStack) -> None:
128-
"""Register DynamoDB client cleanup with the exit stack."""
129-
if hasattr(self, "_raw_client"):
130-
self._client = await stack.enter_async_context(self._raw_client)
131-
132125
@property
133126
def _connected_client(self) -> DynamoDBClient:
134127
if not self._client:
@@ -139,6 +132,9 @@ def _connected_client(self) -> DynamoDBClient:
139132
@override
140133
async def _setup(self) -> None:
141134
"""Setup the DynamoDB client and ensure table exists."""
135+
# Register client cleanup if we own the client
136+
if not self._client_provided_by_user and hasattr(self, "_raw_client"):
137+
self._client = await self._exit_stack.enter_async_context(self._raw_client)
142138
try:
143139
await self._connected_client.describe_table(TableName=self._table_name) # pyright: ignore[reportUnknownMemberType]
144140
except self._connected_client.exceptions.ResourceNotFoundException: # pyright: ignore[reportUnknownMemberType]

key-value/key-value-aio/src/key_value/aio/stores/elasticsearch/store.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,10 @@ def __init__(
259259

260260
@override
261261
async def _setup(self) -> None:
262+
# Register client cleanup if we own the client
263+
if not self._client_provided_by_user:
264+
self._exit_stack.push_async_callback(self._client.close)
265+
262266
cluster_info = await self._client.options(ignore_status=404).info()
263267

264268
self._is_serverless = cluster_info.get("version", {}).get("build_flavor") == "serverless"
@@ -551,6 +555,3 @@ async def _cull(self) -> None:
551555
},
552556
},
553557
)
554-
555-
async def _close(self) -> None:
556-
await self._client.close()

key-value/key-value-aio/src/key_value/aio/stores/memcached/store.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,12 @@ def __init__(
8181
client_provided_by_user=client_provided,
8282
)
8383

84+
@override
85+
async def _setup(self) -> None:
86+
"""Register client cleanup if we own the client."""
87+
if not self._client_provided_by_user:
88+
self._exit_stack.push_async_callback(self._client.close)
89+
8490
@override
8591
async def _get_managed_entry(self, *, key: str, collection: str) -> ManagedEntry | None:
8692
combo_key: str = self._sanitize_key(compound_key(collection=collection, key=key))
@@ -151,6 +157,3 @@ async def _delete_managed_entry(self, *, key: str, collection: str) -> bool:
151157
async def _delete_store(self) -> bool:
152158
_ = await self._client.flush_all()
153159
return True
154-
155-
async def _close(self) -> None:
156-
await self._client.close()

key-value/key-value-aio/src/key_value/aio/stores/mongodb/store.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
from collections.abc import Sequence
2-
from contextlib import AsyncExitStack
32
from datetime import datetime, timezone
43
from typing import Any, overload
54

@@ -194,9 +193,10 @@ def __init__(
194193
)
195194

196195
@override
197-
async def _register_cleanup_callbacks(self, stack: AsyncExitStack) -> None:
198-
"""Register MongoDB client cleanup with the exit stack."""
199-
await stack.enter_async_context(self._client)
196+
async def _setup(self) -> None:
197+
"""Register client cleanup if we own the client."""
198+
if not self._client_provided_by_user:
199+
await self._exit_stack.enter_async_context(self._client)
200200

201201
@override
202202
async def _setup_collection(self, *, collection: str) -> None:

key-value/key-value-aio/src/key_value/aio/stores/redis/store.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -218,9 +218,12 @@ async def _get_collection_keys(self, *, collection: str, limit: int | None = Non
218218

219219
return get_keys_from_compound_keys(compound_keys=keys, collection=collection)
220220

221+
@override
222+
async def _setup(self) -> None:
223+
"""Register client cleanup if we own the client."""
224+
if not self._client_provided_by_user:
225+
self._exit_stack.push_async_callback(self._client.aclose)
226+
221227
@override
222228
async def _delete_store(self) -> bool:
223229
return await self._client.flushdb() # pyright: ignore[reportUnknownMemberType, reportAny]
224-
225-
async def _close(self) -> None:
226-
await self._client.aclose()

key-value/key-value-aio/src/key_value/aio/stores/rocksdb/store.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,8 +85,12 @@ def __init__(
8585
client_provided_by_user=client_provided,
8686
)
8787

88-
async def _close(self) -> None:
89-
self._close_and_flush()
88+
@override
89+
async def _setup(self) -> None:
90+
"""Register database cleanup if we own the database."""
91+
if not self._client_provided_by_user:
92+
# Register a callback to close and flush the database
93+
self._exit_stack.callback(self._close_and_flush)
9094

9195
def _close_and_flush(self) -> None:
9296
if not self._is_closed:

0 commit comments

Comments
 (0)