Skip to content

Commit 1553dfc

Browse files
fix: address CodeRabbit review feedback
- Elasticsearch: disable doc_values for JSON mode, enhance mapping validation, fix empty dict handling - MongoDB: add inverse TTL validation, explicit type errors, fix expires_at=None bug - Fix line length issues and apply linting suggestions Co-authored-by: William Easton <[email protected]>
1 parent f08fe2f commit 1553dfc

File tree

2 files changed

+96
-44
lines changed
  • key-value/key-value-aio/src/key_value/aio/stores

2 files changed

+96
-44
lines changed

key-value/key-value-aio/src/key_value/aio/stores/elasticsearch/store.py

Lines changed: 57 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -64,10 +64,25 @@ class ElasticsearchStore(
6464
_native_storage: bool
6565

6666
@overload
67-
def __init__(self, *, elasticsearch_client: AsyncElasticsearch, index_prefix: str, default_collection: str | None = None, native_storage: bool = False) -> None: ...
67+
def __init__(
68+
self,
69+
*,
70+
elasticsearch_client: AsyncElasticsearch,
71+
index_prefix: str,
72+
default_collection: str | None = None,
73+
native_storage: bool = False,
74+
) -> None: ...
6875

6976
@overload
70-
def __init__(self, *, url: str, api_key: str | None = None, index_prefix: str, default_collection: str | None = None, native_storage: bool = False) -> None: ...
77+
def __init__(
78+
self,
79+
*,
80+
url: str,
81+
api_key: str | None = None,
82+
index_prefix: str,
83+
default_collection: str | None = None,
84+
native_storage: bool = False,
85+
) -> None: ...
7186

7287
def __init__(
7388
self,
@@ -142,7 +157,7 @@ async def _setup_collection(self, *, collection: str) -> None:
142157
},
143158
"value": {
144159
"type": "flattened" if self._native_storage else "keyword",
145-
**({"index": False} if not self._native_storage else {}),
160+
**({"index": False, "doc_values": False} if not self._native_storage else {}),
146161
},
147162
},
148163
}
@@ -154,7 +169,10 @@ async def _validate_index_mapping(self, *, index_name: str, collection: str) ->
154169
try:
155170
mapping_response = await self._client.indices.get_mapping(index=index_name)
156171
mappings = mapping_response.get(index_name, {}).get("mappings", {})
157-
value_field_type = mappings.get("properties", {}).get("value", {}).get("type")
172+
props = mappings.get("properties", {})
173+
value_field_type = props.get("value", {}).get("type")
174+
created_type = props.get("created_at", {}).get("type")
175+
expires_type = props.get("expires_at", {}).get("type")
158176

159177
expected_type = "flattened" if self._native_storage else "keyword"
160178

@@ -166,16 +184,24 @@ async def _validate_index_mapping(self, *, index_name: str, collection: str) ->
166184
f"To fix this, either: 1) Use the correct storage mode when initializing the store, "
167185
f"or 2) Delete and recreate the index with the new mapping."
168186
)
169-
raise ValueError(msg)
170-
except Exception as e:
171-
# If we can't get the mapping, log a warning but don't fail
172-
# This allows the store to work even if mapping validation fails
173-
if not isinstance(e, ValueError):
174-
# Only suppress non-ValueError exceptions (e.g., connection issues)
175-
pass
176-
else:
177-
# Re-raise ValueError from our validation
178-
raise
187+
raise ValueError(msg) # noqa: TRY301
188+
189+
# Enforce date types for timestamps (both modes)
190+
for field_name, field_type in (("created_at", created_type), ("expires_at", expires_type)):
191+
if field_type not in ("date", None): # None => not yet created; will be added on first write
192+
msg = (
193+
f"Index mapping mismatch for collection '{collection}': "
194+
f"'{field_name}' is mapped as '{field_type}', expected 'date'. "
195+
f"Delete and recreate the index or fix the mapping."
196+
)
197+
raise ValueError(msg) # noqa: TRY301
198+
except ValueError:
199+
raise
200+
except Exception:
201+
# Log a warning but do not fail hard (keep behavior)
202+
import logging
203+
204+
logging.getLogger(__name__).warning("Failed to validate mapping for index '%s' (collection '%s')", index_name, collection)
179205

180206
def _sanitize_index_name(self, collection: str) -> str:
181207
return sanitize_string(
@@ -208,8 +234,9 @@ async def _get_managed_entry(self, *, key: str, collection: str) -> ManagedEntry
208234

209235
if self._native_storage:
210236
# Native storage mode: Get value as flattened object
211-
if not (value := source.get("value")):
237+
if "value" not in source:
212238
return None
239+
value = source["value"]
213240

214241
# Detect if data is in JSON string format
215242
if isinstance(value, str):
@@ -218,7 +245,7 @@ async def _get_managed_entry(self, *, key: str, collection: str) -> ManagedEntry
218245
"for native_storage mode. This indicates a storage mode mismatch. "
219246
"You may need to migrate existing data or use the correct storage mode."
220247
)
221-
raise ValueError(msg)
248+
raise TypeError(msg)
222249

223250
if not isinstance(value, dict):
224251
return None
@@ -231,23 +258,22 @@ async def _get_managed_entry(self, *, key: str, collection: str) -> ManagedEntry
231258
created_at=created_at,
232259
expires_at=expires_at,
233260
)
234-
else:
235-
# JSON string mode: Get value as JSON string and parse it
236-
json_value: str | None = source.get("value")
237-
238-
# Detect if data is in native object format
239-
if isinstance(json_value, dict):
240-
msg = (
241-
f"Data for key '{key}' appears to be in native object format, but store is configured "
242-
"for JSON string mode. This indicates a storage mode mismatch. "
243-
"You may need to migrate existing data or use the correct storage mode."
244-
)
245-
raise ValueError(msg)
261+
# JSON string mode: Get value as JSON string and parse it
262+
json_value: str | None = source.get("value")
263+
264+
# Detect if data is in native object format
265+
if isinstance(json_value, dict):
266+
msg = (
267+
f"Data for key '{key}' appears to be in native object format, but store is configured "
268+
"for JSON string mode. This indicates a storage mode mismatch. "
269+
"You may need to migrate existing data or use the correct storage mode."
270+
)
271+
raise TypeError(msg)
246272

247-
if not isinstance(json_value, str):
248-
return None
273+
if not isinstance(json_value, str):
274+
return None
249275

250-
return ManagedEntry.from_json(json_str=json_value)
276+
return ManagedEntry.from_json(json_str=json_value)
251277

252278
@property
253279
def _should_refresh_on_put(self) -> bool:

key-value/key-value-aio/src/key_value/aio/stores/mongodb/store.py

Lines changed: 39 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,13 @@ def __init__(
7171

7272
@overload
7373
def __init__(
74-
self, *, url: str, db_name: str | None = None, coll_name: str | None = None, default_collection: str | None = None, native_storage: bool = False
74+
self,
75+
*,
76+
url: str,
77+
db_name: str | None = None,
78+
coll_name: str | None = None,
79+
default_collection: str | None = None,
80+
native_storage: bool = False,
7581
) -> None:
7682
"""Initialize the MongoDB store.
7783
@@ -175,11 +181,19 @@ async def _validate_collection_indexes(self, *, collection: str) -> None:
175181
f"To fix this, either: 1) Recreate the collection with native_storage=True, "
176182
f"or 2) Manually create the TTL index: db.{collection}.createIndex({{expires_at: 1}}, {{expireAfterSeconds: 0}})"
177183
)
178-
raise ValueError(msg)
184+
raise ValueError(msg) # noqa: TRY301
185+
if not self._native_storage and has_ttl_index:
186+
msg = (
187+
f"Collection '{collection}' has a TTL index on 'expires_at' field, "
188+
f"but store is configured for JSON string mode (native_storage=False). "
189+
f"This may cause unexpected behavior. Consider either: "
190+
f"1) Using native_storage=True, or 2) Dropping the TTL index."
191+
)
192+
raise ValueError(msg) # noqa: TRY301
179193
except ValueError:
180194
# Re-raise our validation errors
181195
raise
182-
except Exception:
196+
except Exception: # noqa: S110
183197
# Suppress other errors (e.g., connection issues) to allow store to work
184198
pass
185199

@@ -197,7 +211,11 @@ async def _get_managed_entry(self, *, key: str, collection: str) -> ManagedEntry
197211
value: dict[str, Any] | None = doc.get("value")
198212

199213
if not isinstance(value, dict):
200-
return None
214+
msg = (
215+
f"Data for key '{key}' has invalid value type: expected dict but got {type(value).__name__}. "
216+
f"This may indicate the collection contains JSON-mode data but native_storage=True."
217+
)
218+
raise TypeError(msg)
201219

202220
# Parse datetime objects directly and validate types
203221
created_at: datetime | None = doc.get("created_at")
@@ -222,14 +240,17 @@ async def _get_managed_entry(self, *, key: str, collection: str) -> ManagedEntry
222240
created_at=created_at,
223241
expires_at=expires_at,
224242
)
225-
else:
226-
# JSON string mode: Get value as JSON string and parse it
227-
json_value: str | None = doc.get("value")
243+
# JSON string mode: Get value as JSON string and parse it
244+
json_value: str | None = doc.get("value")
228245

229-
if not isinstance(json_value, str):
230-
return None
246+
if not isinstance(json_value, str):
247+
msg = (
248+
f"Data for key '{key}' has invalid value type: expected str but got {type(json_value).__name__}. "
249+
f"This may indicate the collection contains native-mode data but native_storage=False."
250+
)
251+
raise TypeError(msg)
231252

232-
return ManagedEntry.from_json(json_str=json_value)
253+
return ManagedEntry.from_json(json_str=json_value)
233254

234255
@override
235256
async def _put_managed_entry(
@@ -253,14 +274,19 @@ async def _put_managed_entry(
253274
# Store as datetime objects (use $setOnInsert for immutable fields)
254275
if managed_entry.created_at:
255276
set_on_insert_fields["created_at"] = managed_entry.created_at
256-
if managed_entry.expires_at:
257-
# expires_at can change, so use $set
258-
set_fields["expires_at"] = managed_entry.expires_at
259277

278+
# Build update document
260279
update_doc: dict[str, Any] = {"$set": set_fields}
261280
if set_on_insert_fields:
262281
update_doc["$setOnInsert"] = set_on_insert_fields
263282

283+
# Always handle expires_at to support removing expiration
284+
if managed_entry.expires_at is not None:
285+
set_fields["expires_at"] = managed_entry.expires_at
286+
else:
287+
# Use $unset to remove the field when expires_at is None
288+
update_doc["$unset"] = {"expires_at": ""}
289+
264290
_ = await self._collections_by_name[collection].update_one(
265291
filter={"key": key},
266292
update=update_doc,

0 commit comments

Comments
 (0)