|
16 | 16 | logger = structlog.get_logger("codegate")
|
17 | 17 |
|
18 | 18 |
|
19 |
| -# call_directly is a function to call the model directly bypassing codegate |
20 |
| -def call_directly(url: str, headers: dict, data: dict) -> Optional[requests.Response]: |
21 |
| - try: |
22 |
| - headers["Content-Type"] = "application/json" |
23 |
| - stream = data.get("stream", False) |
24 |
| - response = requests.post(url, headers=headers, json=data, stream=stream) |
25 |
| - response.raise_for_status() |
26 |
| - return response |
27 |
| - except Exception as e: |
28 |
| - logger.error(f"Error making direct request to {url}: {str(e)}") |
29 |
| - return None |
30 |
| - |
31 |
| - |
32 | 19 | class CodegateTestRunner:
|
33 | 20 | def __init__(self):
|
34 | 21 | self.requester_factory = RequesterFactory()
|
35 | 22 | self.failed_tests = [] # Track failed tests
|
36 | 23 |
|
37 |
| - def call_codegate( |
| 24 | + def call_provider( |
38 | 25 | self, url: str, headers: dict, data: dict, provider: str, method: str = "POST"
|
39 | 26 | ) -> Optional[requests.Response]:
|
40 | 27 | logger.debug(f"Creating requester for provider: {provider}")
|
@@ -146,21 +133,23 @@ def replacement(match):
|
146 | 133 | async def run_test(self, test: dict, test_headers: dict) -> bool:
|
147 | 134 | test_name = test["name"]
|
148 | 135 | data = json.loads(test["data"])
|
| 136 | + codegate_url = test["url"] |
149 | 137 | streaming = data.get("stream", False)
|
150 | 138 | provider = test["provider"]
|
151 | 139 | logger.info(f"Starting test: {test_name}")
|
152 | 140 |
|
153 | 141 | # Call Codegate
|
154 |
| - response = self.call_codegate(test["url"], test_headers, data, provider) |
| 142 | + response = self.call_provider(codegate_url, test_headers, data, provider) |
155 | 143 | if not response:
|
156 | 144 | logger.error(f"Test {test_name} failed: No response received")
|
157 | 145 | return False
|
158 | 146 |
|
159 | 147 | # Call model directly if specified
|
160 | 148 | direct_response = None
|
161 | 149 | if test.get(CodeGateEnrichment.KEY) is not None:
|
162 |
| - direct_response = call_directly( |
163 |
| - test.get(CodeGateEnrichment.KEY)["provider_url"], test_headers, data |
| 150 | + direct_provider_url = test.get(CodeGateEnrichment.KEY)["provider_url"] |
| 151 | + direct_response = self.call_provider( |
| 152 | + direct_provider_url, test_headers, data, "not-codegate" |
164 | 153 | )
|
165 | 154 | if not direct_response:
|
166 | 155 | logger.error(f"Test {test_name} failed: No direct response received")
|
@@ -412,6 +401,7 @@ async def main():
|
412 | 401 | # Exit with status code 1 if any tests failed
|
413 | 402 | if not all_tests_passed:
|
414 | 403 | sys.exit(1)
|
| 404 | + logger.info("All tests passed") |
415 | 405 |
|
416 | 406 |
|
417 | 407 | if __name__ == "__main__":
|
|
0 commit comments