Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ async def _put_managed_entry(
}

# Add TTL if present
if managed_entry.ttl is not None and managed_entry.created_at:
if managed_entry.ttl is not None and managed_entry.created_at is not None:
# DynamoDB TTL expects a Unix timestamp
ttl_timestamp = int(managed_entry.created_at.timestamp() + managed_entry.ttl)
item["ttl"] = {"N": str(ttl_timestamp)}
Expand Down
30 changes: 22 additions & 8 deletions key-value/key-value-aio/src/key_value/aio/stores/memory/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,10 +153,27 @@ async def _setup_collection(self, *, collection: str) -> None:
collection_cache = MemoryCollection(max_entries=self.max_entries_per_collection)
self._cache[collection] = collection_cache

def _get_collection_or_raise(self, collection: str) -> MemoryCollection:
"""Get a collection or raise KeyError if not setup.

Args:
collection: The collection name.

Returns:
The MemoryCollection instance.

Raises:
KeyError: If the collection has not been setup via setup_collection().
"""
collection_cache: MemoryCollection | None = self._cache.get(collection)
if collection_cache is None:
msg = f"Collection '{collection}' has not been setup. Call setup_collection() first."
raise KeyError(msg)
return collection_cache

@override
async def _get_managed_entry(self, *, key: str, collection: str) -> ManagedEntry | None:
collection_cache: MemoryCollection = self._cache[collection]

collection_cache = self._get_collection_or_raise(collection)
return collection_cache.get(key=key)

@override
Expand All @@ -167,20 +184,17 @@ async def _put_managed_entry(
collection: str,
managed_entry: ManagedEntry,
) -> None:
collection_cache: MemoryCollection = self._cache[collection]

collection_cache = self._get_collection_or_raise(collection)
collection_cache.put(key=key, value=managed_entry)

@override
async def _delete_managed_entry(self, *, key: str, collection: str) -> bool:
collection_cache: MemoryCollection = self._cache[collection]

collection_cache = self._get_collection_or_raise(collection)
return collection_cache.delete(key=key)

@override
async def _get_collection_keys(self, *, collection: str, limit: int | None = None) -> list[str]:
collection_cache: MemoryCollection = self._cache[collection]

collection_cache = self._get_collection_or_raise(collection)
return collection_cache.keys(limit=limit)

@override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ def __init__(
self._default_value_json = dump_to_json(obj=dict(default_value))
self._default_ttl = None if default_ttl is None else float(default_ttl)

super().__init__()

def _new_default_value(self) -> dict[str, Any]:
return load_from_json(json_str=self._default_value_json)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def _encrypt_value(self, value: dict[str, Any]) -> dict[str, Any]:
json_str: str = json.dumps(value, separators=(",", ":"))

json_bytes: bytes = json_str.encode(encoding="utf-8")
except (json.JSONDecodeError, TypeError) as e:
except TypeError as e:
msg: str = f"Failed to serialize object to JSON: {e}"
raise SerializationError(msg) from e

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,4 +101,5 @@ async def put_many(
filtered_keys.append(k)
filtered_values.append(v)

await self.key_value.put_many(keys=filtered_keys, values=filtered_values, collection=collection, ttl=ttl)
if filtered_keys:
await self.key_value.put_many(keys=filtered_keys, values=filtered_values, collection=collection, ttl=ttl)
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ def __init__(
self._routing_function = routing_function
self._default_store = default_store

super().__init__()

def _get_store(self, collection: str | None) -> AsyncKeyValue:
"""Get the appropriate store for the given collection.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,9 @@ def statistics(self) -> KVStoreStatistics:
async def get(self, key: str, *, collection: str | None = None) -> dict[str, Any] | None:
collection = collection or DEFAULT_COLLECTION_NAME

if value := await self.key_value.get(collection=collection, key=key):
value = await self.key_value.get(collection=collection, key=key)

if value is not None:
self.statistics.get_collection(collection=collection).get.increment_hit()
return value

Expand All @@ -127,7 +129,7 @@ async def ttl(self, key: str, *, collection: str | None = None) -> tuple[dict[st

value, ttl = await self.key_value.ttl(collection=collection, key=key)

if value:
if value is not None:
self.statistics.get_collection(collection=collection).ttl.increment_hit()
return value, ttl

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,12 @@


class TTLClampWrapper(BaseWrapper):
"""Wrapper that enforces a maximum TTL for puts into the store."""
"""Wrapper that enforces a maximum TTL for puts into the store.

This wrapper only modifies write operations (put, put_many). All read operations
(get, get_many, ttl, ttl_many, delete, delete_many) pass through unchanged to
the underlying store.
"""

def __init__(
self, key_value: AsyncKeyValue, min_ttl: SupportsFloat, max_ttl: SupportsFloat, missing_ttl: SupportsFloat | None = None
Expand Down