Skip to content

Commit fd99815

Browse files
committed
PR Feedback
1 parent d6849e3 commit fd99815

File tree

7 files changed

+80
-66
lines changed

7 files changed

+80
-66
lines changed

key-value/key-value-aio/src/key_value/aio/stores/elasticsearch/store.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
BaseEnumerateKeysStore,
2424
BaseStore,
2525
)
26-
from key_value.aio.stores.elasticsearch.utils import LessCapableJsonSerializer, NdjsonSerializer, new_bulk_action
26+
from key_value.aio.stores.elasticsearch.utils import LessCapableJsonSerializer, LessCapableNdjsonSerializer, new_bulk_action
2727

2828
try:
2929
from elasticsearch import AsyncElasticsearch
@@ -99,14 +99,16 @@ def managed_entry_to_document(collection: str, key: str, managed_entry: ManagedE
9999
def source_to_managed_entry(source: dict[str, Any]) -> ManagedEntry:
100100
value: dict[str, Any] = {}
101101

102+
raw_value = source.get("value")
103+
102104
# Try flattened field first, fall back to string field
103-
if not (value := source.get("value")) or not isinstance(value, dict):
105+
if not raw_value or not isinstance(raw_value, dict):
104106
msg = "Value field not found or invalid type"
105107
raise DeserializationError(msg)
106108

107-
if value_flattened := value.get("flattened"):
109+
if value_flattened := raw_value.get("flattened"): # pyright: ignore[reportUnknownVariableType, reportUnknownMemberType]
108110
value = verify_dict(obj=value_flattened)
109-
elif value_str := value.get("string"):
111+
elif value_str := raw_value.get("string"): # pyright: ignore[reportUnknownVariableType, reportUnknownMemberType]
110112
if not isinstance(value_str, str):
111113
msg = "Value in `value` field is not a string"
112114
raise DeserializationError(msg)
@@ -196,7 +198,7 @@ def __init__(
196198

197199
LessCapableJsonSerializer.install_serializer(client=self._client)
198200
LessCapableJsonSerializer.install_default_serializer(client=self._client)
199-
NdjsonSerializer.install_serializer(client=self._client)
201+
LessCapableNdjsonSerializer.install_serializer(client=self._client)
200202

201203
self._index_prefix = index_prefix
202204
self._native_storage = native_storage

key-value/key-value-aio/src/key_value/aio/stores/elasticsearch/utils.py

Lines changed: 35 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
from typing import Any, ClassVar, TypeVar, cast
22

3-
from elastic_transport import JsonSerializer, NdjsonSerializer, ObjectApiResponse, SerializationError
3+
from elastic_transport import (
4+
JsonSerializer,
5+
NdjsonSerializer,
6+
ObjectApiResponse,
7+
SerializationError,
8+
)
49

510
from elasticsearch import AsyncElasticsearch
611

@@ -29,7 +34,10 @@ def get_aggregations_from_body(body: dict[str, Any]) -> dict[str, Any]:
2934
if not (aggregations := body.get("aggregations")):
3035
return {}
3136

32-
if not isinstance(aggregations, dict) or not all(isinstance(key, str) for key in aggregations): # pyright: ignore[reportUnknownVariableType]
37+
if not isinstance(aggregations, dict) or not all(
38+
isinstance(key, str)
39+
for key in aggregations # pyright: ignore[reportUnknownVariableType]
40+
): # pyright: ignore[reportUnknownVariableType]
3341
return {}
3442

3543
return cast("dict[str, Any]", aggregations)
@@ -113,44 +121,46 @@ def new_bulk_action(action: str, index: str, document_id: str) -> dict[str, Any]
113121
return {action: {"_index": index, "_id": document_id}}
114122

115123

116-
class InstallSerializerMixin:
117-
"""A mixin that installs the serializer into the transport."""
124+
class LessCapableJsonSerializer(JsonSerializer):
125+
"""A JSON Serializer that doesnt try to be smart with datetime, floats, etc."""
118126

119-
mimetype: ClassVar[str]
120-
compatibility_mimetype: ClassVar[str]
127+
mimetype: ClassVar[str] = "application/json"
128+
compatibility_mimetype: ClassVar[str] = "application/vnd.elasticsearch+json"
121129

122-
@classmethod
123-
def install_serializer(cls, client: AsyncElasticsearch) -> None:
124-
client.transport.serializers.serializers.update(
125-
{
126-
cls.mimetype: cls(),
127-
cls.compatibility_mimetype: cls(),
128-
}
130+
def default(self, data: Any) -> Any:
131+
raise SerializationError(
132+
message=f"Unable to serialize to JSON: {data!r} (type: {type(data).__name__})",
129133
)
130134

131135
@classmethod
132136
def install_default_serializer(cls, client: AsyncElasticsearch) -> None:
133137
cls.install_serializer(client=client)
134138
client.transport.serializers.default_serializer = cls()
135139

136-
137-
class LessCapableJsonSerializer(InstallSerializerMixin, JsonSerializer):
138-
"""A JSON Serializer that doesnt try to be smart with datetime, floats, etc."""
139-
140-
mimetype: ClassVar[str] = "application/json"
141-
compatibility_mimetype: ClassVar[str] = "application/vnd.elasticsearch+json"
142-
143-
def default(self, data: Any) -> Any:
144-
raise SerializationError(
145-
message=f"Unable to serialize to JSON: {data!r} (type: {type(data).__name__})",
140+
@classmethod
141+
def install_serializer(cls, client: AsyncElasticsearch) -> None:
142+
client.transport.serializers.serializers.update(
143+
{
144+
cls.mimetype: cls(),
145+
cls.compatibility_mimetype: cls(),
146+
}
146147
)
147148

148149

149-
class NdjsonSerializer(InstallSerializerMixin, NdjsonSerializer):
150+
class LessCapableNdjsonSerializer(NdjsonSerializer):
150151
"""A NDJSON Serializer that doesnt try to be smart with datetime, floats, etc."""
151152

152153
mimetype: ClassVar[str] = "application/x-ndjson"
153154
compatibility_mimetype: ClassVar[str] = "application/vnd.elasticsearch+x-ndjson"
154155

155156
def default(self, data: Any) -> Any:
156-
return LessCapableJsonSerializer.default(self, data)
157+
return LessCapableJsonSerializer.default(data=data) # pyright: ignore[reportCallIssue, reportUnknownVariableType]
158+
159+
@classmethod
160+
def install_serializer(cls, client: AsyncElasticsearch) -> None:
161+
client.transport.serializers.serializers.update(
162+
{
163+
cls.mimetype: cls(),
164+
cls.compatibility_mimetype: cls(),
165+
}
166+
)

key-value/key-value-aio/tests/stores/elasticsearch/test_elasticsearch.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -185,9 +185,8 @@ async def test_value_stored_as_flattened_object(self, store: ElasticsearchStore,
185185
}
186186
)
187187

188-
async def test_migration_from_legacy_mode(self, store: ElasticsearchStore, es_client: AsyncElasticsearch):
189-
"""Verify native mode can read legacy JSON string data"""
190-
# Manually insert a legacy document with JSON string value
188+
async def test_migration_from_non_native_mode(self, store: ElasticsearchStore, es_client: AsyncElasticsearch):
189+
"""Verify native mode can read a document with stringified data"""
191190
index_name = store._sanitize_index_name(collection="test") # pyright: ignore[reportPrivateUsage]
192191
doc_id = store._sanitize_document_id(key="legacy_key") # pyright: ignore[reportPrivateUsage]
193192

@@ -197,14 +196,15 @@ async def test_migration_from_legacy_mode(self, store: ElasticsearchStore, es_cl
197196
body={
198197
"collection": "test",
199198
"key": "legacy_key",
200-
"value": '{"legacy": "data"}', # JSON string
199+
"value": {
200+
"string": '{"legacy": "data"}',
201+
},
201202
},
202203
)
203204
await es_client.indices.refresh(index=index_name)
204205

205-
# Should be able to read stringified values too
206206
result = await store.get(collection="test", key="legacy_key")
207-
assert result == snapshot(None)
207+
assert result == snapshot({"legacy": "data"})
208208

209209

210210
class TestElasticsearchStoreNonNativeMode(BaseTestElasticsearchStore):

key-value/key-value-sync/src/key_value/sync/code_gen/stores/elasticsearch/store.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
BaseEnumerateKeysStore,
2222
BaseStore,
2323
)
24-
from key_value.sync.code_gen.stores.elasticsearch.utils import LessCapableJsonSerializer, NdjsonSerializer, new_bulk_action
24+
from key_value.sync.code_gen.stores.elasticsearch.utils import LessCapableJsonSerializer, LessCapableNdjsonSerializer, new_bulk_action
2525

2626
try:
2727
from elasticsearch import Elasticsearch
@@ -79,14 +79,16 @@ def managed_entry_to_document(collection: str, key: str, managed_entry: ManagedE
7979
def source_to_managed_entry(source: dict[str, Any]) -> ManagedEntry:
8080
value: dict[str, Any] = {}
8181

82+
raw_value = source.get("value")
83+
8284
# Try flattened field first, fall back to string field
83-
if not (value := source.get("value")) or not isinstance(value, dict):
85+
if not raw_value or not isinstance(raw_value, dict):
8486
msg = "Value field not found or invalid type"
8587
raise DeserializationError(msg)
8688

87-
if value_flattened := value.get("flattened"):
89+
if value_flattened := raw_value.get("flattened"): # pyright: ignore[reportUnknownVariableType, reportUnknownMemberType]
8890
value = verify_dict(obj=value_flattened)
89-
elif value_str := value.get("string"):
91+
elif value_str := raw_value.get("string"): # pyright: ignore[reportUnknownVariableType, reportUnknownMemberType]
9092
if not isinstance(value_str, str):
9193
msg = "Value in `value` field is not a string"
9294
raise DeserializationError(msg)
@@ -161,7 +163,7 @@ def __init__(
161163

162164
LessCapableJsonSerializer.install_serializer(client=self._client)
163165
LessCapableJsonSerializer.install_default_serializer(client=self._client)
164-
NdjsonSerializer.install_serializer(client=self._client)
166+
LessCapableNdjsonSerializer.install_serializer(client=self._client)
165167

166168
self._index_prefix = index_prefix
167169
self._native_storage = native_storage

key-value/key-value-sync/src/key_value/sync/code_gen/stores/elasticsearch/utils.py

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ def get_aggregations_from_body(body: dict[str, Any]) -> dict[str, Any]:
3232
return {}
3333

3434
if not isinstance(aggregations, dict) or not all(isinstance(key, str) for key in aggregations): # pyright: ignore[reportUnknownVariableType]
35+
# pyright: ignore[reportUnknownVariableType]
3536
return {}
3637

3738
return cast("dict[str, Any]", aggregations)
@@ -115,37 +116,34 @@ def new_bulk_action(action: str, index: str, document_id: str) -> dict[str, Any]
115116
return {action: {"_index": index, "_id": document_id}}
116117

117118

118-
class InstallSerializerMixin:
119-
"""A mixin that installs the serializer into the transport."""
119+
class LessCapableJsonSerializer(JsonSerializer):
120+
"""A JSON Serializer that doesnt try to be smart with datetime, floats, etc."""
120121

121-
mimetype: ClassVar[str]
122-
compatibility_mimetype: ClassVar[str]
122+
mimetype: ClassVar[str] = "application/json"
123+
compatibility_mimetype: ClassVar[str] = "application/vnd.elasticsearch+json"
123124

124-
@classmethod
125-
def install_serializer(cls, client: Elasticsearch) -> None:
126-
client.transport.serializers.serializers.update({cls.mimetype: cls(), cls.compatibility_mimetype: cls()})
125+
def default(self, data: Any) -> Any:
126+
raise SerializationError(message=f"Unable to serialize to JSON: {data!r} (type: {type(data).__name__})")
127127

128128
@classmethod
129129
def install_default_serializer(cls, client: Elasticsearch) -> None:
130130
cls.install_serializer(client=client)
131131
client.transport.serializers.default_serializer = cls()
132132

133-
134-
class LessCapableJsonSerializer(InstallSerializerMixin, JsonSerializer):
135-
"""A JSON Serializer that doesnt try to be smart with datetime, floats, etc."""
136-
137-
mimetype: ClassVar[str] = "application/json"
138-
compatibility_mimetype: ClassVar[str] = "application/vnd.elasticsearch+json"
139-
140-
def default(self, data: Any) -> Any:
141-
raise SerializationError(message=f"Unable to serialize to JSON: {data!r} (type: {type(data).__name__})")
133+
@classmethod
134+
def install_serializer(cls, client: Elasticsearch) -> None:
135+
client.transport.serializers.serializers.update({cls.mimetype: cls(), cls.compatibility_mimetype: cls()})
142136

143137

144-
class NdjsonSerializer(InstallSerializerMixin, NdjsonSerializer):
138+
class LessCapableNdjsonSerializer(NdjsonSerializer):
145139
"""A NDJSON Serializer that doesnt try to be smart with datetime, floats, etc."""
146140

147141
mimetype: ClassVar[str] = "application/x-ndjson"
148142
compatibility_mimetype: ClassVar[str] = "application/vnd.elasticsearch+x-ndjson"
149143

150144
def default(self, data: Any) -> Any:
151-
return LessCapableJsonSerializer.default(self, data)
145+
return LessCapableJsonSerializer.default(data=data) # pyright: ignore[reportCallIssue, reportUnknownVariableType]
146+
147+
@classmethod
148+
def install_serializer(cls, client: Elasticsearch) -> None:
149+
client.transport.serializers.serializers.update({cls.mimetype: cls(), cls.compatibility_mimetype: cls()})

key-value/key-value-sync/tests/code_gen/stores/elasticsearch/test_elasticsearch.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -186,19 +186,18 @@ def test_value_stored_as_flattened_object(self, store: ElasticsearchStore, es_cl
186186
}
187187
)
188188

189-
def test_migration_from_legacy_mode(self, store: ElasticsearchStore, es_client: Elasticsearch):
190-
"""Verify native mode can read legacy JSON string data"""
191-
# Manually insert a legacy document with JSON string value
189+
def test_migration_from_non_native_mode(self, store: ElasticsearchStore, es_client: Elasticsearch):
190+
"""Verify native mode can read a document with stringified data"""
192191
index_name = store._sanitize_index_name(collection="test") # pyright: ignore[reportPrivateUsage]
193192
doc_id = store._sanitize_document_id(key="legacy_key") # pyright: ignore[reportPrivateUsage]
194193

195-
# JSON string
196-
es_client.index(index=index_name, id=doc_id, body={"collection": "test", "key": "legacy_key", "value": '{"legacy": "data"}'})
194+
es_client.index(
195+
index=index_name, id=doc_id, body={"collection": "test", "key": "legacy_key", "value": {"string": '{"legacy": "data"}'}}
196+
)
197197
es_client.indices.refresh(index=index_name)
198198

199-
# Should be able to read stringified values too
200199
result = store.get(collection="test", key="legacy_key")
201-
assert result == snapshot(None)
200+
assert result == snapshot({"legacy": "data"})
202201

203202

204203
class TestElasticsearchStoreNonNativeMode(BaseTestElasticsearchStore):

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,9 @@ typeCheckingMode = "strict"
9999
reportExplicitAny = false
100100
reportMissingTypeStubs = false
101101
#reportUnnecessaryTypeIgnoreComment = "error"
102+
include = [
103+
"key-value"
104+
]
102105
exclude = [
103106
"**/playground/**",
104107
"**/node_modules/**",

0 commit comments

Comments
 (0)