Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,19 @@


def managed_entry_to_document(collection: str, key: str, managed_entry: ManagedEntry) -> dict[str, Any]:
"""Convert a ManagedEntry to an Elasticsearch document.

This function creates an Elasticsearch document containing the collection, key,
JSON-serialized value, and optional timestamp metadata.

Args:
collection: The collection name.
key: The entry key.
managed_entry: The ManagedEntry to convert.

Returns:
An Elasticsearch document dictionary.
"""
document: dict[str, Any] = {
"collection": collection,
"key": key,
Expand All @@ -89,6 +102,20 @@ def managed_entry_to_document(collection: str, key: str, managed_entry: ManagedE


def source_to_managed_entry(source: dict[str, Any]) -> ManagedEntry:
"""Convert an Elasticsearch document source to a ManagedEntry.

This function deserializes an Elasticsearch document back to a ManagedEntry,
parsing the JSON-encoded value and timestamp metadata.

Args:
source: The Elasticsearch document _source dictionary.

Returns:
A ManagedEntry reconstructed from the document.

Raises:
DeserializationError: If the value field is missing or not a valid string.
"""
if not (value_str := source.get("value")) or not isinstance(value_str, str):
msg = "Value is not a string"
raise DeserializationError(msg)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,14 @@


def get_body_from_response(response: ObjectApiResponse[Any]) -> dict[str, Any]:
"""Extract and validate the body from an Elasticsearch response.

Args:
response: The Elasticsearch API response object.

Returns:
The response body as a dictionary, or an empty dict if the body is missing or invalid.
"""
if not (body := response.body): # pyright: ignore[reportAny]
return {}

Expand All @@ -15,6 +23,14 @@ def get_body_from_response(response: ObjectApiResponse[Any]) -> dict[str, Any]:


def get_source_from_body(body: dict[str, Any]) -> dict[str, Any]:
"""Extract and validate the _source field from an Elasticsearch response body.

Args:
body: The response body dictionary from Elasticsearch.

Returns:
The _source field as a dictionary, or an empty dict if missing or invalid.
"""
if not (source := body.get("_source")):
return {}

Expand All @@ -25,6 +41,14 @@ def get_source_from_body(body: dict[str, Any]) -> dict[str, Any]:


def get_aggregations_from_body(body: dict[str, Any]) -> dict[str, Any]:
"""Extract and validate the aggregations field from an Elasticsearch response body.

Args:
body: The response body dictionary from Elasticsearch.

Returns:
The aggregations field as a dictionary, or an empty dict if missing or invalid.
"""
if not (aggregations := body.get("aggregations")):
return {}

Expand All @@ -35,6 +59,17 @@ def get_aggregations_from_body(body: dict[str, Any]) -> dict[str, Any]:


def get_hits_from_response(response: ObjectApiResponse[Any]) -> list[dict[str, Any]]:
"""Extract and validate the hits array from an Elasticsearch response.

This function navigates the nested structure of Elasticsearch responses to extract
the hits.hits array which contains the actual search results.

Args:
response: The Elasticsearch API response object.

Returns:
A list of hit dictionaries from the response, or an empty list if the hits are missing or invalid.
"""
if not (body := response.body): # pyright: ignore[reportAny]
return []

Expand Down Expand Up @@ -63,6 +98,21 @@ def get_hits_from_response(response: ObjectApiResponse[Any]) -> list[dict[str, A


def get_fields_from_hit(hit: dict[str, Any]) -> dict[str, list[Any]]:
"""Extract and validate the fields from an Elasticsearch hit.

Elasticsearch can return stored fields via the "fields" key in each hit.
This function validates that the fields object exists and conforms to the
expected structure (a dict mapping field names to lists of values).

Args:
hit: A single hit dictionary from an Elasticsearch response.

Returns:
The fields dictionary from the hit, or an empty dict if missing.

Raises:
TypeError: If the fields structure is invalid (not a dict or contains non-list values).
"""
if not (fields := hit.get("fields")):
return {}

Expand All @@ -78,6 +128,18 @@ def get_fields_from_hit(hit: dict[str, Any]) -> dict[str, list[Any]]:


def get_field_from_hit(hit: dict[str, Any], field: str) -> list[Any]:
"""Extract a specific field value from an Elasticsearch hit.

Args:
hit: A single hit dictionary from an Elasticsearch response.
field: The name of the field to extract.

Returns:
The field value as a list, or an empty list if the fields object is missing.

Raises:
TypeError: If the specified field is not present in the hit.
"""
if not (fields := get_fields_from_hit(hit=hit)):
return []

Expand All @@ -89,6 +151,19 @@ def get_field_from_hit(hit: dict[str, Any], field: str) -> list[Any]:


def get_values_from_field_in_hit(hit: dict[str, Any], field: str, value_type: type[T]) -> list[T]:
"""Extract and type-check field values from an Elasticsearch hit.

Args:
hit: A single hit dictionary from an Elasticsearch response.
field: The name of the field to extract.
value_type: The expected type of values in the field list.

Returns:
A list of values of the specified type.

Raises:
TypeError: If the field is missing or contains values of the wrong type.
"""
if not (value := get_field_from_hit(hit=hit, field=field)):
msg = f"Field {field} is not in hit {hit}"
raise TypeError(msg)
Expand All @@ -101,6 +176,19 @@ def get_values_from_field_in_hit(hit: dict[str, Any], field: str, value_type: ty


def get_first_value_from_field_in_hit(hit: dict[str, Any], field: str, value_type: type[T]) -> T:
"""Extract and validate a single-value field from an Elasticsearch hit.

Args:
hit: A single hit dictionary from an Elasticsearch response.
field: The name of the field to extract.
value_type: The expected type of the value.

Returns:
The single value from the field.

Raises:
TypeError: If the field doesn't contain exactly one value, or if the value type is incorrect.
"""
values: list[T] = get_values_from_field_in_hit(hit=hit, field=field, value_type=value_type)
if len(values) != 1:
msg: str = f"Field {field} in hit {hit} is not a single value"
Expand All @@ -109,6 +197,19 @@ def get_first_value_from_field_in_hit(hit: dict[str, Any], field: str, value_typ


def managed_entry_to_document(collection: str, key: str, managed_entry: ManagedEntry) -> dict[str, Any]:
"""Convert a ManagedEntry to an Elasticsearch document format.

This function serializes a ManagedEntry to a document suitable for storage in Elasticsearch,
including collection, key, and value fields, plus optional timestamp metadata.

Args:
collection: The collection name to include in the document.
key: The key to include in the document.
managed_entry: The ManagedEntry to serialize.

Returns:
An Elasticsearch document dictionary ready for indexing.
"""
document: dict[str, Any] = {
"collection": collection,
"key": key,
Expand All @@ -124,4 +225,14 @@ def managed_entry_to_document(collection: str, key: str, managed_entry: ManagedE


def new_bulk_action(action: str, index: str, document_id: str) -> dict[str, Any]:
"""Create a bulk action descriptor for Elasticsearch bulk operations.

Args:
action: The bulk action type (e.g., "index", "delete", "update").
index: The Elasticsearch index name.
document_id: The document ID.

Returns:
A bulk action dictionary formatted for Elasticsearch's bulk API.
"""
return {action: {"_index": index, "_id": document_id}}
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,17 @@ def __init__(
super().__init__(default_collection=default_collection)

def sanitize_key(self, key: str) -> str:
"""Sanitize a key to fit Memcached's key length constraints.

Memcached has a maximum key length limit. Keys exceeding this limit are
hashed using SHA-256 and truncated to ensure they fit within the constraint.

Args:
key: The key to sanitize.

Returns:
The original key if within the length limit, or a SHA-256 hash prefix if too long.
"""
if len(key) > MAX_KEY_LENGTH:
sha256_hash: str = hashlib.sha256(key.encode()).hexdigest()
return sha256_hash[:64]
Expand Down
100 changes: 83 additions & 17 deletions key-value/key-value-aio/src/key_value/aio/stores/memory/store.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import sys
from dataclasses import dataclass, field
from dataclasses import dataclass
from datetime import datetime
from typing import Any, SupportsFloat
from typing import Any

from key_value.shared.utils.managed_entry import ManagedEntry
from key_value.shared.utils.time_to_live import epoch_to_datetime
from typing_extensions import Self, override

from key_value.aio.stores.base import (
Expand All @@ -24,40 +23,71 @@

@dataclass
class MemoryCacheEntry:
"""A cache entry for the memory store."""
"""A cache entry for the memory store.

This dataclass represents an entry in the MemoryStore cache, storing the JSON-serialized
value along with its expiration timestamp.
"""

json_str: str

expires_at: datetime | None

ttl_at_insert: SupportsFloat | None = field(default=None)

@classmethod
def from_managed_entry(cls, managed_entry: ManagedEntry, ttl: SupportsFloat | None = None) -> Self:
def from_managed_entry(cls, managed_entry: ManagedEntry) -> Self:
"""Create a cache entry from a ManagedEntry.

Args:
managed_entry: The ManagedEntry to convert.

Returns:
A new MemoryCacheEntry.
"""
return cls(
json_str=managed_entry.to_json(),
expires_at=managed_entry.expires_at,
ttl_at_insert=ttl,
)

def to_managed_entry(self) -> ManagedEntry:
"""Convert this cache entry to a ManagedEntry.

Returns:
A ManagedEntry reconstructed from the stored JSON.
"""
return ManagedEntry.from_json(json_str=self.json_str)


def _memory_cache_ttu(_key: Any, value: MemoryCacheEntry, now: float) -> float:
"""Calculate time-to-use for cache entries based on their TTL."""
if value.ttl_at_insert is None:
return float(sys.maxsize)
def _memory_cache_ttu(_key: Any, value: MemoryCacheEntry, _now: float) -> float:
"""Calculate time-to-use for cache entries based on their expiration time.

expiration_epoch: float = now + float(value.ttl_at_insert)
This function is used by TLRUCache to determine when entries should expire.

value.expires_at = epoch_to_datetime(epoch=expiration_epoch)
Args:
_key: The cache key (unused).
value: The cache entry.
_now: The current time as an epoch timestamp (unused).

return float(expiration_epoch)
Returns:
The expiration time as an epoch timestamp, or sys.maxsize if the entry has no TTL.
"""
if value.expires_at is None:
return float(sys.maxsize)

return value.expires_at.timestamp()
Comment on lines +60 to +76
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick | 🔵 Trivial

Consider adding noqa comment for consistency.

The function logic correctly uses expires_at.timestamp() for expiration. Both _key and _now parameters are properly prefixed with underscores to indicate they're unused.

For consistency with line 79 (which has # noqa: ARG001), consider adding # noqa: ARG001, ARG003 to suppress linting warnings for both unused parameters, though this is a minor style point.

Based on past review comments.

🤖 Prompt for AI Agents
In key-value/key-value-aio/src/key_value/aio/stores/memory/store.py around lines
60 to 76, the function _memory_cache_ttu currently prefixes unused parameters
with underscores but lacks the noqa lint suppression present elsewhere; add a
trailing comment "# noqa: ARG001, ARG003" on the function definition line to
explicitly suppress unused-argument warnings for _key and _now, keeping the
function logic unchanged.



def _memory_cache_getsizeof(value: MemoryCacheEntry) -> int: # noqa: ARG001
"""Return size of cache entry (always 1 for entry counting)."""
"""Return size of cache entry (always 1 for entry counting).

This function is used by TLRUCache to determine entry sizes. Since we want to count
entries rather than measure memory usage, this always returns 1.

Args:
value: The cache entry (unused).

Returns:
Always returns 1 to enable entry-based size limiting.
"""
return 1


Expand All @@ -68,6 +98,12 @@ def _memory_cache_getsizeof(value: MemoryCacheEntry) -> int: # noqa: ARG001


class MemoryCollection:
"""A fixed-size in-memory collection using TLRUCache.

This class wraps a time-aware LRU cache to provide TTL-based expiration
and automatic eviction of old entries when the cache reaches its size limit.
"""

_cache: TLRUCache[str, MemoryCacheEntry]

def __init__(self, max_entries: int = DEFAULT_MAX_ENTRIES_PER_COLLECTION):
Expand All @@ -83,6 +119,14 @@ def __init__(self, max_entries: int = DEFAULT_MAX_ENTRIES_PER_COLLECTION):
)

def get(self, key: str) -> ManagedEntry | None:
"""Retrieve an entry from the collection.

Args:
key: The key to retrieve.

Returns:
The ManagedEntry if found and not expired, None otherwise.
"""
managed_entry_str: MemoryCacheEntry | None = self._cache.get(key)

if managed_entry_str is None:
Expand All @@ -93,12 +137,34 @@ def get(self, key: str) -> ManagedEntry | None:
return managed_entry

def put(self, key: str, value: ManagedEntry) -> None:
self._cache[key] = MemoryCacheEntry.from_managed_entry(managed_entry=value, ttl=value.ttl)
"""Store an entry in the collection.

Args:
key: The key to store under.
value: The ManagedEntry to store.
"""
self._cache[key] = MemoryCacheEntry.from_managed_entry(managed_entry=value)

def delete(self, key: str) -> bool:
"""Delete an entry from the collection.

Args:
key: The key to delete.

Returns:
True if the key was deleted, False if it didn't exist.
"""
return self._cache.pop(key, None) is not None

def keys(self, *, limit: int | None = None) -> list[str]:
"""Retrieve all keys in the collection.

Args:
limit: The maximum number of keys to return. Defaults to 10,000.

Returns:
A list of keys in the collection, limited to the specified maximum.
"""
limit = min(limit or DEFAULT_PAGE_SIZE, PAGE_LIMIT)
return list(self._cache.keys())[:limit]

Expand Down
Loading