Skip to content

Commit ac651f4

Browse files
Refactor ManagedEntry serialization with adapter pattern (#184)
Co-authored-by: claude[bot] <41898282+claude[bot]@users.noreply.github.com> Co-authored-by: William Easton <[email protected]>
1 parent 17d4dfe commit ac651f4

File tree

62 files changed

+1160
-780
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

62 files changed

+1160
-780
lines changed

key-value/key-value-aio/src/key_value/aio/stores/base.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from key_value.shared.errors import StoreSetupError
1515
from key_value.shared.type_checking.bear_spray import bear_enforce
1616
from key_value.shared.utils.managed_entry import ManagedEntry
17+
from key_value.shared.utils.serialization import BasicSerializationAdapter, SerializationAdapter
1718
from key_value.shared.utils.time_to_live import prepare_entry_timestamps
1819
from typing_extensions import Self, override
1920

@@ -67,14 +68,23 @@ class BaseStore(AsyncKeyValueProtocol, ABC):
6768
_setup_collection_locks: defaultdict[str, Lock]
6869
_setup_collection_complete: defaultdict[str, bool]
6970

71+
_serialization_adapter: SerializationAdapter
72+
7073
_seed: FROZEN_SEED_DATA_TYPE
7174

7275
default_collection: str
7376

74-
def __init__(self, *, default_collection: str | None = None, seed: SEED_DATA_TYPE | None = None) -> None:
77+
def __init__(
78+
self,
79+
*,
80+
serialization_adapter: SerializationAdapter | None = None,
81+
default_collection: str | None = None,
82+
seed: SEED_DATA_TYPE | None = None,
83+
) -> None:
7584
"""Initialize the managed key-value store.
7685
7786
Args:
87+
serialization_adapter: The serialization adapter to use for the store.
7888
default_collection: The default collection to use if no collection is provided.
7989
Defaults to "default_collection".
8090
seed: Optional seed data to pre-populate the store. Format: {collection: {key: {field: value, ...}}}.
@@ -91,6 +101,8 @@ def __init__(self, *, default_collection: str | None = None, seed: SEED_DATA_TYP
91101

92102
self.default_collection = default_collection or DEFAULT_COLLECTION_NAME
93103

104+
self._serialization_adapter = serialization_adapter or BasicSerializationAdapter()
105+
94106
if not hasattr(self, "_stable_api"):
95107
self._stable_api = False
96108

@@ -286,9 +298,9 @@ async def put(self, key: str, value: Mapping[str, Any], *, collection: str | Non
286298
collection = collection or self.default_collection
287299
await self.setup_collection(collection=collection)
288300

289-
created_at, ttl_seconds, expires_at = prepare_entry_timestamps(ttl=ttl)
301+
created_at, _, expires_at = prepare_entry_timestamps(ttl=ttl)
290302

291-
managed_entry: ManagedEntry = ManagedEntry(value=value, ttl=ttl_seconds, created_at=created_at, expires_at=expires_at)
303+
managed_entry: ManagedEntry = ManagedEntry(value=value, created_at=created_at, expires_at=expires_at)
292304

293305
await self._put_managed_entry(
294306
collection=collection,
@@ -316,9 +328,7 @@ async def put_many(
316328

317329
created_at, ttl_seconds, expires_at = prepare_entry_timestamps(ttl=ttl)
318330

319-
managed_entries: list[ManagedEntry] = [
320-
ManagedEntry(value=value, ttl=ttl_seconds, created_at=created_at, expires_at=expires_at) for value in values
321-
]
331+
managed_entries: list[ManagedEntry] = [ManagedEntry(value=value, created_at=created_at, expires_at=expires_at) for value in values]
322332

323333
await self._put_managed_entries(
324334
collection=collection,

key-value/key-value-aio/src/key_value/aio/stores/disk/multi_store.py

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1-
import time
21
from collections.abc import Callable
2+
from datetime import timezone
33
from pathlib import Path
44
from typing import overload
55

6-
from key_value.shared.utils.compound import compound_key
7-
from key_value.shared.utils.managed_entry import ManagedEntry
6+
from key_value.shared.utils.managed_entry import ManagedEntry, datetime
7+
from key_value.shared.utils.serialization import BasicSerializationAdapter
88
from typing_extensions import override
99

1010
from key_value.aio.stores.base import BaseContextManagerStore, BaseStore
@@ -100,6 +100,7 @@ def default_disk_cache_factory(collection: str) -> Cache:
100100
self._cache = {}
101101

102102
self._stable_api = True
103+
self._serialization_adapter = BasicSerializationAdapter()
103104

104105
super().__init__(default_collection=default_collection)
105106

@@ -109,18 +110,17 @@ async def _setup_collection(self, *, collection: str) -> None:
109110

110111
@override
111112
async def _get_managed_entry(self, *, key: str, collection: str) -> ManagedEntry | None:
112-
combo_key: str = compound_key(collection=collection, key=key)
113-
114113
expire_epoch: float
115114

116-
managed_entry_str, expire_epoch = self._cache[collection].get(key=combo_key, expire_time=True) # pyright: ignore[reportAny]
115+
managed_entry_str, expire_epoch = self._cache[collection].get(key=key, expire_time=True) # pyright: ignore[reportAny]
117116

118117
if not isinstance(managed_entry_str, str):
119118
return None
120119

121-
ttl = (expire_epoch - time.time()) if expire_epoch else None
120+
managed_entry: ManagedEntry = self._serialization_adapter.load_json(json_str=managed_entry_str)
122121

123-
managed_entry: ManagedEntry = ManagedEntry.from_json(json_str=managed_entry_str, ttl=ttl)
122+
if expire_epoch:
123+
managed_entry.expires_at = datetime.fromtimestamp(expire_epoch, tz=timezone.utc)
124124

125125
return managed_entry
126126

@@ -132,15 +132,11 @@ async def _put_managed_entry(
132132
collection: str,
133133
managed_entry: ManagedEntry,
134134
) -> None:
135-
combo_key: str = compound_key(collection=collection, key=key)
136-
137-
_ = self._cache[collection].set(key=combo_key, value=managed_entry.to_json(include_expiration=False), expire=managed_entry.ttl)
135+
_ = self._cache[collection].set(key=key, value=self._serialization_adapter.dump_json(entry=managed_entry), expire=managed_entry.ttl)
138136

139137
@override
140138
async def _delete_managed_entry(self, *, key: str, collection: str) -> bool:
141-
combo_key: str = compound_key(collection=collection, key=key)
142-
143-
return self._cache[collection].delete(key=combo_key, retry=True)
139+
return self._cache[collection].delete(key=key, retry=True)
144140

145141
def _sync_close(self) -> None:
146142
for cache in self._cache.values():

key-value/key-value-aio/src/key_value/aio/stores/disk/store.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import time
1+
from datetime import datetime, timezone
22
from pathlib import Path
33
from typing import overload
44

@@ -90,9 +90,10 @@ async def _get_managed_entry(self, *, key: str, collection: str) -> ManagedEntry
9090
if not isinstance(managed_entry_str, str):
9191
return None
9292

93-
ttl = (expire_epoch - time.time()) if expire_epoch else None
93+
managed_entry: ManagedEntry = self._serialization_adapter.load_json(json_str=managed_entry_str)
9494

95-
managed_entry: ManagedEntry = ManagedEntry.from_json(json_str=managed_entry_str, ttl=ttl)
95+
if expire_epoch:
96+
managed_entry.expires_at = datetime.fromtimestamp(expire_epoch, tz=timezone.utc)
9697

9798
return managed_entry
9899

@@ -106,7 +107,7 @@ async def _put_managed_entry(
106107
) -> None:
107108
combo_key: str = compound_key(collection=collection, key=key)
108109

109-
_ = self._cache.set(key=combo_key, value=managed_entry.to_json(include_expiration=False), expire=managed_entry.ttl)
110+
_ = self._cache.set(key=combo_key, value=self._serialization_adapter.dump_json(entry=managed_entry), expire=managed_entry.ttl)
110111

111112
@override
112113
async def _delete_managed_entry(self, *, key: str, collection: str) -> bool:

key-value/key-value-aio/src/key_value/aio/stores/dynamodb/store.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from datetime import datetime, timezone
12
from types import TracebackType
23
from typing import TYPE_CHECKING, Any, overload
34

@@ -183,23 +184,31 @@ async def _setup(self) -> None:
183184
@override
184185
async def _get_managed_entry(self, *, key: str, collection: str) -> ManagedEntry | None:
185186
"""Retrieve a managed entry from DynamoDB."""
186-
response = await self._connected_client.get_item( # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType]
187+
response = await self._connected_client.get_item(
187188
TableName=self._table_name,
188189
Key={
189190
"collection": {"S": collection},
190191
"key": {"S": key},
191192
},
192193
)
193194

194-
item = response.get("Item") # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType]
195+
item = response.get("Item")
195196
if not item:
196197
return None
197198

198-
json_value = item.get("value", {}).get("S") # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType]
199+
json_value = item.get("value", {}).get("S")
199200
if not json_value:
200201
return None
201202

202-
return ManagedEntry.from_json(json_str=json_value) # pyright: ignore[reportUnknownArgumentType]
203+
managed_entry: ManagedEntry = self._serialization_adapter.load_json(json_str=json_value)
204+
205+
expires_at_epoch = item.get("ttl", {}).get("N")
206+
207+
# Our managed entry may carry a TTL, but the TTL in DynamoDB takes precedence.
208+
if expires_at_epoch:
209+
managed_entry.expires_at = datetime.fromtimestamp(int(expires_at_epoch), tz=timezone.utc)
210+
211+
return managed_entry
203212

204213
@override
205214
async def _put_managed_entry(
@@ -210,7 +219,7 @@ async def _put_managed_entry(
210219
managed_entry: ManagedEntry,
211220
) -> None:
212221
"""Store a managed entry in DynamoDB."""
213-
json_value = managed_entry.to_json()
222+
json_value = self._serialization_adapter.dump_json(entry=managed_entry)
214223

215224
item: dict[str, Any] = {
216225
"collection": {"S": collection},
@@ -219,9 +228,9 @@ async def _put_managed_entry(
219228
}
220229

221230
# Add TTL if present
222-
if managed_entry.ttl is not None and managed_entry.created_at is not None:
231+
if managed_entry.expires_at is not None:
223232
# DynamoDB TTL expects a Unix timestamp
224-
ttl_timestamp = int(managed_entry.created_at.timestamp() + managed_entry.ttl)
233+
ttl_timestamp = int(managed_entry.expires_at.timestamp())
225234
item["ttl"] = {"N": str(ttl_timestamp)}
226235

227236
await self._connected_client.put_item( # pyright: ignore[reportUnknownMemberType]

key-value/key-value-aio/src/key_value/aio/stores/elasticsearch/store.py

Lines changed: 45 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,15 @@
66
from elastic_transport import ObjectApiResponse
77
from elastic_transport import SerializationError as ElasticsearchSerializationError
88
from key_value.shared.errors import DeserializationError, SerializationError
9-
from key_value.shared.utils.managed_entry import ManagedEntry, load_from_json, verify_dict
9+
from key_value.shared.utils.managed_entry import ManagedEntry
1010
from key_value.shared.utils.sanitize import (
1111
ALPHANUMERIC_CHARACTERS,
1212
LOWERCASE_ALPHABET,
1313
NUMBERS,
1414
sanitize_string,
1515
)
16-
from key_value.shared.utils.time_to_live import now_as_epoch, try_parse_datetime_str
16+
from key_value.shared.utils.serialization import SerializationAdapter
17+
from key_value.shared.utils.time_to_live import now_as_epoch
1718
from typing_extensions import override
1819

1920
from key_value.aio.stores.base import (
@@ -84,52 +85,50 @@
8485
ALLOWED_INDEX_CHARACTERS: str = LOWERCASE_ALPHABET + NUMBERS + "_" + "-" + "."
8586

8687

87-
def managed_entry_to_document(collection: str, key: str, managed_entry: ManagedEntry, *, native_storage: bool = False) -> dict[str, Any]:
88-
document: dict[str, Any] = {"collection": collection, "key": key, "value": {}}
88+
class ElasticsearchSerializationAdapter(SerializationAdapter):
89+
"""Adapter for Elasticsearch with support for native and string storage modes."""
8990

90-
# Store in appropriate field based on mode
91-
if native_storage:
92-
document["value"]["flattened"] = managed_entry.value_as_dict
93-
else:
94-
document["value"]["string"] = managed_entry.value_as_json
91+
_native_storage: bool
9592

96-
if managed_entry.created_at:
97-
document["created_at"] = managed_entry.created_at.isoformat()
98-
if managed_entry.expires_at:
99-
document["expires_at"] = managed_entry.expires_at.isoformat()
93+
def __init__(self, *, native_storage: bool = True) -> None:
94+
"""Initialize the Elasticsearch adapter.
10095
101-
return document
96+
Args:
97+
native_storage: If True (default), store values as flattened dicts.
98+
If False, store values as JSON strings.
99+
"""
100+
super().__init__()
102101

102+
self._native_storage = native_storage
103+
self._date_format = "isoformat"
104+
self._value_format = "dict" if native_storage else "string"
103105

104-
def source_to_managed_entry(source: dict[str, Any]) -> ManagedEntry:
105-
value: dict[str, Any] = {}
106+
@override
107+
def prepare_dump(self, data: dict[str, Any]) -> dict[str, Any]:
108+
value = data.pop("value")
109+
110+
data["value"] = {}
106111

107-
raw_value = source.get("value")
112+
if self._native_storage:
113+
data["value"]["flattened"] = value
114+
else:
115+
data["value"]["string"] = value
108116

109-
# Try flattened field first, fall back to string field
110-
if not raw_value or not isinstance(raw_value, dict):
111-
msg = "Value field not found or invalid type"
112-
raise DeserializationError(msg)
117+
return data
113118

114-
if value_flattened := raw_value.get("flattened"): # pyright: ignore[reportUnknownVariableType, reportUnknownMemberType]
115-
value = verify_dict(obj=value_flattened)
116-
elif value_str := raw_value.get("string"): # pyright: ignore[reportUnknownVariableType, reportUnknownMemberType]
117-
if not isinstance(value_str, str):
118-
msg = "Value in `value` field is not a string"
119-
raise DeserializationError(msg)
120-
value = load_from_json(value_str)
121-
else:
122-
msg = "Value field not found or invalid type"
123-
raise DeserializationError(msg)
119+
@override
120+
def prepare_load(self, data: dict[str, Any]) -> dict[str, Any]:
121+
value = data.pop("value")
124122

125-
created_at: datetime | None = try_parse_datetime_str(value=source.get("created_at"))
126-
expires_at: datetime | None = try_parse_datetime_str(value=source.get("expires_at"))
123+
if "flattened" in value:
124+
data["value"] = value["flattened"]
125+
elif "string" in value:
126+
data["value"] = value["string"]
127+
else:
128+
msg = "Value field not found in Elasticsearch document"
129+
raise DeserializationError(message=msg)
127130

128-
return ManagedEntry(
129-
value=value,
130-
created_at=created_at,
131-
expires_at=expires_at,
132-
)
131+
return data
133132

134133

135134
class ElasticsearchStore(
@@ -145,6 +144,8 @@ class ElasticsearchStore(
145144

146145
_native_storage: bool
147146

147+
_adapter: SerializationAdapter
148+
148149
@overload
149150
def __init__(
150151
self,
@@ -208,6 +209,7 @@ def __init__(
208209
self._index_prefix = index_prefix
209210
self._native_storage = native_storage
210211
self._is_serverless = False
212+
self._adapter = ElasticsearchSerializationAdapter(native_storage=native_storage)
211213

212214
super().__init__(default_collection=default_collection)
213215

@@ -260,7 +262,7 @@ async def _get_managed_entry(self, *, key: str, collection: str) -> ManagedEntry
260262
return None
261263

262264
try:
263-
return source_to_managed_entry(source=source)
265+
return self._adapter.load_dict(data=source)
264266
except DeserializationError:
265267
return None
266268

@@ -293,7 +295,7 @@ async def _get_managed_entries(self, *, collection: str, keys: Sequence[str]) ->
293295
continue
294296

295297
try:
296-
entries_by_id[doc_id] = source_to_managed_entry(source=source)
298+
entries_by_id[doc_id] = self._adapter.load_dict(data=source)
297299
except DeserializationError as e:
298300
logger.error(
299301
"Failed to deserialize Elasticsearch document in batch operation",
@@ -324,9 +326,7 @@ async def _put_managed_entry(
324326
index_name: str = self._sanitize_index_name(collection=collection)
325327
document_id: str = self._sanitize_document_id(key=key)
326328

327-
document: dict[str, Any] = managed_entry_to_document(
328-
collection=collection, key=key, managed_entry=managed_entry, native_storage=self._native_storage
329-
)
329+
document: dict[str, Any] = self._adapter.dump_dict(entry=managed_entry)
330330

331331
try:
332332
_ = await self._client.index(
@@ -364,11 +364,10 @@ async def _put_managed_entries(
364364

365365
index_action: dict[str, Any] = new_bulk_action(action="index", index=index_name, document_id=document_id)
366366

367-
document: dict[str, Any] = managed_entry_to_document(
368-
collection=collection, key=key, managed_entry=managed_entry, native_storage=self._native_storage
369-
)
367+
document: dict[str, Any] = self._adapter.dump_dict(entry=managed_entry)
370368

371369
operations.extend([index_action, document])
370+
372371
try:
373372
_ = await self._client.bulk(operations=operations, refresh=self._should_refresh_on_put) # pyright: ignore[reportUnknownMemberType]
374373
except ElasticsearchSerializationError as e:

0 commit comments

Comments
 (0)