diff --git a/src/codegate/muxing/adapter.py b/src/codegate/muxing/adapter.py index e63ded38..74513e98 100644 --- a/src/codegate/muxing/adapter.py +++ b/src/codegate/muxing/adapter.py @@ -32,11 +32,10 @@ class BodyAdapter: def _get_provider_formatted_url(self, model_route: rulematcher.ModelRoute) -> str: """Get the provider formatted URL to use in base_url. Note this value comes from DB""" - if model_route.endpoint.provider_type in [ - db_models.ProviderType.openai, - db_models.ProviderType.openrouter, - ]: + if model_route.endpoint.provider_type == db_models.ProviderType.openai: return urljoin(model_route.endpoint.endpoint, "/v1") + if model_route.endpoint.provider_type == db_models.ProviderType.openrouter: + return urljoin(model_route.endpoint.endpoint, "/api/v1") return model_route.endpoint.endpoint def set_destination_info(self, model_route: rulematcher.ModelRoute, data: dict) -> dict: @@ -199,7 +198,8 @@ def _format_antropic(self, chunk: str) -> str: ], ) return open_ai_chunk.model_dump_json(exclude_none=True, exclude_unset=True) - except Exception: + except Exception as e: + logger.warning(f"Error formatting Anthropic chunk: {chunk}. Error: {e}") return cleaned_chunk.strip() diff --git a/tests/muxing/__init__.py b/tests/muxing/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/muxing/test_adapter.py b/tests/muxing/test_adapter.py new file mode 100644 index 00000000..18b215c2 --- /dev/null +++ b/tests/muxing/test_adapter.py @@ -0,0 +1,31 @@ +import pytest + +from codegate.db.models import ProviderType +from codegate.muxing.adapter import BodyAdapter + + +class MockedEndpoint: + def __init__(self, provider_type: ProviderType, endpoint_route: str): + self.provider_type = provider_type + self.endpoint = endpoint_route + + +class MockedModelRoute: + def __init__(self, provider_type: ProviderType, endpoint_route: str): + self.endpoint = MockedEndpoint(provider_type, endpoint_route) + + +@pytest.mark.parametrize( + "provider_type, endpoint_route, expected_route", + [ + (ProviderType.openai, "https://api.openai.com/", "https://api.openai.com/v1"), + (ProviderType.openrouter, "https://openrouter.ai/api", "https://openrouter.ai/api/v1"), + (ProviderType.openrouter, "https://openrouter.ai/", "https://openrouter.ai/api/v1"), + (ProviderType.ollama, "http://localhost:11434", "http://localhost:11434"), + ], +) +def test_catch_all(provider_type, endpoint_route, expected_route): + body_adapter = BodyAdapter() + model_route = MockedModelRoute(provider_type, endpoint_route) + actual_route = body_adapter._get_provider_formatted_url(model_route) + assert actual_route == expected_route