11from collections .abc import Sequence
22from datetime import datetime
3- from typing import Any , overload
3+ from typing import Any , cast , overload
44
55from elastic_transport import ObjectApiResponse # noqa: TC002
66from key_value .shared .errors import DeserializationError
6060 "doc_values" : False ,
6161 "ignore_above" : 256 ,
6262 },
63+ "value_flattened" : {
64+ "type" : "flattened" ,
65+ },
6366 },
6467}
6568
7376ALLOWED_INDEX_CHARACTERS : str = LOWERCASE_ALPHABET + NUMBERS + "_" + "-" + "."
7477
7578
76- def managed_entry_to_document (collection : str , key : str , managed_entry : ManagedEntry ) -> dict [str , Any ]:
79+ def managed_entry_to_document (collection : str , key : str , managed_entry : ManagedEntry , * , native_storage : bool = False ) -> dict [str , Any ]:
7780 document : dict [str , Any ] = {
7881 "collection" : collection ,
7982 "key" : key ,
80- "value" : managed_entry .to_json (include_metadata = False ),
8183 }
8284
85+ # Store in appropriate field based on mode
86+ if native_storage :
87+ document ["value_flattened" ] = dict (managed_entry .value )
88+ else :
89+ document ["value" ] = managed_entry .to_json (include_metadata = False )
90+
8391 if managed_entry .created_at :
8492 document ["created_at" ] = managed_entry .created_at .isoformat ()
8593 if managed_entry .expires_at :
@@ -89,15 +97,26 @@ def managed_entry_to_document(collection: str, key: str, managed_entry: ManagedE
8997
9098
9199def source_to_managed_entry (source : dict [str , Any ]) -> ManagedEntry :
92- if not (value_str := source .get ("value" )) or not isinstance (value_str , str ):
93- msg = "Value is not a string"
100+ # Try flattened field first, fall back to string field
101+ value_flattened = source .get ("value_flattened" )
102+ value_str = source .get ("value" )
103+
104+ value : dict [str , Any ]
105+ if value_flattened and isinstance (value_flattened , dict ):
106+ # Native storage mode - cast to the correct type
107+ value = cast (dict [str , Any ], value_flattened )
108+ elif value_str and isinstance (value_str , str ):
109+ # Legacy JSON string mode
110+ value = load_from_json (value_str )
111+ else :
112+ msg = "Value field not found or invalid type"
94113 raise DeserializationError (msg )
95114
96115 created_at : datetime | None = try_parse_datetime_str (value = source .get ("created_at" ))
97116 expires_at : datetime | None = try_parse_datetime_str (value = source .get ("expires_at" ))
98117
99118 return ManagedEntry (
100- value = load_from_json ( value_str ) ,
119+ value = value ,
101120 created_at = created_at ,
102121 expires_at = expires_at ,
103122 )
@@ -114,11 +133,28 @@ class ElasticsearchStore(
114133
115134 _index_prefix : str
116135
136+ _native_storage : bool
137+
117138 @overload
118- def __init__ (self , * , elasticsearch_client : AsyncElasticsearch , index_prefix : str , default_collection : str | None = None ) -> None : ...
139+ def __init__ (
140+ self ,
141+ * ,
142+ elasticsearch_client : AsyncElasticsearch ,
143+ index_prefix : str ,
144+ native_storage : bool = False ,
145+ default_collection : str | None = None ,
146+ ) -> None : ...
119147
120148 @overload
121- def __init__ (self , * , url : str , api_key : str | None = None , index_prefix : str , default_collection : str | None = None ) -> None : ...
149+ def __init__ (
150+ self ,
151+ * ,
152+ url : str ,
153+ api_key : str | None = None ,
154+ index_prefix : str ,
155+ native_storage : bool = False ,
156+ default_collection : str | None = None ,
157+ ) -> None : ...
122158
123159 def __init__ (
124160 self ,
@@ -127,6 +163,7 @@ def __init__(
127163 url : str | None = None ,
128164 api_key : str | None = None ,
129165 index_prefix : str ,
166+ native_storage : bool = False ,
130167 default_collection : str | None = None ,
131168 ) -> None :
132169 """Initialize the elasticsearch store.
@@ -136,6 +173,8 @@ def __init__(
136173 url: The url of the elasticsearch cluster.
137174 api_key: The api key to use.
138175 index_prefix: The index prefix to use. Collections will be prefixed with this prefix.
176+ native_storage: Whether to use native storage mode (flattened field type) for values.
177+ Defaults to False for backward compatibility.
139178 default_collection: The default collection to use if no collection is provided.
140179 """
141180 if elasticsearch_client is None and url is None :
@@ -153,6 +192,7 @@ def __init__(
153192 raise ValueError (msg )
154193
155194 self ._index_prefix = index_prefix
195+ self ._native_storage = native_storage
156196 self ._is_serverless = False
157197
158198 super ().__init__ (default_collection = default_collection )
@@ -205,18 +245,11 @@ async def _get_managed_entry(self, *, key: str, collection: str) -> ManagedEntry
205245 if not (source := get_source_from_body (body = body )):
206246 return None
207247
208- if not (value_str := source .get ("value" )) or not isinstance (value_str , str ):
248+ try :
249+ return source_to_managed_entry (source = source )
250+ except DeserializationError :
209251 return None
210252
211- created_at : datetime | None = try_parse_datetime_str (value = source .get ("created_at" ))
212- expires_at : datetime | None = try_parse_datetime_str (value = source .get ("expires_at" ))
213-
214- return ManagedEntry (
215- value = load_from_json (value_str ),
216- created_at = created_at ,
217- expires_at = expires_at ,
218- )
219-
220253 @override
221254 async def _get_managed_entries (self , * , collection : str , keys : Sequence [str ]) -> list [ManagedEntry | None ]:
222255 if not keys :
@@ -265,7 +298,9 @@ async def _put_managed_entry(
265298 index_name : str = self ._sanitize_index_name (collection = collection )
266299 document_id : str = self ._sanitize_document_id (key = key )
267300
268- document : dict [str , Any ] = managed_entry_to_document (collection = collection , key = key , managed_entry = managed_entry )
301+ document : dict [str , Any ] = managed_entry_to_document (
302+ collection = collection , key = key , managed_entry = managed_entry , native_storage = self ._native_storage
303+ )
269304
270305 _ = await self ._client .index (
271306 index = index_name ,
@@ -297,7 +332,9 @@ async def _put_managed_entries(
297332
298333 index_action : dict [str , Any ] = new_bulk_action (action = "index" , index = index_name , document_id = document_id )
299334
300- document : dict [str , Any ] = managed_entry_to_document (collection = collection , key = key , managed_entry = managed_entry )
335+ document : dict [str , Any ] = managed_entry_to_document (
336+ collection = collection , key = key , managed_entry = managed_entry , native_storage = self ._native_storage
337+ )
301338
302339 operations .extend ([index_action , document ])
303340
0 commit comments