diff --git a/src/codegate/muxing/adapter.py b/src/codegate/muxing/adapter.py index 010d99ce..f076df5e 100644 --- a/src/codegate/muxing/adapter.py +++ b/src/codegate/muxing/adapter.py @@ -22,40 +22,12 @@ class MuxingAdapterError(Exception): class BodyAdapter: """ - Map the body between OpenAI and Anthropic. + Format the body to the destination provider format. - TODO: This are really knaive implementations. We should replace them with more robust ones. + We expect the body to always be in OpenAI format. We need to configure the client + to send and expect OpenAI format. Here we just need to set the destination provider info. """ - def _from_openai_to_antrhopic(self, openai_body: dict) -> dict: - """Map the OpenAI body to the Anthropic body.""" - new_body = copy.deepcopy(openai_body) - messages = new_body.get("messages", []) - system_prompt = None - system_msg_idx = None - if messages: - for i_msg, msg in enumerate(messages): - if msg.get("role", "") == "system": - system_prompt = msg.get("content") - system_msg_idx = i_msg - break - if system_prompt: - new_body["system"] = system_prompt - if system_msg_idx is not None: - del messages[system_msg_idx] - return new_body - - def _from_anthropic_to_openai(self, anthropic_body: dict) -> dict: - """Map the Anthropic body to the OpenAI body.""" - new_body = copy.deepcopy(anthropic_body) - system_prompt = anthropic_body.get("system") - messages = new_body.get("messages", []) - if system_prompt: - messages.insert(0, {"role": "system", "content": system_prompt}) - if "system" in new_body: - del new_body["system"] - return new_body - def _get_provider_formatted_url(self, model_route: rulematcher.ModelRoute) -> str: """Get the provider formatted URL to use in base_url. Note this value comes from DB""" if model_route.endpoint.provider_type in [ @@ -65,35 +37,13 @@ def _get_provider_formatted_url(self, model_route: rulematcher.ModelRoute) -> st return f"{model_route.endpoint.endpoint}/v1" return model_route.endpoint.endpoint - def _set_destination_info(self, data: dict, model_route: rulematcher.ModelRoute) -> dict: + def set_destination_info(self, model_route: rulematcher.ModelRoute, data: dict) -> dict: """Set the destination provider info.""" new_data = copy.deepcopy(data) new_data["model"] = model_route.model.name new_data["base_url"] = self._get_provider_formatted_url(model_route) return new_data - def _identify_provider(self, data: dict) -> db_models.ProviderType: - """Identify the request provider.""" - if "system" in data: - return db_models.ProviderType.anthropic - else: - return db_models.ProviderType.openai - - def map_body_to_dest(self, model_route: rulematcher.ModelRoute, data: dict) -> dict: - """ - Map the body to the destination provider. - - We only need to transform the body if the destination or origin provider is Anthropic. - """ - origin_prov = self._identify_provider(data) - if model_route.endpoint.provider_type == db_models.ProviderType.anthropic: - if origin_prov != db_models.ProviderType.anthropic: - data = self._from_openai_to_antrhopic(data) - else: - if origin_prov == db_models.ProviderType.anthropic: - data = self._from_anthropic_to_openai(data) - return self._set_destination_info(data, model_route) - class StreamChunkFormatter: """ diff --git a/src/codegate/muxing/router.py b/src/codegate/muxing/router.py index 0e5a4676..df3a9d39 100644 --- a/src/codegate/muxing/router.py +++ b/src/codegate/muxing/router.py @@ -97,7 +97,7 @@ async def route_to_dest_provider( # 2. Map the request body to the destination provider format. rest_of_path = self._ensure_path_starts_with_slash(rest_of_path) - new_data = self._body_adapter.map_body_to_dest(model_route, data) + new_data = self._body_adapter.set_destination_info(model_route, data) # 3. Run pipeline. Selecting the correct destination provider. provider = self._provider_registry.get_provider(model_route.endpoint.provider_type)