From b1b35d9e9bdb65e58dd9107b7af745d00496f207 Mon Sep 17 00:00:00 2001 From: Alejandro Ponce Date: Thu, 13 Feb 2025 13:25:48 +0200 Subject: [PATCH] Add integration tests with muxing Replicate the current integration tests but instead of using the specific provider URL, e.g. `/ollama` use the muxing URL, i.e. `/v1/mux/`. Muxing functionality should take care of routing the request to the correct model and provider. For the moment we're only going to test with the "catch_all" rule. Meaning, all the requests will be directed to the same model. In future iterations we can expand the integration tests to check for multiple rules across different providers. --- src/codegate/muxing/adapter.py | 88 ++++++++++-------- tests/integration/anthropic/testcases.yaml | 25 ++++++ tests/integration/integration_tests.py | 99 ++++++++++++++++++--- tests/integration/llamacpp/testcases.yaml | 24 +++++ tests/integration/ollama/testcases.yaml | 24 +++++ tests/integration/openai/testcases.yaml | 25 ++++++ tests/integration/openrouter/testcases.yaml | 25 ++++++ tests/integration/requesters.py | 24 +++-- tests/integration/vllm/testcases.yaml | 24 +++++ tests/muxing/test_adapter.py | 1 + 10 files changed, 304 insertions(+), 55 deletions(-) diff --git a/src/codegate/muxing/adapter.py b/src/codegate/muxing/adapter.py index 74513e98..5a4a70c1 100644 --- a/src/codegate/muxing/adapter.py +++ b/src/codegate/muxing/adapter.py @@ -32,7 +32,10 @@ class BodyAdapter: def _get_provider_formatted_url(self, model_route: rulematcher.ModelRoute) -> str: """Get the provider formatted URL to use in base_url. Note this value comes from DB""" - if model_route.endpoint.provider_type == db_models.ProviderType.openai: + if model_route.endpoint.provider_type in [ + db_models.ProviderType.openai, + db_models.ProviderType.vllm, + ]: return urljoin(model_route.endpoint.endpoint, "/v1") if model_route.endpoint.provider_type == db_models.ProviderType.openrouter: return urljoin(model_route.endpoint.endpoint, "/api/v1") @@ -90,6 +93,47 @@ def _format_openai(self, chunk: str) -> str: cleaned_chunk = chunk.split("data:")[1].strip() return cleaned_chunk + def _format_antropic(self, chunk: str) -> str: + """ + Format the Anthropic chunk to OpenAI format. + + This function is used by both chat and FIM formatters + """ + cleaned_chunk = chunk.split("data:")[1].strip() + try: + chunk_dict = json.loads(cleaned_chunk) + msg_type = chunk_dict.get("type", "") + + finish_reason = None + if msg_type == "message_stop": + finish_reason = "stop" + + # In type == "content_block_start" the content comes in "content_block" + # In type == "content_block_delta" the content comes in "delta" + msg_content_dict = chunk_dict.get("delta", {}) or chunk_dict.get("content_block", {}) + # We couldn't obtain the content from the chunk. Skip it. + if not msg_content_dict: + return "" + + msg_content = msg_content_dict.get("text", "") + open_ai_chunk = ModelResponse( + id=f"anthropic-chat-{str(uuid.uuid4())}", + model="anthropic-muxed-model", + object="chat.completion.chunk", + choices=[ + StreamingChoices( + finish_reason=finish_reason, + index=0, + delta=Delta(content=msg_content, role="assistant"), + logprobs=None, + ) + ], + ) + return open_ai_chunk.model_dump_json(exclude_none=True, exclude_unset=True) + except Exception as e: + logger.warning(f"Error formatting Anthropic chunk: {chunk}. Error: {e}") + return cleaned_chunk.strip() + def _format_as_openai_chunk(self, formatted_chunk: str) -> str: """Format the chunk as OpenAI chunk. This is the format how the clients expect the data.""" chunk_to_send = f"data:{formatted_chunk}\n\n" @@ -148,6 +192,8 @@ def provider_format_funcs(self) -> Dict[str, Callable]: db_models.ProviderType.llamacpp: self._format_openai, # OpenRouter is a dialect of OpenAI db_models.ProviderType.openrouter: self._format_openai, + # VLLM is a dialect of OpenAI + db_models.ProviderType.vllm: self._format_openai, } def _format_ollama(self, chunk: str) -> str: @@ -165,43 +211,6 @@ def _format_ollama(self, chunk: str) -> str: logger.warning(f"Error formatting Ollama chunk: {chunk}. Error: {e}") return chunk - def _format_antropic(self, chunk: str) -> str: - """Format the Anthropic chunk to OpenAI format.""" - cleaned_chunk = chunk.split("data:")[1].strip() - try: - chunk_dict = json.loads(cleaned_chunk) - msg_type = chunk_dict.get("type", "") - - finish_reason = None - if msg_type == "message_stop": - finish_reason = "stop" - - # In type == "content_block_start" the content comes in "content_block" - # In type == "content_block_delta" the content comes in "delta" - msg_content_dict = chunk_dict.get("delta", {}) or chunk_dict.get("content_block", {}) - # We couldn't obtain the content from the chunk. Skip it. - if not msg_content_dict: - return "" - - msg_content = msg_content_dict.get("text", "") - open_ai_chunk = ModelResponse( - id=f"anthropic-chat-{str(uuid.uuid4())}", - model="anthropic-muxed-model", - object="chat.completion.chunk", - choices=[ - StreamingChoices( - finish_reason=finish_reason, - index=0, - delta=Delta(content=msg_content, role="assistant"), - logprobs=None, - ) - ], - ) - return open_ai_chunk.model_dump_json(exclude_none=True, exclude_unset=True) - except Exception as e: - logger.warning(f"Error formatting Anthropic chunk: {chunk}. Error: {e}") - return cleaned_chunk.strip() - class FimStreamChunkFormatter(StreamChunkFormatter): @@ -218,6 +227,9 @@ def provider_format_funcs(self) -> Dict[str, Callable]: db_models.ProviderType.llamacpp: self._format_openai, # OpenRouter is a dialect of OpenAI db_models.ProviderType.openrouter: self._format_openai, + # VLLM is a dialect of OpenAI + db_models.ProviderType.vllm: self._format_openai, + db_models.ProviderType.anthropic: self._format_antropic, } def _format_ollama(self, chunk: str) -> str: diff --git a/tests/integration/anthropic/testcases.yaml b/tests/integration/anthropic/testcases.yaml index 6ad992de..03f8f666 100644 --- a/tests/integration/anthropic/testcases.yaml +++ b/tests/integration/anthropic/testcases.yaml @@ -2,6 +2,31 @@ headers: anthropic: x-api-key: ENV_ANTHROPIC_KEY +muxing: + mux_url: http://127.0.0.1:8989/v1/mux/ + trimm_from_testcase_url: http://127.0.0.1:8989/anthropic/ + provider_endpoint: + url: http://127.0.0.1:8989/api/v1/provider-endpoints + headers: + Content-Type: application/json + data: | + { + "name": "anthropic_muxing", + "description": "Muxing testing endpoint", + "provider_type": "anthropic", + "endpoint": "https://api.anthropic.com/", + "auth_type": "api_key", + "api_key": "ENV_ANTHROPIC_KEY" + } + muxes: + url: http://127.0.0.1:8989/api/v1/workspaces/default/muxes + headers: + Content-Type: application/json + rules: + - model: claude-3-5-haiku-20241022 + matcher_type: catch_all + matcher: "" + testcases: anthropic_chat: name: Anthropic Chat diff --git a/tests/integration/integration_tests.py b/tests/integration/integration_tests.py index befb3ff3..efc2e105 100644 --- a/tests/integration/integration_tests.py +++ b/tests/integration/integration_tests.py @@ -1,9 +1,10 @@ import asyncio +import copy import json import os import re import sys -from typing import Dict, Optional, Tuple +from typing import Any, Dict, Optional, Tuple import requests import structlog @@ -21,7 +22,7 @@ def __init__(self): self.failed_tests = [] # Track failed tests def call_codegate( - self, url: str, headers: dict, data: dict, provider: str + self, url: str, headers: dict, data: dict, provider: str, method: str = "POST" ) -> Optional[requests.Response]: logger.debug(f"Creating requester for provider: {provider}") requester = self.requester_factory.create_requester(provider) @@ -31,12 +32,12 @@ def call_codegate( logger.debug(f"Headers: {headers}") logger.debug(f"Data: {data}") - response = requester.make_request(url, headers, data) + response = requester.make_request(url, headers, data, method=method) # Enhanced response logging if response is not None: - if response.status_code != 200: + if response.status_code not in [200, 201, 204]: logger.debug(f"Response error status: {response.status_code}") logger.debug(f"Response error headers: {dict(response.headers)}") try: @@ -174,7 +175,7 @@ async def run_test(self, test: dict, test_headers: dict) -> bool: async def _get_testcases( self, testcases_dict: Dict, test_names: Optional[list[str]] = None - ) -> Dict: + ) -> Dict[str, Dict[str, str]]: testcases: Dict[str, Dict[str, str]] = testcases_dict["testcases"] # Filter testcases by provider and test names @@ -192,15 +193,94 @@ async def _get_testcases( testcases = filtered_testcases return testcases + async def _setup_muxing( + self, provider: str, muxing_config: Optional[Dict] + ) -> Optional[Tuple[str, str]]: + """ + Muxing setup. Create the provider endpoints and the muxing rules + + Return + """ + # The muxing section was not found in the testcases.yaml file. Nothing to do. + if not muxing_config: + return + + # Create the provider endpoint + provider_endpoint = muxing_config.get("provider_endpoint") + try: + data_with_api_keys = self.replace_env_variables(provider_endpoint["data"], os.environ) + response_create_provider = self.call_codegate( + provider=provider, + url=provider_endpoint["url"], + headers=provider_endpoint["headers"], + data=json.loads(data_with_api_keys), + ) + created_provider_endpoint = response_create_provider.json() + except Exception as e: + logger.warning(f"Could not setup provider endpoint for muxing: {e}") + return + logger.info("Created provider endpoint for muixing") + + muxes_rules: Dict[str, Any] = muxing_config.get("muxes", {}) + try: + # We need to first update all the muxes with the provider_id + for mux in muxes_rules.get("rules", []): + mux["provider_id"] = created_provider_endpoint["id"] + + # The endpoint actually takes a list + self.call_codegate( + provider=provider, + url=muxes_rules["url"], + headers=muxes_rules["headers"], + data=muxes_rules.get("rules", []), + method="PUT", + ) + except Exception as e: + logger.warning(f"Could not setup muxing rules: {e}") + return + logger.info("Created muxing rules") + + return muxing_config["mux_url"], muxing_config["trimm_from_testcase_url"] + + async def _augment_testcases_with_muxing( + self, testcases: Dict, mux_url: str, trimm_from_testcase_url: str + ) -> Dict: + """ + Augment the testcases with the muxing information. Copy the testcases + and execute them through the muxing endpoint. + """ + test_cases_with_muxing = copy.deepcopy(testcases) + for test_id, test_data in testcases.items(): + # Replace the provider in the URL with the muxed URL + rest_of_path = test_data["url"].replace(trimm_from_testcase_url, "") + new_url = f"{mux_url}{rest_of_path}" + new_test_data = copy.deepcopy(test_data) + new_test_data["url"] = new_url + new_test_id = f"{test_id}_muxed" + test_cases_with_muxing[new_test_id] = new_test_data + + logger.info("Augmented testcases with muxing") + return test_cases_with_muxing + async def _setup( - self, testcases_file: str, test_names: Optional[list[str]] = None + self, testcases_file: str, provider: str, test_names: Optional[list[str]] = None ) -> Tuple[Dict, Dict]: with open(testcases_file, "r") as f: - testcases_dict = yaml.safe_load(f) + testcases_dict: Dict = yaml.safe_load(f) headers = testcases_dict["headers"] testcases = await self._get_testcases(testcases_dict, test_names) - return headers, testcases + muxing_result = await self._setup_muxing(provider, testcases_dict.get("muxing", {})) + # We don't have any muxing setup, return the headers and testcases + if not muxing_result: + return headers, testcases + + mux_url, trimm_from_testcase_url = muxing_result + test_cases_with_muxing = await self._augment_testcases_with_muxing( + testcases, mux_url, trimm_from_testcase_url + ) + + return headers, test_cases_with_muxing async def run_tests( self, @@ -208,8 +288,7 @@ async def run_tests( provider: str, test_names: Optional[list[str]] = None, ) -> bool: - headers, testcases = await self._setup(testcases_file, test_names) - + headers, testcases = await self._setup(testcases_file, provider, test_names) if not testcases: logger.warning( f"No tests found for provider {provider} in file: {testcases_file} " diff --git a/tests/integration/llamacpp/testcases.yaml b/tests/integration/llamacpp/testcases.yaml index b4c8bbd0..69ec72df 100644 --- a/tests/integration/llamacpp/testcases.yaml +++ b/tests/integration/llamacpp/testcases.yaml @@ -2,6 +2,30 @@ headers: llamacpp: Content-Type: application/json +muxing: + mux_url: http://127.0.0.1:8989/v1/mux/ + trimm_from_testcase_url: http://127.0.0.1:8989/llamacpp/ + provider_endpoint: + url: http://127.0.0.1:8989/api/v1/provider-endpoints + headers: + Content-Type: application/json + data: | + { + "name": "llamacpp_muxing", + "description": "Muxing testing endpoint", + "provider_type": "llamacpp", + "endpoint": "./codegate_volume/models", + "auth_type": "none" + } + muxes: + url: http://127.0.0.1:8989/api/v1/workspaces/default/muxes + headers: + Content-Type: application/json + rules: + - model: qwen2.5-coder-0.5b-instruct-q5_k_m + matcher_type: catch_all + matcher: "" + testcases: llamacpp_chat: name: LlamaCPP Chat diff --git a/tests/integration/ollama/testcases.yaml b/tests/integration/ollama/testcases.yaml index 38c8ba7a..9931ecdb 100644 --- a/tests/integration/ollama/testcases.yaml +++ b/tests/integration/ollama/testcases.yaml @@ -2,6 +2,30 @@ headers: ollama: Content-Type: application/json +muxing: + mux_url: http://127.0.0.1:8989/v1/mux/ + trimm_from_testcase_url: http://127.0.0.1:8989/ollama/ + provider_endpoint: + url: http://127.0.0.1:8989/api/v1/provider-endpoints + headers: + Content-Type: application/json + data: | + { + "name": "ollama_muxing", + "description": "Muxing testing endpoint", + "provider_type": "ollama", + "endpoint": "http://127.0.0.1:11434", + "auth_type": "none" + } + muxes: + url: http://127.0.0.1:8989/api/v1/workspaces/default/muxes + headers: + Content-Type: application/json + rules: + - model: qwen2.5-coder:1.5b + matcher_type: catch_all + matcher: "" + testcases: ollama_chat: name: Ollama Chat diff --git a/tests/integration/openai/testcases.yaml b/tests/integration/openai/testcases.yaml index 603a69e7..452dcce6 100644 --- a/tests/integration/openai/testcases.yaml +++ b/tests/integration/openai/testcases.yaml @@ -2,6 +2,31 @@ headers: openai: Authorization: Bearer ENV_OPENAI_KEY +muxing: + mux_url: http://127.0.0.1:8989/v1/mux/ + trimm_from_testcase_url: http://127.0.0.1:8989/openai/ + provider_endpoint: + url: http://127.0.0.1:8989/api/v1/provider-endpoints + headers: + Content-Type: application/json + data: | + { + "name": "openai_muxing", + "description": "Muxing testing endpoint", + "provider_type": "openai", + "endpoint": "https://api.openai.com/", + "auth_type": "api_key", + "api_key": "ENV_OPENAI_KEY" + } + muxes: + url: http://127.0.0.1:8989/api/v1/workspaces/default/muxes + headers: + Content-Type: application/json + rules: + - model: gpt-4o-mini + matcher_type: catch_all + matcher: "" + testcases: openai_chat: name: OpenAI Chat diff --git a/tests/integration/openrouter/testcases.yaml b/tests/integration/openrouter/testcases.yaml index 1ced50b7..d64e0266 100644 --- a/tests/integration/openrouter/testcases.yaml +++ b/tests/integration/openrouter/testcases.yaml @@ -2,6 +2,31 @@ headers: openrouter: Authorization: Bearer ENV_OPENROUTER_KEY +muxing: + mux_url: http://127.0.0.1:8989/v1/mux/ + trimm_from_testcase_url: http://localhost:8989/openrouter/ + provider_endpoint: + url: http://127.0.0.1:8989/api/v1/provider-endpoints + headers: + Content-Type: application/json + data: | + { + "name": "openrouter_muxing", + "description": "Muxing testing endpoint", + "provider_type": "openrouter", + "endpoint": "https://openrouter.ai/api", + "auth_type": "api_key", + "api_key": "ENV_OPENROUTER_KEY" + } + muxes: + url: http://127.0.0.1:8989/api/v1/workspaces/default/muxes + headers: + Content-Type: application/json + rules: + - model: anthropic/claude-3.5-haiku + matcher_type: catch_all + matcher: "" + testcases: anthropic_chat: name: Openrouter Chat diff --git a/tests/integration/requesters.py b/tests/integration/requesters.py index 8441a51f..60ee9572 100644 --- a/tests/integration/requesters.py +++ b/tests/integration/requesters.py @@ -11,33 +11,43 @@ class BaseRequester(ABC): @abstractmethod - def make_request(self, url: str, headers: dict, data: dict) -> Optional[requests.Response]: + def make_request( + self, url: str, headers: dict, data: dict, method: str = "POST" + ) -> Optional[requests.Response]: pass class StandardRequester(BaseRequester): - def make_request(self, url: str, headers: dict, data: dict) -> Optional[requests.Response]: + def make_request( + self, url: str, headers: dict, data: dict, method: str = "POST" + ) -> Optional[requests.Response]: # Ensure Content-Type is always set correctly headers["Content-Type"] = "application/json" # Explicitly serialize to JSON string json_data = json.dumps(data) - return requests.post( - url, headers=headers, data=json_data # Use data instead of json parameter + return requests.request( + method=method, + url=url, + headers=headers, + data=json_data, # Use data instead of json parameter ) class CopilotRequester(BaseRequester): - def make_request(self, url: str, headers: dict, data: dict) -> Optional[requests.Response]: + def make_request( + self, url: str, headers: dict, data: dict, method: str = "POST" + ) -> Optional[requests.Response]: # Ensure Content-Type is always set correctly headers["Content-Type"] = "application/json" # Explicitly serialize to JSON string json_data = json.dumps(data) - return requests.post( - url, + return requests.request( + method=method, + url=url, data=json_data, # Use data instead of json parameter headers=headers, proxies={"https": "https://localhost:8990", "http": "http://localhost:8990"}, diff --git a/tests/integration/vllm/testcases.yaml b/tests/integration/vllm/testcases.yaml index 48e2bf6e..bb446ced 100644 --- a/tests/integration/vllm/testcases.yaml +++ b/tests/integration/vllm/testcases.yaml @@ -2,6 +2,30 @@ headers: vllm: Content-Type: application/json +muxing: + mux_url: http://127.0.0.1:8989/v1/mux/ + trimm_from_testcase_url: http://127.0.0.1:8989/vllm/ + provider_endpoint: + url: http://127.0.0.1:8989/api/v1/provider-endpoints + headers: + Content-Type: application/json + data: | + { + "name": "vllm_muxing", + "description": "Muxing testing endpoint", + "provider_type": "vllm", + "endpoint": "http://127.0.0.1:8000", + "auth_type": "none" + } + muxes: + url: http://127.0.0.1:8989/api/v1/workspaces/default/muxes + headers: + Content-Type: application/json + rules: + - model: Qwen/Qwen2.5-Coder-0.5B-Instruct + matcher_type: catch_all + matcher: "" + testcases: vllm_chat: name: VLLM Chat diff --git a/tests/muxing/test_adapter.py b/tests/muxing/test_adapter.py index 18b215c2..ba510ef0 100644 --- a/tests/muxing/test_adapter.py +++ b/tests/muxing/test_adapter.py @@ -22,6 +22,7 @@ def __init__(self, provider_type: ProviderType, endpoint_route: str): (ProviderType.openrouter, "https://openrouter.ai/api", "https://openrouter.ai/api/v1"), (ProviderType.openrouter, "https://openrouter.ai/", "https://openrouter.ai/api/v1"), (ProviderType.ollama, "http://localhost:11434", "http://localhost:11434"), + (ProviderType.vllm, "http://localhost:8000", "http://localhost:8000/v1"), ], ) def test_catch_all(provider_type, endpoint_route, expected_route):