|
3 | 3 | import os
|
4 | 4 | import re
|
5 | 5 | import sys
|
6 |
| -from typing import Optional |
| 6 | +from typing import Dict, Optional, Tuple |
7 | 7 |
|
8 | 8 | import requests
|
9 | 9 | import structlog
|
@@ -172,54 +172,56 @@ async def run_test(self, test: dict, test_headers: dict) -> bool:
|
172 | 172 | self.failed_tests.append(test_name)
|
173 | 173 | return False
|
174 | 174 |
|
175 |
| - async def run_tests( |
176 |
| - self, |
177 |
| - testcases_file: str, |
178 |
| - providers: Optional[list[str]] = None, |
179 |
| - test_names: Optional[list[str]] = None, |
180 |
| - ) -> bool: |
181 |
| - with open(testcases_file, "r") as f: |
182 |
| - tests = yaml.safe_load(f) |
| 175 | + async def _get_testcases( |
| 176 | + self, testcases_dict: Dict, test_names: Optional[list[str]] = None |
| 177 | + ) -> Dict: |
| 178 | + testcases: Dict[str, Dict[str, str]] = testcases_dict["testcases"] |
183 | 179 |
|
184 |
| - headers = tests["headers"] |
185 |
| - testcases = tests["testcases"] |
186 |
| - |
187 |
| - if providers or test_names: |
| 180 | + # Filter testcases by provider and test names |
| 181 | + if test_names: |
188 | 182 | filtered_testcases = {}
|
189 | 183 |
|
| 184 | + # Iterate over the original testcases and only keep the ones that match the |
| 185 | + # specified test names |
190 | 186 | for test_id, test_data in testcases.items():
|
191 |
| - if providers: |
192 |
| - if test_data.get("provider", "").lower() not in [p.lower() for p in providers]: |
193 |
| - continue |
194 |
| - |
195 |
| - if test_names: |
196 |
| - if test_data.get("name", "").lower() not in [t.lower() for t in test_names]: |
197 |
| - continue |
| 187 | + if test_data.get("name", "").lower() not in [t.lower() for t in test_names]: |
| 188 | + continue |
198 | 189 |
|
199 | 190 | filtered_testcases[test_id] = test_data
|
200 | 191 |
|
201 | 192 | testcases = filtered_testcases
|
| 193 | + return testcases |
202 | 194 |
|
203 |
| - if not testcases: |
204 |
| - filter_msg = [] |
205 |
| - if providers: |
206 |
| - filter_msg.append(f"providers: {', '.join(providers)}") |
207 |
| - if test_names: |
208 |
| - filter_msg.append(f"test names: {', '.join(test_names)}") |
209 |
| - logger.warning(f"No tests found for {' and '.join(filter_msg)}") |
210 |
| - return True # No tests is not a failure |
| 195 | + async def _setup( |
| 196 | + self, testcases_file: str, test_names: Optional[list[str]] = None |
| 197 | + ) -> Tuple[Dict, Dict]: |
| 198 | + with open(testcases_file, "r") as f: |
| 199 | + testcases_dict = yaml.safe_load(f) |
| 200 | + |
| 201 | + headers = testcases_dict["headers"] |
| 202 | + testcases = await self._get_testcases(testcases_dict, test_names) |
| 203 | + return headers, testcases |
| 204 | + |
| 205 | + async def run_tests( |
| 206 | + self, |
| 207 | + testcases_file: str, |
| 208 | + provider: str, |
| 209 | + test_names: Optional[list[str]] = None, |
| 210 | + ) -> bool: |
| 211 | + headers, testcases = await self._setup(testcases_file, test_names) |
| 212 | + |
| 213 | + if not testcases: |
| 214 | + logger.warning( |
| 215 | + f"No tests found for provider {provider} in file: {testcases_file} " |
| 216 | + f"and specific testcases: {test_names}" |
| 217 | + ) |
| 218 | + return True # No tests is not a failure |
211 | 219 |
|
212 | 220 | test_count = len(testcases)
|
213 |
| - filter_msg = [] |
214 |
| - if providers: |
215 |
| - filter_msg.append(f"providers: {', '.join(providers)}") |
| 221 | + logging_msg = f"Running {test_count} tests for provider {provider}" |
216 | 222 | if test_names:
|
217 |
| - filter_msg.append(f"test names: {', '.join(test_names)}") |
218 |
| - |
219 |
| - logger.info( |
220 |
| - f"Running {test_count} tests" |
221 |
| - + (f" for {' and '.join(filter_msg)}" if filter_msg else "") |
222 |
| - ) |
| 223 | + logging_msg += f" and test names: {', '.join(test_names)}" |
| 224 | + logger.info(logging_msg) |
223 | 225 |
|
224 | 226 | all_tests_passed = True
|
225 | 227 | for test_id, test_data in testcases.items():
|
@@ -285,10 +287,12 @@ async def main():
|
285 | 287 | logger.warning(f"No testcases.yaml found for provider {provider}")
|
286 | 288 | continue
|
287 | 289 |
|
| 290 | + # Run tests for the provider. The provider has already been selected when |
| 291 | + # reading the testcases.yaml file. |
288 | 292 | logger.info(f"Running tests for provider: {provider}")
|
289 | 293 | provider_tests_passed = await test_runner.run_tests(
|
290 | 294 | provider_test_file,
|
291 |
| - providers=[provider], # Only run tests for current provider |
| 295 | + provider=provider, |
292 | 296 | test_names=test_names,
|
293 | 297 | )
|
294 | 298 | all_tests_passed = all_tests_passed and provider_tests_passed
|
|
0 commit comments