diff --git a/tests/integration/integration_tests.py b/tests/integration/integration_tests.py index 1326d193..befb3ff3 100644 --- a/tests/integration/integration_tests.py +++ b/tests/integration/integration_tests.py @@ -3,7 +3,7 @@ import os import re import sys -from typing import Optional +from typing import Dict, Optional, Tuple import requests import structlog @@ -172,54 +172,56 @@ async def run_test(self, test: dict, test_headers: dict) -> bool: self.failed_tests.append(test_name) return False - async def run_tests( - self, - testcases_file: str, - providers: Optional[list[str]] = None, - test_names: Optional[list[str]] = None, - ) -> bool: - with open(testcases_file, "r") as f: - tests = yaml.safe_load(f) + async def _get_testcases( + self, testcases_dict: Dict, test_names: Optional[list[str]] = None + ) -> Dict: + testcases: Dict[str, Dict[str, str]] = testcases_dict["testcases"] - headers = tests["headers"] - testcases = tests["testcases"] - - if providers or test_names: + # Filter testcases by provider and test names + if test_names: filtered_testcases = {} + # Iterate over the original testcases and only keep the ones that match the + # specified test names for test_id, test_data in testcases.items(): - if providers: - if test_data.get("provider", "").lower() not in [p.lower() for p in providers]: - continue - - if test_names: - if test_data.get("name", "").lower() not in [t.lower() for t in test_names]: - continue + if test_data.get("name", "").lower() not in [t.lower() for t in test_names]: + continue filtered_testcases[test_id] = test_data testcases = filtered_testcases + return testcases - if not testcases: - filter_msg = [] - if providers: - filter_msg.append(f"providers: {', '.join(providers)}") - if test_names: - filter_msg.append(f"test names: {', '.join(test_names)}") - logger.warning(f"No tests found for {' and '.join(filter_msg)}") - return True # No tests is not a failure + async def _setup( + self, testcases_file: str, test_names: Optional[list[str]] = None + ) -> Tuple[Dict, Dict]: + with open(testcases_file, "r") as f: + testcases_dict = yaml.safe_load(f) + + headers = testcases_dict["headers"] + testcases = await self._get_testcases(testcases_dict, test_names) + return headers, testcases + + async def run_tests( + self, + testcases_file: str, + provider: str, + test_names: Optional[list[str]] = None, + ) -> bool: + headers, testcases = await self._setup(testcases_file, test_names) + + if not testcases: + logger.warning( + f"No tests found for provider {provider} in file: {testcases_file} " + f"and specific testcases: {test_names}" + ) + return True # No tests is not a failure test_count = len(testcases) - filter_msg = [] - if providers: - filter_msg.append(f"providers: {', '.join(providers)}") + logging_msg = f"Running {test_count} tests for provider {provider}" if test_names: - filter_msg.append(f"test names: {', '.join(test_names)}") - - logger.info( - f"Running {test_count} tests" - + (f" for {' and '.join(filter_msg)}" if filter_msg else "") - ) + logging_msg += f" and test names: {', '.join(test_names)}" + logger.info(logging_msg) all_tests_passed = True for test_id, test_data in testcases.items(): @@ -285,10 +287,12 @@ async def main(): logger.warning(f"No testcases.yaml found for provider {provider}") continue + # Run tests for the provider. The provider has already been selected when + # reading the testcases.yaml file. logger.info(f"Running tests for provider: {provider}") provider_tests_passed = await test_runner.run_tests( provider_test_file, - providers=[provider], # Only run tests for current provider + provider=provider, test_names=test_names, ) all_tests_passed = all_tests_passed and provider_tests_passed