|
6 | 6 | from elastic_transport import ObjectApiResponse |
7 | 7 | from elastic_transport import SerializationError as ElasticsearchSerializationError |
8 | 8 | from key_value.shared.errors import DeserializationError, SerializationError |
9 | | -from key_value.shared.utils.managed_entry import ManagedEntry, load_from_json, verify_dict |
| 9 | +from key_value.shared.utils.managed_entry import ManagedEntry |
10 | 10 | from key_value.shared.utils.sanitize import ( |
11 | 11 | ALPHANUMERIC_CHARACTERS, |
12 | 12 | LOWERCASE_ALPHABET, |
13 | 13 | NUMBERS, |
14 | 14 | sanitize_string, |
15 | 15 | ) |
16 | 16 | from key_value.shared.utils.serialization import SerializationAdapter |
17 | | -from key_value.shared.utils.time_to_live import now_as_epoch, try_parse_datetime_str |
| 17 | +from key_value.shared.utils.time_to_live import now_as_epoch |
18 | 18 | from typing_extensions import override |
19 | 19 |
|
20 | 20 | from key_value.aio.stores.base import ( |
|
85 | 85 | ALLOWED_INDEX_CHARACTERS: str = LOWERCASE_ALPHABET + NUMBERS + "_" + "-" + "." |
86 | 86 |
|
87 | 87 |
|
88 | | -class ElasticsearchAdapter(SerializationAdapter): |
89 | | - """Adapter for Elasticsearch with support for native and string storage modes. |
| 88 | +class ElasticsearchSerializationAdapter(SerializationAdapter): |
| 89 | + """Adapter for Elasticsearch with support for native and string storage modes.""" |
90 | 90 |
|
91 | | - This adapter supports two storage modes: |
92 | | - - Native mode: Stores values as flattened dicts for efficient querying |
93 | | - - String mode: Stores values as JSON strings for backward compatibility |
94 | | -
|
95 | | - Elasticsearch-specific features: |
96 | | - - Stores collection name in the document for multi-tenancy |
97 | | - - Uses ISO format for datetime fields |
98 | | - - Supports migration between storage modes |
99 | | - """ |
| 91 | + _native_storage: bool |
100 | 92 |
|
101 | 93 | def __init__(self, *, native_storage: bool = True) -> None: |
102 | 94 | """Initialize the Elasticsearch adapter. |
103 | 95 |
|
104 | 96 | Args: |
105 | 97 | native_storage: If True (default), store values as flattened dicts. |
106 | | - If False, store values as JSON strings. |
| 98 | + If False, store values as JSON strings. |
107 | 99 | """ |
108 | | - self.native_storage = native_storage |
| 100 | + super().__init__() |
109 | 101 |
|
110 | | - def to_storage(self, key: str, entry: ManagedEntry, collection: str | None = None) -> dict[str, Any]: |
111 | | - """Convert a ManagedEntry to an Elasticsearch document. |
| 102 | + self._native_storage = native_storage |
| 103 | + self._date_format = "isoformat" |
| 104 | + self._value_format = "dict" if native_storage else "string" |
112 | 105 |
|
113 | | - Args: |
114 | | - key: The key associated with this entry. |
115 | | - entry: The ManagedEntry to serialize. |
116 | | - collection: The collection name to store in the document. |
| 106 | + @override |
| 107 | + def prepare_dump(self, data: dict[str, Any]) -> dict[str, Any]: |
| 108 | + value = data.pop("value") |
117 | 109 |
|
118 | | - Returns: |
119 | | - An Elasticsearch document dict with collection, key, value, and metadata. |
120 | | - """ |
121 | | - document: dict[str, Any] = {"collection": collection or "", "key": key, "value": {}} |
| 110 | + data["value"] = {} |
122 | 111 |
|
123 | | - # Store in appropriate field based on mode |
124 | | - if self.native_storage: |
125 | | - document["value"]["flattened"] = entry.value_as_dict |
| 112 | + if self._native_storage: |
| 113 | + data["value"]["flattened"] = value |
126 | 114 | else: |
127 | | - document["value"]["string"] = entry.value_as_json |
128 | | - |
129 | | - if entry.created_at: |
130 | | - document["created_at"] = entry.created_at.isoformat() |
131 | | - if entry.expires_at: |
132 | | - document["expires_at"] = entry.expires_at.isoformat() |
133 | | - |
134 | | - return document |
| 115 | + data["value"]["string"] = value |
135 | 116 |
|
136 | | - def from_storage(self, data: dict[str, Any] | str) -> ManagedEntry: |
137 | | - """Convert an Elasticsearch document back to a ManagedEntry. |
| 117 | + return data |
138 | 118 |
|
139 | | - This method supports both native (flattened) and string storage modes, |
140 | | - trying the flattened field first and falling back to the string field. |
141 | | - This allows for seamless migration between storage modes. |
142 | | -
|
143 | | - Args: |
144 | | - data: The Elasticsearch document to deserialize. |
145 | | -
|
146 | | - Returns: |
147 | | - A ManagedEntry reconstructed from the document. |
| 119 | + @override |
| 120 | + def prepare_load(self, data: dict[str, Any]) -> dict[str, Any]: |
| 121 | + value = data.pop("value") |
148 | 122 |
|
149 | | - Raises: |
150 | | - DeserializationError: If data is not a dict or is malformed. |
151 | | - """ |
152 | | - if not isinstance(data, dict): |
153 | | - msg = "Expected Elasticsearch document to be a dict" |
154 | | - raise DeserializationError(msg) |
155 | | - |
156 | | - document = data |
157 | | - value: dict[str, Any] = {} |
158 | | - |
159 | | - raw_value = document.get("value") |
160 | | - |
161 | | - # Try flattened field first, fall back to string field |
162 | | - if not raw_value or not isinstance(raw_value, dict): |
163 | | - msg = "Value field not found or invalid type" |
164 | | - raise DeserializationError(msg) |
165 | | - |
166 | | - if value_flattened := raw_value.get("flattened"): # pyright: ignore[reportUnknownVariableType, reportUnknownMemberType] |
167 | | - value = verify_dict(obj=value_flattened) |
168 | | - elif value_str := raw_value.get("string"): # pyright: ignore[reportUnknownVariableType, reportUnknownMemberType] |
169 | | - if not isinstance(value_str, str): |
170 | | - msg = "Value in `value` field is not a string" |
171 | | - raise DeserializationError(msg) |
172 | | - value = load_from_json(value_str) |
| 123 | + if flattened := value.get("flattened"): |
| 124 | + data["value"] = flattened |
| 125 | + elif string := value.get("string"): |
| 126 | + data["value"] = string |
173 | 127 | else: |
174 | | - msg = "Value field not found or invalid type" |
175 | | - raise DeserializationError(msg) |
176 | | - |
177 | | - created_at: datetime | None = try_parse_datetime_str(value=document.get("created_at")) |
178 | | - expires_at: datetime | None = try_parse_datetime_str(value=document.get("expires_at")) |
| 128 | + msg = "Value field not found in Elasticsearch document" |
| 129 | + raise DeserializationError(message=msg) |
179 | 130 |
|
180 | | - return ManagedEntry( |
181 | | - value=value, |
182 | | - created_at=created_at, |
183 | | - expires_at=expires_at, |
184 | | - ) |
| 131 | + return data |
185 | 132 |
|
186 | 133 |
|
187 | 134 | class ElasticsearchStore( |
@@ -262,7 +209,7 @@ def __init__( |
262 | 209 | self._index_prefix = index_prefix |
263 | 210 | self._native_storage = native_storage |
264 | 211 | self._is_serverless = False |
265 | | - self._adapter = ElasticsearchAdapter(native_storage=native_storage) |
| 212 | + self._adapter = ElasticsearchSerializationAdapter(native_storage=native_storage) |
266 | 213 |
|
267 | 214 | super().__init__(default_collection=default_collection) |
268 | 215 |
|
@@ -315,7 +262,7 @@ async def _get_managed_entry(self, *, key: str, collection: str) -> ManagedEntry |
315 | 262 | return None |
316 | 263 |
|
317 | 264 | try: |
318 | | - return self._adapter.from_storage(data=source) |
| 265 | + return self._adapter.load_dict(data=source) |
319 | 266 | except DeserializationError: |
320 | 267 | return None |
321 | 268 |
|
@@ -348,7 +295,7 @@ async def _get_managed_entries(self, *, collection: str, keys: Sequence[str]) -> |
348 | 295 | continue |
349 | 296 |
|
350 | 297 | try: |
351 | | - entries_by_id[doc_id] = self._adapter.from_storage(data=source) |
| 298 | + entries_by_id[doc_id] = self._adapter.load_dict(data=source) |
352 | 299 | except DeserializationError as e: |
353 | 300 | logger.error( |
354 | 301 | "Failed to deserialize Elasticsearch document in batch operation", |
@@ -379,10 +326,7 @@ async def _put_managed_entry( |
379 | 326 | index_name: str = self._sanitize_index_name(collection=collection) |
380 | 327 | document_id: str = self._sanitize_document_id(key=key) |
381 | 328 |
|
382 | | - document: dict[str, Any] = self._adapter.to_storage(key=key, entry=managed_entry, collection=collection) |
383 | | - if not isinstance(document, dict): |
384 | | - msg = "Elasticsearch adapter must return dict" |
385 | | - raise TypeError(msg) |
| 329 | + document: dict[str, Any] = self._adapter.dump_dict(entry=managed_entry) |
386 | 330 |
|
387 | 331 | try: |
388 | 332 | _ = await self._client.index( |
@@ -420,12 +364,10 @@ async def _put_managed_entries( |
420 | 364 |
|
421 | 365 | index_action: dict[str, Any] = new_bulk_action(action="index", index=index_name, document_id=document_id) |
422 | 366 |
|
423 | | - document: dict[str, Any] = self._adapter.to_storage(key=key, entry=managed_entry, collection=collection) |
424 | | - if not isinstance(document, dict): |
425 | | - msg = "Elasticsearch adapter must return dict" |
426 | | - raise TypeError(msg) |
| 367 | + document: dict[str, Any] = self._adapter.dump_dict(entry=managed_entry) |
427 | 368 |
|
428 | 369 | operations.extend([index_action, document]) |
| 370 | + |
429 | 371 | try: |
430 | 372 | _ = await self._client.bulk(operations=operations, refresh=self._should_refresh_on_put) # pyright: ignore[reportUnknownMemberType] |
431 | 373 | except ElasticsearchSerializationError as e: |
|
0 commit comments