Skip to content

Commit 714b0cb

Browse files
committed
Refactor serialization
1 parent 49cf605 commit 714b0cb

File tree

17 files changed

+474
-695
lines changed

17 files changed

+474
-695
lines changed

key-value/key-value-aio/src/key_value/aio/stores/base.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from key_value.shared.errors import StoreSetupError
1515
from key_value.shared.type_checking.bear_spray import bear_enforce
1616
from key_value.shared.utils.managed_entry import ManagedEntry
17+
from key_value.shared.utils.serialization import BasicSerializationAdapter, SerializationAdapter
1718
from key_value.shared.utils.time_to_live import prepare_entry_timestamps
1819
from typing_extensions import Self, override
1920

@@ -67,6 +68,8 @@ class BaseStore(AsyncKeyValueProtocol, ABC):
6768
_setup_collection_locks: defaultdict[str, Lock]
6869
_setup_collection_complete: defaultdict[str, bool]
6970

71+
_serialization_adapter: SerializationAdapter
72+
7073
_seed: FROZEN_SEED_DATA_TYPE
7174

7275
default_collection: str
@@ -91,6 +94,8 @@ def __init__(self, *, default_collection: str | None = None, seed: SEED_DATA_TYP
9194

9295
self.default_collection = default_collection or DEFAULT_COLLECTION_NAME
9396

97+
self._serialization_adapter = BasicSerializationAdapter()
98+
9499
if not hasattr(self, "_stable_api"):
95100
self._stable_api = False
96101

@@ -286,9 +291,9 @@ async def put(self, key: str, value: Mapping[str, Any], *, collection: str | Non
286291
collection = collection or self.default_collection
287292
await self.setup_collection(collection=collection)
288293

289-
created_at, ttl_seconds, expires_at = prepare_entry_timestamps(ttl=ttl)
294+
created_at, _, expires_at = prepare_entry_timestamps(ttl=ttl)
290295

291-
managed_entry: ManagedEntry = ManagedEntry(value=value, ttl=ttl_seconds, created_at=created_at, expires_at=expires_at)
296+
managed_entry: ManagedEntry = ManagedEntry(value=value, created_at=created_at, expires_at=expires_at)
292297

293298
await self._put_managed_entry(
294299
collection=collection,
@@ -316,9 +321,7 @@ async def put_many(
316321

317322
created_at, ttl_seconds, expires_at = prepare_entry_timestamps(ttl=ttl)
318323

319-
managed_entries: list[ManagedEntry] = [
320-
ManagedEntry(value=value, ttl=ttl_seconds, created_at=created_at, expires_at=expires_at) for value in values
321-
]
324+
managed_entries: list[ManagedEntry] = [ManagedEntry(value=value, created_at=created_at, expires_at=expires_at) for value in values]
322325

323326
await self._put_managed_entries(
324327
collection=collection,

key-value/key-value-aio/src/key_value/aio/stores/elasticsearch/store.py

Lines changed: 34 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,15 @@
66
from elastic_transport import ObjectApiResponse
77
from elastic_transport import SerializationError as ElasticsearchSerializationError
88
from key_value.shared.errors import DeserializationError, SerializationError
9-
from key_value.shared.utils.managed_entry import ManagedEntry, load_from_json, verify_dict
9+
from key_value.shared.utils.managed_entry import ManagedEntry
1010
from key_value.shared.utils.sanitize import (
1111
ALPHANUMERIC_CHARACTERS,
1212
LOWERCASE_ALPHABET,
1313
NUMBERS,
1414
sanitize_string,
1515
)
1616
from key_value.shared.utils.serialization import SerializationAdapter
17-
from key_value.shared.utils.time_to_live import now_as_epoch, try_parse_datetime_str
17+
from key_value.shared.utils.time_to_live import now_as_epoch
1818
from typing_extensions import override
1919

2020
from key_value.aio.stores.base import (
@@ -85,103 +85,50 @@
8585
ALLOWED_INDEX_CHARACTERS: str = LOWERCASE_ALPHABET + NUMBERS + "_" + "-" + "."
8686

8787

88-
class ElasticsearchAdapter(SerializationAdapter):
89-
"""Adapter for Elasticsearch with support for native and string storage modes.
88+
class ElasticsearchSerializationAdapter(SerializationAdapter):
89+
"""Adapter for Elasticsearch with support for native and string storage modes."""
9090

91-
This adapter supports two storage modes:
92-
- Native mode: Stores values as flattened dicts for efficient querying
93-
- String mode: Stores values as JSON strings for backward compatibility
94-
95-
Elasticsearch-specific features:
96-
- Stores collection name in the document for multi-tenancy
97-
- Uses ISO format for datetime fields
98-
- Supports migration between storage modes
99-
"""
91+
_native_storage: bool
10092

10193
def __init__(self, *, native_storage: bool = True) -> None:
10294
"""Initialize the Elasticsearch adapter.
10395
10496
Args:
10597
native_storage: If True (default), store values as flattened dicts.
106-
If False, store values as JSON strings.
98+
If False, store values as JSON strings.
10799
"""
108-
self.native_storage = native_storage
100+
super().__init__()
109101

110-
def to_storage(self, key: str, entry: ManagedEntry, collection: str | None = None) -> dict[str, Any]:
111-
"""Convert a ManagedEntry to an Elasticsearch document.
102+
self._native_storage = native_storage
103+
self._date_format = "isoformat"
104+
self._value_format = "dict" if native_storage else "string"
112105

113-
Args:
114-
key: The key associated with this entry.
115-
entry: The ManagedEntry to serialize.
116-
collection: The collection name to store in the document.
106+
@override
107+
def prepare_dump(self, data: dict[str, Any]) -> dict[str, Any]:
108+
value = data.pop("value")
117109

118-
Returns:
119-
An Elasticsearch document dict with collection, key, value, and metadata.
120-
"""
121-
document: dict[str, Any] = {"collection": collection or "", "key": key, "value": {}}
110+
data["value"] = {}
122111

123-
# Store in appropriate field based on mode
124-
if self.native_storage:
125-
document["value"]["flattened"] = entry.value_as_dict
112+
if self._native_storage:
113+
data["value"]["flattened"] = value
126114
else:
127-
document["value"]["string"] = entry.value_as_json
128-
129-
if entry.created_at:
130-
document["created_at"] = entry.created_at.isoformat()
131-
if entry.expires_at:
132-
document["expires_at"] = entry.expires_at.isoformat()
133-
134-
return document
115+
data["value"]["string"] = value
135116

136-
def from_storage(self, data: dict[str, Any] | str) -> ManagedEntry:
137-
"""Convert an Elasticsearch document back to a ManagedEntry.
117+
return data
138118

139-
This method supports both native (flattened) and string storage modes,
140-
trying the flattened field first and falling back to the string field.
141-
This allows for seamless migration between storage modes.
142-
143-
Args:
144-
data: The Elasticsearch document to deserialize.
145-
146-
Returns:
147-
A ManagedEntry reconstructed from the document.
119+
@override
120+
def prepare_load(self, data: dict[str, Any]) -> dict[str, Any]:
121+
value = data.pop("value")
148122

149-
Raises:
150-
DeserializationError: If data is not a dict or is malformed.
151-
"""
152-
if not isinstance(data, dict):
153-
msg = "Expected Elasticsearch document to be a dict"
154-
raise DeserializationError(msg)
155-
156-
document = data
157-
value: dict[str, Any] = {}
158-
159-
raw_value = document.get("value")
160-
161-
# Try flattened field first, fall back to string field
162-
if not raw_value or not isinstance(raw_value, dict):
163-
msg = "Value field not found or invalid type"
164-
raise DeserializationError(msg)
165-
166-
if value_flattened := raw_value.get("flattened"): # pyright: ignore[reportUnknownVariableType, reportUnknownMemberType]
167-
value = verify_dict(obj=value_flattened)
168-
elif value_str := raw_value.get("string"): # pyright: ignore[reportUnknownVariableType, reportUnknownMemberType]
169-
if not isinstance(value_str, str):
170-
msg = "Value in `value` field is not a string"
171-
raise DeserializationError(msg)
172-
value = load_from_json(value_str)
123+
if flattened := value.get("flattened"):
124+
data["value"] = flattened
125+
elif string := value.get("string"):
126+
data["value"] = string
173127
else:
174-
msg = "Value field not found or invalid type"
175-
raise DeserializationError(msg)
176-
177-
created_at: datetime | None = try_parse_datetime_str(value=document.get("created_at"))
178-
expires_at: datetime | None = try_parse_datetime_str(value=document.get("expires_at"))
128+
msg = "Value field not found in Elasticsearch document"
129+
raise DeserializationError(message=msg)
179130

180-
return ManagedEntry(
181-
value=value,
182-
created_at=created_at,
183-
expires_at=expires_at,
184-
)
131+
return data
185132

186133

187134
class ElasticsearchStore(
@@ -262,7 +209,7 @@ def __init__(
262209
self._index_prefix = index_prefix
263210
self._native_storage = native_storage
264211
self._is_serverless = False
265-
self._adapter = ElasticsearchAdapter(native_storage=native_storage)
212+
self._adapter = ElasticsearchSerializationAdapter(native_storage=native_storage)
266213

267214
super().__init__(default_collection=default_collection)
268215

@@ -315,7 +262,7 @@ async def _get_managed_entry(self, *, key: str, collection: str) -> ManagedEntry
315262
return None
316263

317264
try:
318-
return self._adapter.from_storage(data=source)
265+
return self._adapter.load_dict(data=source)
319266
except DeserializationError:
320267
return None
321268

@@ -348,7 +295,7 @@ async def _get_managed_entries(self, *, collection: str, keys: Sequence[str]) ->
348295
continue
349296

350297
try:
351-
entries_by_id[doc_id] = self._adapter.from_storage(data=source)
298+
entries_by_id[doc_id] = self._adapter.load_dict(data=source)
352299
except DeserializationError as e:
353300
logger.error(
354301
"Failed to deserialize Elasticsearch document in batch operation",
@@ -379,10 +326,7 @@ async def _put_managed_entry(
379326
index_name: str = self._sanitize_index_name(collection=collection)
380327
document_id: str = self._sanitize_document_id(key=key)
381328

382-
document: dict[str, Any] = self._adapter.to_storage(key=key, entry=managed_entry, collection=collection)
383-
if not isinstance(document, dict):
384-
msg = "Elasticsearch adapter must return dict"
385-
raise TypeError(msg)
329+
document: dict[str, Any] = self._adapter.dump_dict(entry=managed_entry)
386330

387331
try:
388332
_ = await self._client.index(
@@ -420,12 +364,10 @@ async def _put_managed_entries(
420364

421365
index_action: dict[str, Any] = new_bulk_action(action="index", index=index_name, document_id=document_id)
422366

423-
document: dict[str, Any] = self._adapter.to_storage(key=key, entry=managed_entry, collection=collection)
424-
if not isinstance(document, dict):
425-
msg = "Elasticsearch adapter must return dict"
426-
raise TypeError(msg)
367+
document: dict[str, Any] = self._adapter.dump_dict(entry=managed_entry)
427368

428369
operations.extend([index_action, document])
370+
429371
try:
430372
_ = await self._client.bulk(operations=operations, refresh=self._should_refresh_on_put) # pyright: ignore[reportUnknownMemberType]
431373
except ElasticsearchSerializationError as e:

0 commit comments

Comments
 (0)