From e11164277e6f2358bc5919313e2e1fc90b9eb696 Mon Sep 17 00:00:00 2001 From: Alejandro Ponce Date: Tue, 11 Feb 2025 11:25:36 +0200 Subject: [PATCH] Remove code for mapping origin to destination in muxing The initial muxing implementation assumed that we would need to transform the input body from one provider to another. e.g. from athropic to openai. This changed. The final implementation is assuming that is receiving OpenAI format and will respond with OpenAI format. This PR is removing unused code. --- src/codegate/muxing/adapter.py | 58 +++------------------------------- src/codegate/muxing/router.py | 2 +- 2 files changed, 5 insertions(+), 55 deletions(-) diff --git a/src/codegate/muxing/adapter.py b/src/codegate/muxing/adapter.py index 010d99ce..f076df5e 100644 --- a/src/codegate/muxing/adapter.py +++ b/src/codegate/muxing/adapter.py @@ -22,40 +22,12 @@ class MuxingAdapterError(Exception): class BodyAdapter: """ - Map the body between OpenAI and Anthropic. + Format the body to the destination provider format. - TODO: This are really knaive implementations. We should replace them with more robust ones. + We expect the body to always be in OpenAI format. We need to configure the client + to send and expect OpenAI format. Here we just need to set the destination provider info. """ - def _from_openai_to_antrhopic(self, openai_body: dict) -> dict: - """Map the OpenAI body to the Anthropic body.""" - new_body = copy.deepcopy(openai_body) - messages = new_body.get("messages", []) - system_prompt = None - system_msg_idx = None - if messages: - for i_msg, msg in enumerate(messages): - if msg.get("role", "") == "system": - system_prompt = msg.get("content") - system_msg_idx = i_msg - break - if system_prompt: - new_body["system"] = system_prompt - if system_msg_idx is not None: - del messages[system_msg_idx] - return new_body - - def _from_anthropic_to_openai(self, anthropic_body: dict) -> dict: - """Map the Anthropic body to the OpenAI body.""" - new_body = copy.deepcopy(anthropic_body) - system_prompt = anthropic_body.get("system") - messages = new_body.get("messages", []) - if system_prompt: - messages.insert(0, {"role": "system", "content": system_prompt}) - if "system" in new_body: - del new_body["system"] - return new_body - def _get_provider_formatted_url(self, model_route: rulematcher.ModelRoute) -> str: """Get the provider formatted URL to use in base_url. Note this value comes from DB""" if model_route.endpoint.provider_type in [ @@ -65,35 +37,13 @@ def _get_provider_formatted_url(self, model_route: rulematcher.ModelRoute) -> st return f"{model_route.endpoint.endpoint}/v1" return model_route.endpoint.endpoint - def _set_destination_info(self, data: dict, model_route: rulematcher.ModelRoute) -> dict: + def set_destination_info(self, model_route: rulematcher.ModelRoute, data: dict) -> dict: """Set the destination provider info.""" new_data = copy.deepcopy(data) new_data["model"] = model_route.model.name new_data["base_url"] = self._get_provider_formatted_url(model_route) return new_data - def _identify_provider(self, data: dict) -> db_models.ProviderType: - """Identify the request provider.""" - if "system" in data: - return db_models.ProviderType.anthropic - else: - return db_models.ProviderType.openai - - def map_body_to_dest(self, model_route: rulematcher.ModelRoute, data: dict) -> dict: - """ - Map the body to the destination provider. - - We only need to transform the body if the destination or origin provider is Anthropic. - """ - origin_prov = self._identify_provider(data) - if model_route.endpoint.provider_type == db_models.ProviderType.anthropic: - if origin_prov != db_models.ProviderType.anthropic: - data = self._from_openai_to_antrhopic(data) - else: - if origin_prov == db_models.ProviderType.anthropic: - data = self._from_anthropic_to_openai(data) - return self._set_destination_info(data, model_route) - class StreamChunkFormatter: """ diff --git a/src/codegate/muxing/router.py b/src/codegate/muxing/router.py index 0e5a4676..df3a9d39 100644 --- a/src/codegate/muxing/router.py +++ b/src/codegate/muxing/router.py @@ -97,7 +97,7 @@ async def route_to_dest_provider( # 2. Map the request body to the destination provider format. rest_of_path = self._ensure_path_starts_with_slash(rest_of_path) - new_data = self._body_adapter.map_body_to_dest(model_route, data) + new_data = self._body_adapter.set_destination_info(model_route, data) # 3. Run pipeline. Selecting the correct destination provider. provider = self._provider_registry.get_provider(model_route.endpoint.provider_type)