33# DO NOT CHANGE! Change the original file instead.
44from collections .abc import Generator
55from datetime import datetime , timedelta , timezone
6+ from typing import TYPE_CHECKING , Any
67
78import pytest
89from dirty_equals import IsFloat , IsStr
2223from tests .code_gen .conftest import docker_container , should_skip_docker_tests
2324from tests .code_gen .stores .base import BaseStoreTests , ContextManagerStoreTestMixin
2425
26+ if TYPE_CHECKING :
27+ from elastic_transport ._response import ObjectApiResponse
28+
2529TEST_SIZE_LIMIT = 1 * 1024 * 1024 # 1MB
2630ES_HOST = "localhost"
27- CONTAINER_PORT = 9200
28- HOST_PORT = 19200
29- ES_URL = f"http:// { ES_HOST } : { HOST_PORT } "
31+ ES_PORT = 9200
32+ ES_URL = f"http:// { ES_HOST } : { ES_PORT } "
33+ ES_CONTAINER_PORT = 9200
3034
3135WAIT_FOR_ELASTICSEARCH_TIMEOUT = 30
3236# Released Apr 2025
@@ -42,7 +46,12 @@ def ping_elasticsearch() -> bool:
4246 es_client : Elasticsearch = get_elasticsearch_client ()
4347
4448 with es_client :
45- return es_client .ping ()
49+ if not es_client .ping ():
50+ return False
51+
52+ status : ObjectApiResponse [dict [str , Any ]] = es_client .options (ignore_status = 404 ).cluster .health (wait_for_status = "green" )
53+
54+ return status .body .get ("status" ) == "green"
4655
4756
4857def cleanup_elasticsearch_indices (elasticsearch_client : Elasticsearch ):
@@ -91,7 +100,7 @@ def setup_elasticsearch(self, request: pytest.FixtureRequest) -> Generator[None,
91100 with docker_container (
92101 f"elasticsearch-test-{ version } " ,
93102 es_image ,
94- {str (CONTAINER_PORT ): HOST_PORT },
103+ {str (ES_CONTAINER_PORT ): ES_PORT },
95104 {"discovery.type" : "single-node" , "xpack.security.enabled" : "false" },
96105 ):
97106 if not wait_for_true (bool_fn = ping_elasticsearch , tries = WAIT_FOR_ELASTICSEARCH_TIMEOUT , wait_time = 2 ):
@@ -100,6 +109,11 @@ def setup_elasticsearch(self, request: pytest.FixtureRequest) -> Generator[None,
100109
101110 yield
102111
112+ @pytest .fixture
113+ def es_client (self ) -> Generator [Elasticsearch , None , None ]:
114+ with Elasticsearch (hosts = [ES_URL ]) as es_client :
115+ yield es_client
116+
103117 @override
104118 @pytest .fixture
105119 def store (self ) -> ElasticsearchStore :
@@ -115,11 +129,10 @@ def sanitizing_store(self) -> ElasticsearchStore:
115129 )
116130
117131 @pytest .fixture (autouse = True )
118- def cleanup_elasticsearch (self ):
119- with get_elasticsearch_client () as es_client :
120- cleanup_elasticsearch_indices (elasticsearch_client = es_client )
121- yield
122- cleanup_elasticsearch_indices (elasticsearch_client = es_client )
132+ def cleanup_elasticsearch_indices (self , es_client : Elasticsearch ):
133+ cleanup_elasticsearch_indices (elasticsearch_client = es_client )
134+ yield
135+ cleanup_elasticsearch_indices (elasticsearch_client = es_client )
123136
124137 @pytest .mark .skip (reason = "Distributed Caches are unbounded" )
125138 @override
@@ -146,19 +159,18 @@ def test_long_key_name(self, store: ElasticsearchStore, sanitizing_store: Elasti
146159 sanitizing_store .put (collection = "test_collection" , key = "test_key" * 100 , value = {"test" : "test" })
147160 assert sanitizing_store .get (collection = "test_collection" , key = "test_key" * 100 ) == {"test" : "test" }
148161
149- def test_put_put_two_indices (self , store : ElasticsearchStore ):
162+ def test_put_put_two_indices (self , store : ElasticsearchStore , es_client : Elasticsearch ):
150163 store .put (collection = "test_collection" , key = "test_key" , value = {"test" : "test" })
151164 store .put (collection = "test_collection_2" , key = "test_key" , value = {"test" : "test" })
152165 assert store .get (collection = "test_collection" , key = "test_key" ) == {"test" : "test" }
153166 assert store .get (collection = "test_collection_2" , key = "test_key" ) == {"test" : "test" }
154167
155- with get_elasticsearch_client () as es_client :
156- indices = es_client .options (ignore_status = 404 ).indices .get (index = "kv-store-e2e-test-*" )
157- assert len (indices .body ) == 2
158- index_names : list [str ] = [str (key ) for key in indices ]
159- assert index_names == snapshot (["kv-store-e2e-test-test_collection" , "kv-store-e2e-test-test_collection_2" ])
168+ indices = es_client .options (ignore_status = 404 ).indices .get (index = "kv-store-e2e-test-*" )
169+ assert len (indices .body ) == 2
170+ index_names : list [str ] = [str (key ) for key in indices ]
171+ assert index_names == snapshot (["kv-store-e2e-test-test_collection" , "kv-store-e2e-test-test_collection_2" ])
160172
161- def test_value_stored_as_flattened_object (self , store : ElasticsearchStore ):
173+ def test_value_stored_as_flattened_object (self , store : ElasticsearchStore , es_client : Elasticsearch ):
162174 """Verify values are stored as flattened objects, not JSON strings"""
163175 store .put (collection = "test" , key = "test_key" , value = {"name" : "Alice" , "age" : 30 })
164176
@@ -167,32 +179,30 @@ def test_value_stored_as_flattened_object(self, store: ElasticsearchStore):
167179 index_name = store ._get_index_name (collection = "test" ) # pyright: ignore[reportPrivateUsage]
168180 doc_id = store ._get_document_id (key = "test_key" ) # pyright: ignore[reportPrivateUsage]
169181
170- with get_elasticsearch_client () as es_client :
171- response = es_client .get (index = index_name , id = doc_id )
172- assert response .body ["_source" ] == snapshot (
173- {
174- "version" : 1 ,
175- "key" : "test_key" ,
176- "collection" : "test" ,
177- "value" : {"flattened" : {"name" : "Alice" , "age" : 30 }},
178- "created_at" : IsStr (min_length = 20 , max_length = 40 ),
179- }
180- )
181-
182- # Test with TTL
183- store .put (collection = "test" , key = "test_key" , value = {"name" : "Bob" , "age" : 25 }, ttl = 10 )
184-
185- response = es_client .get (index = index_name , id = doc_id )
186- assert response .body ["_source" ] == snapshot (
187- {
188- "version" : 1 ,
189- "key" : "test_key" ,
190- "collection" : "test" ,
191- "value" : {"flattened" : {"name" : "Bob" , "age" : 25 }},
192- "created_at" : IsStr (min_length = 20 , max_length = 40 ),
193- "expires_at" : IsStr (min_length = 20 , max_length = 40 ),
194- }
195- )
182+ response = es_client .get (index = index_name , id = doc_id )
183+ assert response .body ["_source" ] == snapshot (
184+ {
185+ "version" : 1 ,
186+ "key" : "test_key" ,
187+ "collection" : "test" ,
188+ "value" : {"flattened" : {"name" : "Alice" , "age" : 30 }},
189+ "created_at" : IsStr (min_length = 20 , max_length = 40 ),
190+ }
191+ )
192+
193+ # Test with TTL
194+ store .put (collection = "test" , key = "test_key" , value = {"name" : "Bob" , "age" : 25 }, ttl = 10 )
195+ response = es_client .get (index = index_name , id = doc_id )
196+ assert response .body ["_source" ] == snapshot (
197+ {
198+ "version" : 1 ,
199+ "key" : "test_key" ,
200+ "collection" : "test" ,
201+ "value" : {"flattened" : {"name" : "Bob" , "age" : 25 }},
202+ "created_at" : IsStr (min_length = 20 , max_length = 40 ),
203+ "expires_at" : IsStr (min_length = 20 , max_length = 40 ),
204+ }
205+ )
196206
197207 @override
198208 def test_special_characters_in_collection_name (self , store : ElasticsearchStore , sanitizing_store : ElasticsearchStore ): # pyright: ignore[reportIncompatibleMethodOverride]
0 commit comments