@@ -99,14 +99,6 @@ async def setup_elasticsearch(self, request: pytest.FixtureRequest) -> AsyncGene
9999
100100 yield
101101
102- @pytest .fixture
103- async def es_client (self ) -> AsyncGenerator [AsyncElasticsearch , None ]:
104- async with AsyncElasticsearch (hosts = [ES_URL ]) as es_client :
105- try :
106- yield es_client
107- finally :
108- await es_client .close ()
109-
110102 @override
111103 @pytest .fixture
112104 async def store (self ) -> ElasticsearchStore :
@@ -122,10 +114,11 @@ async def sanitizing_store(self) -> ElasticsearchStore:
122114 )
123115
124116 @pytest .fixture (autouse = True )
125- async def cleanup_elasticsearch_indices (self , es_client : AsyncElasticsearch ):
126- await cleanup_elasticsearch_indices (elasticsearch_client = es_client )
127- yield
128- await cleanup_elasticsearch_indices (elasticsearch_client = es_client )
117+ async def cleanup_elasticsearch (self ):
118+ async with get_elasticsearch_client () as es_client :
119+ await cleanup_elasticsearch_indices (elasticsearch_client = es_client )
120+ yield
121+ await cleanup_elasticsearch_indices (elasticsearch_client = es_client )
129122
130123 @pytest .mark .skip (reason = "Distributed Caches are unbounded" )
131124 @override
@@ -152,18 +145,19 @@ async def test_long_key_name(self, store: ElasticsearchStore, sanitizing_store:
152145 await sanitizing_store .put (collection = "test_collection" , key = "test_key" * 100 , value = {"test" : "test" })
153146 assert await sanitizing_store .get (collection = "test_collection" , key = "test_key" * 100 ) == {"test" : "test" }
154147
155- async def test_put_put_two_indices (self , store : ElasticsearchStore , es_client : AsyncElasticsearch ):
148+ async def test_put_put_two_indices (self , store : ElasticsearchStore ):
156149 await store .put (collection = "test_collection" , key = "test_key" , value = {"test" : "test" })
157150 await store .put (collection = "test_collection_2" , key = "test_key" , value = {"test" : "test" })
158151 assert await store .get (collection = "test_collection" , key = "test_key" ) == {"test" : "test" }
159152 assert await store .get (collection = "test_collection_2" , key = "test_key" ) == {"test" : "test" }
160153
161- indices = await es_client .options (ignore_status = 404 ).indices .get (index = "kv-store-e2e-test-*" )
162- assert len (indices .body ) == 2
163- index_names : list [str ] = [str (key ) for key in indices ]
164- assert index_names == snapshot (["kv-store-e2e-test-test_collection" , "kv-store-e2e-test-test_collection_2" ])
154+ async with get_elasticsearch_client () as es_client :
155+ indices = await es_client .options (ignore_status = 404 ).indices .get (index = "kv-store-e2e-test-*" )
156+ assert len (indices .body ) == 2
157+ index_names : list [str ] = [str (key ) for key in indices ]
158+ assert index_names == snapshot (["kv-store-e2e-test-test_collection" , "kv-store-e2e-test-test_collection_2" ])
165159
166- async def test_value_stored_as_flattened_object (self , store : ElasticsearchStore , es_client : AsyncElasticsearch ):
160+ async def test_value_stored_as_flattened_object (self , store : ElasticsearchStore ):
167161 """Verify values are stored as flattened objects, not JSON strings"""
168162 await store .put (collection = "test" , key = "test_key" , value = {"name" : "Alice" , "age" : 30 })
169163
@@ -172,30 +166,32 @@ async def test_value_stored_as_flattened_object(self, store: ElasticsearchStore,
172166 index_name = store ._get_index_name (collection = "test" ) # pyright: ignore[reportPrivateUsage]
173167 doc_id = store ._get_document_id (key = "test_key" ) # pyright: ignore[reportPrivateUsage]
174168
175- response = await es_client .get (index = index_name , id = doc_id )
176- assert response .body ["_source" ] == snapshot (
177- {
178- "version" : 1 ,
179- "key" : "test_key" ,
180- "collection" : "test" ,
181- "value" : {"flattened" : {"name" : "Alice" , "age" : 30 }},
182- "created_at" : IsStr (min_length = 20 , max_length = 40 ),
183- }
184- )
185-
186- # Test with TTL
187- await store .put (collection = "test" , key = "test_key" , value = {"name" : "Bob" , "age" : 25 }, ttl = 10 )
188- response = await es_client .get (index = index_name , id = doc_id )
189- assert response .body ["_source" ] == snapshot (
190- {
191- "version" : 1 ,
192- "key" : "test_key" ,
193- "collection" : "test" ,
194- "value" : {"flattened" : {"name" : "Bob" , "age" : 25 }},
195- "created_at" : IsStr (min_length = 20 , max_length = 40 ),
196- "expires_at" : IsStr (min_length = 20 , max_length = 40 ),
197- }
198- )
169+ async with get_elasticsearch_client () as es_client :
170+ response = await es_client .get (index = index_name , id = doc_id )
171+ assert response .body ["_source" ] == snapshot (
172+ {
173+ "version" : 1 ,
174+ "key" : "test_key" ,
175+ "collection" : "test" ,
176+ "value" : {"flattened" : {"name" : "Alice" , "age" : 30 }},
177+ "created_at" : IsStr (min_length = 20 , max_length = 40 ),
178+ }
179+ )
180+
181+ # Test with TTL
182+ await store .put (collection = "test" , key = "test_key" , value = {"name" : "Bob" , "age" : 25 }, ttl = 10 )
183+
184+ response = await es_client .get (index = index_name , id = doc_id )
185+ assert response .body ["_source" ] == snapshot (
186+ {
187+ "version" : 1 ,
188+ "key" : "test_key" ,
189+ "collection" : "test" ,
190+ "value" : {"flattened" : {"name" : "Bob" , "age" : 25 }},
191+ "created_at" : IsStr (min_length = 20 , max_length = 40 ),
192+ "expires_at" : IsStr (min_length = 20 , max_length = 40 ),
193+ }
194+ )
199195
200196 @override
201197 async def test_special_characters_in_collection_name (self , store : ElasticsearchStore , sanitizing_store : ElasticsearchStore ): # pyright: ignore[reportIncompatibleMethodOverride]
0 commit comments