Skip to content

Commit f8e7f01

Browse files
feat: add native storage mode for Elasticsearch using flattened fields
Add optional `native_storage` parameter to ElasticsearchStore that stores values as Elasticsearch flattened objects instead of JSON strings. Key changes: - Add `native_storage` parameter to constructor (defaults to False) - Update mapping to include `value_flattened` field (flattened type) - Update `managed_entry_to_document` to store in appropriate field - Update `source_to_managed_entry` to read from both fields for migration - Add comprehensive tests for native storage mode and migration Benefits: - More efficient storage and retrieval - Better aligned with Elasticsearch's document model - Enables basic querying on nested fields The dual-field approach (value + value_flattened) enables backward compatibility and gradual migration from legacy JSON string mode. Related to #87 Co-authored-by: William Easton <[email protected]>
1 parent db43abc commit f8e7f01

File tree

2 files changed

+168
-20
lines changed

2 files changed

+168
-20
lines changed

key-value/key-value-aio/src/key_value/aio/stores/elasticsearch/store.py

Lines changed: 57 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from collections.abc import Sequence
22
from datetime import datetime
3-
from typing import Any, overload
3+
from typing import Any, cast, overload
44

55
from elastic_transport import ObjectApiResponse # noqa: TC002
66
from key_value.shared.errors import DeserializationError
@@ -60,6 +60,9 @@
6060
"doc_values": False,
6161
"ignore_above": 256,
6262
},
63+
"value_flattened": {
64+
"type": "flattened",
65+
},
6366
},
6467
}
6568

@@ -73,13 +76,18 @@
7376
ALLOWED_INDEX_CHARACTERS: str = LOWERCASE_ALPHABET + NUMBERS + "_" + "-" + "."
7477

7578

76-
def managed_entry_to_document(collection: str, key: str, managed_entry: ManagedEntry) -> dict[str, Any]:
79+
def managed_entry_to_document(collection: str, key: str, managed_entry: ManagedEntry, *, native_storage: bool = False) -> dict[str, Any]:
7780
document: dict[str, Any] = {
7881
"collection": collection,
7982
"key": key,
80-
"value": managed_entry.to_json(include_metadata=False),
8183
}
8284

85+
# Store in appropriate field based on mode
86+
if native_storage:
87+
document["value_flattened"] = dict(managed_entry.value)
88+
else:
89+
document["value"] = managed_entry.to_json(include_metadata=False)
90+
8391
if managed_entry.created_at:
8492
document["created_at"] = managed_entry.created_at.isoformat()
8593
if managed_entry.expires_at:
@@ -89,15 +97,26 @@ def managed_entry_to_document(collection: str, key: str, managed_entry: ManagedE
8997

9098

9199
def source_to_managed_entry(source: dict[str, Any]) -> ManagedEntry:
92-
if not (value_str := source.get("value")) or not isinstance(value_str, str):
93-
msg = "Value is not a string"
100+
# Try flattened field first, fall back to string field
101+
value_flattened = source.get("value_flattened")
102+
value_str = source.get("value")
103+
104+
value: dict[str, Any]
105+
if value_flattened and isinstance(value_flattened, dict):
106+
# Native storage mode - cast to the correct type
107+
value = cast(dict[str, Any], value_flattened)
108+
elif value_str and isinstance(value_str, str):
109+
# Legacy JSON string mode
110+
value = load_from_json(value_str)
111+
else:
112+
msg = "Value field not found or invalid type"
94113
raise DeserializationError(msg)
95114

96115
created_at: datetime | None = try_parse_datetime_str(value=source.get("created_at"))
97116
expires_at: datetime | None = try_parse_datetime_str(value=source.get("expires_at"))
98117

99118
return ManagedEntry(
100-
value=load_from_json(value_str),
119+
value=value,
101120
created_at=created_at,
102121
expires_at=expires_at,
103122
)
@@ -114,11 +133,28 @@ class ElasticsearchStore(
114133

115134
_index_prefix: str
116135

136+
_native_storage: bool
137+
117138
@overload
118-
def __init__(self, *, elasticsearch_client: AsyncElasticsearch, index_prefix: str, default_collection: str | None = None) -> None: ...
139+
def __init__(
140+
self,
141+
*,
142+
elasticsearch_client: AsyncElasticsearch,
143+
index_prefix: str,
144+
native_storage: bool = False,
145+
default_collection: str | None = None,
146+
) -> None: ...
119147

120148
@overload
121-
def __init__(self, *, url: str, api_key: str | None = None, index_prefix: str, default_collection: str | None = None) -> None: ...
149+
def __init__(
150+
self,
151+
*,
152+
url: str,
153+
api_key: str | None = None,
154+
index_prefix: str,
155+
native_storage: bool = False,
156+
default_collection: str | None = None,
157+
) -> None: ...
122158

123159
def __init__(
124160
self,
@@ -127,6 +163,7 @@ def __init__(
127163
url: str | None = None,
128164
api_key: str | None = None,
129165
index_prefix: str,
166+
native_storage: bool = False,
130167
default_collection: str | None = None,
131168
) -> None:
132169
"""Initialize the elasticsearch store.
@@ -136,6 +173,8 @@ def __init__(
136173
url: The url of the elasticsearch cluster.
137174
api_key: The api key to use.
138175
index_prefix: The index prefix to use. Collections will be prefixed with this prefix.
176+
native_storage: Whether to use native storage mode (flattened field type) for values.
177+
Defaults to False for backward compatibility.
139178
default_collection: The default collection to use if no collection is provided.
140179
"""
141180
if elasticsearch_client is None and url is None:
@@ -153,6 +192,7 @@ def __init__(
153192
raise ValueError(msg)
154193

155194
self._index_prefix = index_prefix
195+
self._native_storage = native_storage
156196
self._is_serverless = False
157197

158198
super().__init__(default_collection=default_collection)
@@ -205,18 +245,11 @@ async def _get_managed_entry(self, *, key: str, collection: str) -> ManagedEntry
205245
if not (source := get_source_from_body(body=body)):
206246
return None
207247

208-
if not (value_str := source.get("value")) or not isinstance(value_str, str):
248+
try:
249+
return source_to_managed_entry(source=source)
250+
except DeserializationError:
209251
return None
210252

211-
created_at: datetime | None = try_parse_datetime_str(value=source.get("created_at"))
212-
expires_at: datetime | None = try_parse_datetime_str(value=source.get("expires_at"))
213-
214-
return ManagedEntry(
215-
value=load_from_json(value_str),
216-
created_at=created_at,
217-
expires_at=expires_at,
218-
)
219-
220253
@override
221254
async def _get_managed_entries(self, *, collection: str, keys: Sequence[str]) -> list[ManagedEntry | None]:
222255
if not keys:
@@ -265,7 +298,9 @@ async def _put_managed_entry(
265298
index_name: str = self._sanitize_index_name(collection=collection)
266299
document_id: str = self._sanitize_document_id(key=key)
267300

268-
document: dict[str, Any] = managed_entry_to_document(collection=collection, key=key, managed_entry=managed_entry)
301+
document: dict[str, Any] = managed_entry_to_document(
302+
collection=collection, key=key, managed_entry=managed_entry, native_storage=self._native_storage
303+
)
269304

270305
_ = await self._client.index(
271306
index=index_name,
@@ -297,7 +332,9 @@ async def _put_managed_entries(
297332

298333
index_action: dict[str, Any] = new_bulk_action(action="index", index=index_name, document_id=document_id)
299334

300-
document: dict[str, Any] = managed_entry_to_document(collection=collection, key=key, managed_entry=managed_entry)
335+
document: dict[str, Any] = managed_entry_to_document(
336+
collection=collection, key=key, managed_entry=managed_entry, native_storage=self._native_storage
337+
)
301338

302339
operations.extend([index_action, document])
303340

key-value/key-value-aio/tests/stores/elasticsearch/test_elasticsearch.py

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,31 @@ def test_managed_entry_document_conversion():
6969
assert round_trip_managed_entry.expires_at == expires_at
7070

7171

72+
def test_managed_entry_document_conversion_native_storage():
73+
created_at = datetime(year=2025, month=1, day=1, hour=0, minute=0, second=0, tzinfo=timezone.utc)
74+
expires_at = created_at + timedelta(seconds=10)
75+
76+
managed_entry = ManagedEntry(value={"test": "test"}, created_at=created_at, expires_at=expires_at)
77+
document = managed_entry_to_document(collection="test_collection", key="test_key", managed_entry=managed_entry, native_storage=True)
78+
79+
assert document == snapshot(
80+
{
81+
"collection": "test_collection",
82+
"key": "test_key",
83+
"value_flattened": {"test": "test"},
84+
"created_at": "2025-01-01T00:00:00+00:00",
85+
"expires_at": "2025-01-01T00:00:10+00:00",
86+
}
87+
)
88+
89+
round_trip_managed_entry = source_to_managed_entry(source=document)
90+
91+
assert round_trip_managed_entry.value == managed_entry.value
92+
assert round_trip_managed_entry.created_at == created_at
93+
assert round_trip_managed_entry.ttl == IsFloat(lt=0)
94+
assert round_trip_managed_entry.expires_at == expires_at
95+
96+
7297
@pytest.mark.skipif(should_skip_docker_tests(), reason="Docker is not running")
7398
class TestElasticsearchStore(ContextManagerStoreTestMixin, BaseStoreTests):
7499
@pytest.fixture(autouse=True, scope="session", params=ELASTICSEARCH_VERSIONS_TO_TEST)
@@ -121,3 +146,89 @@ async def test_put_put_two_indices(self, store: ElasticsearchStore, es_client: A
121146
assert len(indices.body) == 2
122147
assert "kv-store-e2e-test-test_collection" in indices
123148
assert "kv-store-e2e-test-test_collection_2" in indices
149+
150+
151+
@pytest.mark.skipif(should_skip_docker_tests(), reason="Docker is not running")
152+
class TestElasticsearchStoreNativeMode(ContextManagerStoreTestMixin, BaseStoreTests):
153+
"""Test ElasticsearchStore with native_storage=True"""
154+
155+
@pytest.fixture(autouse=True, scope="session", params=ELASTICSEARCH_VERSIONS_TO_TEST)
156+
async def setup_elasticsearch(self, request: pytest.FixtureRequest) -> AsyncGenerator[None, None]:
157+
version = request.param
158+
es_image = f"docker.elastic.co/elasticsearch/elasticsearch:{version}"
159+
160+
with docker_container(
161+
f"elasticsearch-test-native-{version}",
162+
es_image,
163+
{str(ES_CONTAINER_PORT): ES_PORT},
164+
{"discovery.type": "single-node", "xpack.security.enabled": "false"},
165+
):
166+
if not await async_wait_for_true(bool_fn=ping_elasticsearch, tries=WAIT_FOR_ELASTICSEARCH_TIMEOUT, wait_time=2):
167+
msg = f"Elasticsearch {version} failed to start"
168+
raise ElasticsearchFailedToStartError(msg)
169+
170+
yield
171+
172+
@pytest.fixture
173+
async def es_client(self) -> AsyncGenerator[AsyncElasticsearch, None]:
174+
async with AsyncElasticsearch(hosts=[ES_URL]) as es_client:
175+
yield es_client
176+
177+
@override
178+
@pytest.fixture
179+
async def store(self) -> AsyncGenerator[ElasticsearchStore, None]:
180+
async with get_elasticsearch_client() as es_client:
181+
indices = await es_client.options(ignore_status=404).indices.get(index="kv-store-native-test-*")
182+
for index in indices:
183+
_ = await es_client.options(ignore_status=404).indices.delete(index=index)
184+
async with ElasticsearchStore(url=ES_URL, index_prefix="kv-store-native-test", native_storage=True) as store:
185+
yield store
186+
187+
@pytest.mark.skip(reason="Distributed Caches are unbounded")
188+
@override
189+
async def test_not_unbounded(self, store: BaseStore): ...
190+
191+
@pytest.mark.skip(reason="Skip concurrent tests on distributed caches")
192+
@override
193+
async def test_concurrent_operations(self, store: BaseStore): ...
194+
195+
async def test_value_stored_as_flattened_object(self, store: ElasticsearchStore, es_client: AsyncElasticsearch):
196+
"""Verify values are stored as flattened objects, not JSON strings"""
197+
await store.put(collection="test", key="test_key", value={"name": "Alice", "age": 30})
198+
199+
# Check raw Elasticsearch document using public sanitization methods
200+
# Note: We need to access these internal methods for testing the storage format
201+
index_name = store._sanitize_index_name(collection="test") # pyright: ignore[reportPrivateUsage]
202+
doc_id = store._sanitize_document_id(key="test_key") # pyright: ignore[reportPrivateUsage]
203+
204+
response = await es_client.get(index=index_name, id=doc_id)
205+
source = response.body["_source"]
206+
207+
# Should have value_flattened field as dict
208+
assert "value_flattened" in source
209+
assert isinstance(source["value_flattened"], dict)
210+
assert source["value_flattened"] == {"name": "Alice", "age": 30}
211+
212+
# Should NOT have value field (JSON string)
213+
assert "value" not in source or source.get("value") is None
214+
215+
async def test_migration_from_legacy_mode(self, store: ElasticsearchStore, es_client: AsyncElasticsearch):
216+
"""Verify native mode can read legacy JSON string data"""
217+
# Manually insert a legacy document with JSON string value
218+
index_name = store._sanitize_index_name(collection="test") # pyright: ignore[reportPrivateUsage]
219+
doc_id = store._sanitize_document_id(key="legacy_key") # pyright: ignore[reportPrivateUsage]
220+
221+
await es_client.index(
222+
index=index_name,
223+
id=doc_id,
224+
body={
225+
"collection": "test",
226+
"key": "legacy_key",
227+
"value": '{"legacy": "data"}', # JSON string
228+
},
229+
)
230+
await es_client.indices.refresh(index=index_name)
231+
232+
# Should be able to read it in native mode
233+
result = await store.get(collection="test", key="legacy_key")
234+
assert result == {"legacy": "data"}

0 commit comments

Comments
 (0)