Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
c07ed4b
Initial plan
Copilot Oct 24, 2025
a47353f
Implement bulk operations for Redis, Valkey, DynamoDB, MongoDB, Memca…
Copilot Oct 24, 2025
1b32bcc
Remove RocksDB _get_managed_entries implementation (multi_get not ava…
Copilot Oct 24, 2025
808a899
Add _batch_items helper to BaseStore and refactor DynamoDB to use it
Copilot Oct 25, 2025
aba884f
Remove DynamoDB bulk operations and _batch_items helper from BaseStore
Copilot Oct 25, 2025
8f7954c
Merge branch 'main' into copilot/optimize-many-methods-on-stores
strawgate Oct 25, 2025
9d467dd
Merge branch 'main' into copilot/optimize-many-methods-on-stores
strawgate Oct 26, 2025
2101910
refactor: improve bulk operations implementation based on code review
github-actions[bot] Oct 26, 2025
7f029ca
fix: resolve static type check errors in bulk operations
github-actions[bot] Oct 26, 2025
4523b80
Merge branch 'main' into copilot/optimize-many-methods-on-stores
strawgate Oct 26, 2025
2956ae3
Merge branch 'main' into copilot/optimize-many-methods-on-stores
strawgate Oct 26, 2025
ce58f72
Add delete_many tests for missing keys
strawgate Oct 26, 2025
742999d
Merge branch 'main' into copilot/optimize-many-methods-on-stores
strawgate Oct 26, 2025
c500cf4
refactor: simplify bulk put operations for single TTL interface
github-actions[bot] Oct 26, 2025
ce673cc
refactor: simplify bulk put operations for single TTL interface
github-actions[bot] Oct 26, 2025
e4e9b03
More test updates
strawgate Oct 26, 2025
4076872
refactor: pass ttl and timestamps as parameters to _put_managed_entries
github-actions[bot] Oct 26, 2025
d7da259
Improvements to bulk requests in stores
strawgate Oct 26, 2025
07e5d76
Merge remote-tracking branch 'origin/main' into copilot/optimize-many…
strawgate Oct 26, 2025
6604651
Lint and fix
strawgate Oct 26, 2025
17b4621
Additional cleanup
strawgate Oct 26, 2025
1dac47f
lint
strawgate Oct 26, 2025
3f0dabe
More document conversion tests
strawgate Oct 26, 2025
92936fa
Merge remote-tracking branch 'origin/main' into copilot/optimize-many…
strawgate Oct 26, 2025
275c8d9
Fixes for Windows Registry store tests
strawgate Oct 26, 2025
f5a0e41
More Windows registry fixes
strawgate Oct 26, 2025
52a8879
make ttl expire test more reliable
strawgate Oct 26, 2025
eeffb17
Fix windows registry utils
strawgate Oct 26, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 47 additions & 24 deletions key-value/key-value-aio/src/key_value/aio/stores/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,15 @@
from asyncio.locks import Lock
from collections import defaultdict
from collections.abc import Mapping, Sequence
from datetime import datetime
from types import MappingProxyType, TracebackType
from typing import Any, SupportsFloat

from key_value.shared.constants import DEFAULT_COLLECTION_NAME
from key_value.shared.errors import StoreSetupError
from key_value.shared.type_checking.bear_spray import bear_enforce
from key_value.shared.utils.managed_entry import ManagedEntry
from key_value.shared.utils.time_to_live import now, prepare_ttl
from key_value.shared.utils.time_to_live import prepare_entry_timestamps
from typing_extensions import Self, override

from key_value.aio.protocols.key_value import (
Expand Down Expand Up @@ -207,13 +208,36 @@ async def ttl_many(
return [(dict(entry.value), entry.ttl) if entry and not entry.is_expired else (None, None) for entry in entries]

@abstractmethod
async def _put_managed_entry(self, *, collection: str, key: str, managed_entry: ManagedEntry) -> None:
async def _put_managed_entry(
self,
*,
collection: str,
key: str,
managed_entry: ManagedEntry,
) -> None:
"""Store a managed entry by key in the specified collection."""
...

async def _put_managed_entries(self, *, collection: str, keys: Sequence[str], managed_entries: Sequence[ManagedEntry]) -> None:
"""Store multiple managed entries by key in the specified collection."""
async def _put_managed_entries(
self,
*,
collection: str,
keys: Sequence[str],
managed_entries: Sequence[ManagedEntry],
ttl: float | None, # noqa: ARG002
created_at: datetime, # noqa: ARG002
expires_at: datetime | None, # noqa: ARG002
) -> None:
"""Store multiple managed entries by key in the specified collection.

Args:
collection: The collection to store entries in
keys: The keys for the entries
managed_entries: The managed entries to store
ttl: The TTL in seconds (None for no expiration)
created_at: The creation timestamp for all entries
expires_at: The expiration timestamp for all entries (None if no TTL)
"""
for key, managed_entry in zip(keys, managed_entries, strict=True):
await self._put_managed_entry(
collection=collection,
Expand All @@ -228,30 +252,16 @@ async def put(self, key: str, value: Mapping[str, Any], *, collection: str | Non
collection = collection or self.default_collection
await self.setup_collection(collection=collection)

managed_entry: ManagedEntry = ManagedEntry(value=value, ttl=prepare_ttl(t=ttl), created_at=now())
created_at, ttl_seconds, expires_at = prepare_entry_timestamps(ttl=ttl)

managed_entry: ManagedEntry = ManagedEntry(value=value, ttl=ttl_seconds, created_at=created_at, expires_at=expires_at)

await self._put_managed_entry(
collection=collection,
key=key,
managed_entry=managed_entry,
)

def _prepare_put_many(
self, *, keys: Sequence[str], values: Sequence[Mapping[str, Any]], ttl: SupportsFloat | None
) -> tuple[Sequence[str], Sequence[Mapping[str, Any]], float | None]:
"""Prepare multiple managed entries for a put_many operation.

Inheriting classes can use this method if they need to modify a put_many operation."""

if len(keys) != len(values):
msg = "put_many called but a different number of keys and values were provided"
raise ValueError(msg) from None

ttl_for_entries: float | None = prepare_ttl(t=ttl)

return (keys, values, ttl_for_entries)

@bear_enforce
@override
async def put_many(
self,
Expand All @@ -266,11 +276,24 @@ async def put_many(
collection = collection or self.default_collection
await self.setup_collection(collection=collection)

keys, values, ttl_for_entries = self._prepare_put_many(keys=keys, values=values, ttl=ttl)
if len(keys) != len(values):
msg = "put_many called but a different number of keys and values were provided"
raise ValueError(msg) from None

managed_entries: list[ManagedEntry] = [ManagedEntry(value=value, ttl=ttl_for_entries, created_at=now()) for value in values]
created_at, ttl_seconds, expires_at = prepare_entry_timestamps(ttl=ttl)

await self._put_managed_entries(collection=collection, keys=keys, managed_entries=managed_entries)
managed_entries: list[ManagedEntry] = [
ManagedEntry(value=value, ttl=ttl_seconds, created_at=created_at, expires_at=expires_at) for value in values
]

await self._put_managed_entries(
collection=collection,
keys=keys,
managed_entries=managed_entries,
ttl=ttl_seconds,
created_at=created_at,
expires_at=expires_at,
)

@abstractmethod
async def _delete_managed_entry(self, *, key: str, collection: str) -> bool:
Expand Down
161 changes: 140 additions & 21 deletions key-value/key-value-aio/src/key_value/aio/stores/elasticsearch/store.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from datetime import datetime # noqa: TC003
from collections.abc import Sequence
from datetime import datetime
from typing import Any, overload

from elastic_transport import ObjectApiResponse # noqa: TC002
from key_value.shared.utils.compound import compound_key
from key_value.shared.errors import DeserializationError
from key_value.shared.utils.managed_entry import ManagedEntry, load_from_json
from key_value.shared.utils.sanitize import (
ALPHANUMERIC_CHARACTERS,
Expand All @@ -21,6 +22,7 @@
BaseEnumerateKeysStore,
BaseStore,
)
from key_value.aio.stores.elasticsearch.utils import new_bulk_action

try:
from elasticsearch import AsyncElasticsearch
Expand Down Expand Up @@ -71,6 +73,36 @@
ALLOWED_INDEX_CHARACTERS: str = LOWERCASE_ALPHABET + NUMBERS + "_" + "-" + "."


def managed_entry_to_document(collection: str, key: str, managed_entry: ManagedEntry) -> dict[str, Any]:
document: dict[str, Any] = {
"collection": collection,
"key": key,
"value": managed_entry.to_json(include_metadata=False),
}

if managed_entry.created_at:
document["created_at"] = managed_entry.created_at.isoformat()
if managed_entry.expires_at:
document["expires_at"] = managed_entry.expires_at.isoformat()

return document


def source_to_managed_entry(source: dict[str, Any]) -> ManagedEntry:
if not (value_str := source.get("value")) or not isinstance(value_str, str):
msg = "Value is not a string"
raise DeserializationError(msg)

created_at: datetime | None = try_parse_datetime_str(value=source.get("created_at"))
expires_at: datetime | None = try_parse_datetime_str(value=source.get("expires_at"))

return ManagedEntry(
value=load_from_json(value_str),
created_at=created_at,
expires_at=expires_at,
)


class ElasticsearchStore(
BaseEnumerateCollectionsStore, BaseEnumerateKeysStore, BaseDestroyCollectionStore, BaseCullStore, BaseContextManagerStore, BaseStore
):
Expand Down Expand Up @@ -156,13 +188,17 @@ def _sanitize_document_id(self, key: str) -> str:
allowed_characters=ALLOWED_KEY_CHARACTERS,
)

def _get_destination(self, *, collection: str, key: str) -> tuple[str, str]:
index_name: str = self._sanitize_index_name(collection=collection)
document_id: str = self._sanitize_document_id(key=key)

return index_name, document_id

@override
async def _get_managed_entry(self, *, key: str, collection: str) -> ManagedEntry | None:
combo_key: str = compound_key(collection=collection, key=key)
index_name, document_id = self._get_destination(collection=collection, key=key)

elasticsearch_response = await self._client.options(ignore_status=404).get(
index=self._sanitize_index_name(collection=collection), id=self._sanitize_document_id(key=combo_key)
)
elasticsearch_response = await self._client.options(ignore_status=404).get(index=index_name, id=document_id)

body: dict[str, Any] = get_body_from_response(response=elasticsearch_response)

Expand All @@ -181,6 +217,39 @@ async def _get_managed_entry(self, *, key: str, collection: str) -> ManagedEntry
expires_at=expires_at,
)

@override
async def _get_managed_entries(self, *, collection: str, keys: Sequence[str]) -> list[ManagedEntry | None]:
if not keys:
return []

# Use mget for efficient batch retrieval
index_name = self._sanitize_index_name(collection=collection)
document_ids = [self._sanitize_document_id(key=key) for key in keys]
docs = [{"_id": document_id} for document_id in document_ids]

elasticsearch_response = await self._client.options(ignore_status=404).mget(index=index_name, docs=docs)

body: dict[str, Any] = get_body_from_response(response=elasticsearch_response)
docs_result = body.get("docs", [])

entries_by_id: dict[str, ManagedEntry | None] = {}
for doc in docs_result:
if not (doc_id := doc.get("_id")):
continue

if "found" not in doc:
entries_by_id[doc_id] = None
continue

if not (source := doc.get("_source")):
entries_by_id[doc_id] = None
continue

entries_by_id[doc_id] = source_to_managed_entry(source=source)

# Return entries in the same order as input keys
return [entries_by_id.get(document_id) for document_id in document_ids]

@property
def _should_refresh_on_put(self) -> bool:
return not self._is_serverless
Expand All @@ -193,32 +262,54 @@ async def _put_managed_entry(
collection: str,
managed_entry: ManagedEntry,
) -> None:
combo_key: str = compound_key(collection=collection, key=key)
index_name: str = self._sanitize_index_name(collection=collection)
document_id: str = self._sanitize_document_id(key=key)

document: dict[str, Any] = {
"collection": collection,
"key": key,
"value": managed_entry.to_json(include_metadata=False),
}

if managed_entry.created_at:
document["created_at"] = managed_entry.created_at.isoformat()
if managed_entry.expires_at:
document["expires_at"] = managed_entry.expires_at.isoformat()
document: dict[str, Any] = managed_entry_to_document(collection=collection, key=key, managed_entry=managed_entry)

_ = await self._client.index(
index=self._sanitize_index_name(collection=collection),
id=self._sanitize_document_id(key=combo_key),
index=index_name,
id=document_id,
body=document,
refresh=self._should_refresh_on_put,
)

@override
async def _put_managed_entries(
self,
*,
collection: str,
keys: Sequence[str],
managed_entries: Sequence[ManagedEntry],
ttl: float | None,
created_at: datetime,
expires_at: datetime | None,
) -> None:
if not keys:
return

operations: list[dict[str, Any]] = []

index_name: str = self._sanitize_index_name(collection=collection)

for key, managed_entry in zip(keys, managed_entries, strict=True):
document_id: str = self._sanitize_document_id(key=key)

index_action: dict[str, Any] = new_bulk_action(action="index", index=index_name, document_id=document_id)

document: dict[str, Any] = managed_entry_to_document(collection=collection, key=key, managed_entry=managed_entry)

operations.extend([index_action, document])

_ = await self._client.bulk(operations=operations, refresh=self._should_refresh_on_put) # pyright: ignore[reportUnknownMemberType]

@override
async def _delete_managed_entry(self, *, key: str, collection: str) -> bool:
combo_key: str = compound_key(collection=collection, key=key)
index_name: str = self._sanitize_index_name(collection=collection)
document_id: str = self._sanitize_document_id(key=key)

elasticsearch_response: ObjectApiResponse[Any] = await self._client.options(ignore_status=404).delete(
index=self._sanitize_index_name(collection=collection), id=self._sanitize_document_id(key=combo_key)
index=index_name, id=document_id
)

body: dict[str, Any] = get_body_from_response(response=elasticsearch_response)
Expand All @@ -228,6 +319,34 @@ async def _delete_managed_entry(self, *, key: str, collection: str) -> bool:

return result == "deleted"

@override
async def _delete_managed_entries(self, *, keys: Sequence[str], collection: str) -> int:
if not keys:
return 0

operations: list[dict[str, Any]] = []

for key in keys:
index_name, document_id = self._get_destination(collection=collection, key=key)

delete_action: dict[str, Any] = new_bulk_action(action="delete", index=index_name, document_id=document_id)

operations.append(delete_action)

elasticsearch_response = await self._client.bulk(operations=operations) # pyright: ignore[reportUnknownMemberType]

body: dict[str, Any] = get_body_from_response(response=elasticsearch_response)

# Count successful deletions
deleted_count = 0
items = body.get("items", [])
for item in items:
delete_result = item.get("delete", {})
if delete_result.get("result") == "deleted":
deleted_count += 1

return deleted_count

@override
async def _get_collection_keys(self, *, collection: str, limit: int | None = None) -> list[str]:
"""Get up to 10,000 keys in the specified collection (eventually consistent)."""
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Any, TypeVar, cast

from elastic_transport import ObjectApiResponse
from key_value.shared.utils.managed_entry import ManagedEntry


def get_body_from_response(response: ObjectApiResponse[Any]) -> dict[str, Any]:
Expand Down Expand Up @@ -105,3 +106,22 @@ def get_first_value_from_field_in_hit(hit: dict[str, Any], field: str, value_typ
msg: str = f"Field {field} in hit {hit} is not a single value"
raise TypeError(msg)
return values[0]


def managed_entry_to_document(collection: str, key: str, managed_entry: ManagedEntry) -> dict[str, Any]:
document: dict[str, Any] = {
"collection": collection,
"key": key,
"value": managed_entry.to_json(include_metadata=False),
}

if managed_entry.created_at:
document["created_at"] = managed_entry.created_at.isoformat()
if managed_entry.expires_at:
document["expires_at"] = managed_entry.expires_at.isoformat()

return document


def new_bulk_action(action: str, index: str, document_id: str) -> dict[str, Any]:
return {action: {"_index": index, "_id": document_id}}
Loading
Loading