Skip to content

Commit b08040b

Browse files
refactor: simplify client ownership by initializing flag in base class
- Initialize _client_provided_by_user = False in BaseContextManagerStore.__init__() - Remove redundant = False assignments from all store implementations - Add ownership checks in _close() methods to prevent closing user-provided clients - Add ownership checks in __del__() methods where applicable (RocksDB, Disk) - Fixes test failures where stores were closing user-provided clients This simplifies the ownership pattern: stores only set the flag to True when users provide a client, relying on the base class default of False for internally-created clients. Co-authored-by: William Easton <[email protected]>
1 parent c6c6bea commit b08040b

File tree

18 files changed

+45
-42
lines changed

18 files changed

+45
-42
lines changed

key-value/key-value-aio/src/key_value/aio/stores/base.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -434,6 +434,11 @@ class BaseContextManagerStore(BaseStore, ABC):
434434

435435
_client_provided_by_user: bool
436436

437+
def __init__(self, **kwargs: Any) -> None:
438+
"""Initialize the context manager store with default client ownership."""
439+
self._client_provided_by_user = False
440+
super().__init__(**kwargs)
441+
437442
async def __aenter__(self) -> Self:
438443
await self.setup()
439444
return self

key-value/key-value-aio/src/key_value/aio/stores/disk/store.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,6 @@ def __init__(
7777
self._cache = Cache(directory=directory, size_limit=max_size)
7878
else:
7979
self._cache = Cache(directory=directory, eviction_policy="none")
80-
self._client_provided_by_user = False
8180

8281
self._stable_api = True
8382

@@ -125,7 +124,9 @@ async def _delete_managed_entry(self, *, key: str, collection: str) -> bool:
125124

126125
@override
127126
async def _close(self) -> None:
128-
self._cache.close()
127+
if not self._client_provided_by_user:
128+
self._cache.close()
129129

130130
def __del__(self) -> None:
131-
self._cache.close()
131+
if not self._client_provided_by_user:
132+
self._cache.close()

key-value/key-value-aio/src/key_value/aio/stores/duckdb/store.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,6 @@ def __init__(
154154
self._connection = duckdb.connect(":memory:")
155155
else:
156156
self._connection = duckdb.connect(database=database_path)
157-
self._client_provided_by_user = False
158157

159158
self._is_closed = False
160159
self._adapter = DuckDBSerializationAdapter()

key-value/key-value-aio/src/key_value/aio/stores/dynamodb/store.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,6 @@ def __init__(
116116
self._raw_client = session.client(service_name="dynamodb", endpoint_url=endpoint_url) # pyright: ignore[reportUnknownMemberType]
117117

118118
self._client = None
119-
self._client_provided_by_user = False
120119

121120
super().__init__(default_collection=default_collection)
122121

@@ -263,5 +262,5 @@ async def _delete_managed_entry(self, *, key: str, collection: str) -> bool:
263262
@override
264263
async def _close(self) -> None:
265264
"""Close the DynamoDB client."""
266-
if self._client:
265+
if self._client and not self._client_provided_by_user:
267266
await self._client.__aexit__(None, None, None) # pyright: ignore[reportUnknownMemberType]

key-value/key-value-aio/src/key_value/aio/stores/elasticsearch/store.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,6 @@ def __init__(
236236
self._client = AsyncElasticsearch(
237237
hosts=[url], api_key=api_key, http_compress=True, request_timeout=10, retry_on_timeout=True, max_retries=3
238238
)
239-
self._client_provided_by_user = False
240239
else:
241240
msg = "Either elasticsearch_client or url must be provided"
242241
raise ValueError(msg)
@@ -553,4 +552,5 @@ async def _cull(self) -> None:
553552

554553
@override
555554
async def _close(self) -> None:
556-
await self._client.close()
555+
if not self._client_provided_by_user:
556+
await self._client.close()

key-value/key-value-aio/src/key_value/aio/stores/memcached/store.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,6 @@ def __init__(
7373
self._client_provided_by_user = True
7474
else:
7575
self._client = Client(host=host, port=port)
76-
self._client_provided_by_user = False
7776

7877
super().__init__(
7978
default_collection=default_collection,
@@ -153,4 +152,5 @@ async def _delete_store(self) -> bool:
153152

154153
@override
155154
async def _close(self) -> None:
156-
await self._client.close()
155+
if not self._client_provided_by_user:
156+
await self._client.close()

key-value/key-value-aio/src/key_value/aio/stores/mongodb/store.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -174,11 +174,9 @@ def __init__(
174174
self._client_provided_by_user = True
175175
elif url:
176176
self._client = AsyncMongoClient(url)
177-
self._client_provided_by_user = False
178177
else:
179178
# Defaults to localhost
180179
self._client = AsyncMongoClient()
181-
self._client_provided_by_user = False
182180

183181
db_name = db_name or DEFAULT_DB
184182
coll_name = coll_name or DEFAULT_COLLECTION
@@ -345,4 +343,5 @@ async def _delete_collection(self, *, collection: str) -> bool:
345343

346344
@override
347345
async def _close(self) -> None:
348-
await self._client.close()
346+
if not self._client_provided_by_user:
347+
await self._client.close()

key-value/key-value-aio/src/key_value/aio/stores/redis/store.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,6 @@ def __init__(
7676
password=parsed_url.password or password,
7777
decode_responses=True,
7878
)
79-
self._client_provided_by_user = False
8079
else:
8180
self._client = Redis(
8281
host=host,
@@ -85,7 +84,6 @@ def __init__(
8584
password=password,
8685
decode_responses=True,
8786
)
88-
self._client_provided_by_user = False
8987

9088
self._stable_api = True
9189
self._adapter = BasicSerializationAdapter(date_format="isoformat", value_format="dict")
@@ -222,4 +220,5 @@ async def _delete_store(self) -> bool:
222220

223221
@override
224222
async def _close(self) -> None:
225-
await self._client.aclose()
223+
if not self._client_provided_by_user:
224+
await self._client.aclose()

key-value/key-value-aio/src/key_value/aio/stores/rocksdb/store.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,15 +76,15 @@ def __init__(
7676
opts.create_if_missing(True)
7777

7878
self._db = Rdict(str(path), options=opts)
79-
self._client_provided_by_user = False
8079

8180
self._is_closed = False
8281

8382
super().__init__(default_collection=default_collection)
8483

8584
@override
8685
async def _close(self) -> None:
87-
self._close_and_flush()
86+
if not self._client_provided_by_user:
87+
self._close_and_flush()
8888

8989
def _close_and_flush(self) -> None:
9090
if not self._is_closed:
@@ -190,4 +190,5 @@ async def _delete_managed_entries(self, *, keys: Sequence[str], collection: str)
190190
return deleted_count
191191

192192
def __del__(self) -> None:
193-
self._close_and_flush()
193+
if not self._client_provided_by_user:
194+
self._close_and_flush()

key-value/key-value-aio/src/key_value/aio/stores/valkey/store.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,6 @@ def __init__(
7474
credentials: ServerCredentials | None = ServerCredentials(password=password, username=username) if password else None
7575
self._client_config = GlideClientConfiguration(addresses=addresses, database_id=db, credentials=credentials)
7676
self._connected_client = None
77-
self._client_provided_by_user = False
7877

7978
self._stable_api = True
8079

@@ -165,6 +164,5 @@ async def _delete_managed_entries(self, *, keys: Sequence[str], collection: str)
165164

166165
@override
167166
async def _close(self) -> None:
168-
if self._connected_client is None:
169-
return
170-
await self._client.close()
167+
if self._connected_client is not None and not self._client_provided_by_user:
168+
await self._client.close()

0 commit comments

Comments
 (0)