Skip to content
This repository was archived by the owner on Jun 5, 2025. It is now read-only.

Remove code for mapping origin to destination in muxing #1010

Merged
merged 1 commit into from
Feb 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 4 additions & 54 deletions src/codegate/muxing/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,40 +22,12 @@ class MuxingAdapterError(Exception):

class BodyAdapter:
"""
Map the body between OpenAI and Anthropic.
Format the body to the destination provider format.

TODO: This are really knaive implementations. We should replace them with more robust ones.
We expect the body to always be in OpenAI format. We need to configure the client
to send and expect OpenAI format. Here we just need to set the destination provider info.
"""

def _from_openai_to_antrhopic(self, openai_body: dict) -> dict:
"""Map the OpenAI body to the Anthropic body."""
new_body = copy.deepcopy(openai_body)
messages = new_body.get("messages", [])
system_prompt = None
system_msg_idx = None
if messages:
for i_msg, msg in enumerate(messages):
if msg.get("role", "") == "system":
system_prompt = msg.get("content")
system_msg_idx = i_msg
break
if system_prompt:
new_body["system"] = system_prompt
if system_msg_idx is not None:
del messages[system_msg_idx]
return new_body

def _from_anthropic_to_openai(self, anthropic_body: dict) -> dict:
"""Map the Anthropic body to the OpenAI body."""
new_body = copy.deepcopy(anthropic_body)
system_prompt = anthropic_body.get("system")
messages = new_body.get("messages", [])
if system_prompt:
messages.insert(0, {"role": "system", "content": system_prompt})
if "system" in new_body:
del new_body["system"]
return new_body

def _get_provider_formatted_url(self, model_route: rulematcher.ModelRoute) -> str:
"""Get the provider formatted URL to use in base_url. Note this value comes from DB"""
if model_route.endpoint.provider_type in [
Expand All @@ -65,35 +37,13 @@ def _get_provider_formatted_url(self, model_route: rulematcher.ModelRoute) -> st
return f"{model_route.endpoint.endpoint}/v1"
return model_route.endpoint.endpoint

def _set_destination_info(self, data: dict, model_route: rulematcher.ModelRoute) -> dict:
def set_destination_info(self, model_route: rulematcher.ModelRoute, data: dict) -> dict:
"""Set the destination provider info."""
new_data = copy.deepcopy(data)
new_data["model"] = model_route.model.name
new_data["base_url"] = self._get_provider_formatted_url(model_route)
return new_data

def _identify_provider(self, data: dict) -> db_models.ProviderType:
"""Identify the request provider."""
if "system" in data:
return db_models.ProviderType.anthropic
else:
return db_models.ProviderType.openai

def map_body_to_dest(self, model_route: rulematcher.ModelRoute, data: dict) -> dict:
"""
Map the body to the destination provider.

We only need to transform the body if the destination or origin provider is Anthropic.
"""
origin_prov = self._identify_provider(data)
if model_route.endpoint.provider_type == db_models.ProviderType.anthropic:
if origin_prov != db_models.ProviderType.anthropic:
data = self._from_openai_to_antrhopic(data)
else:
if origin_prov == db_models.ProviderType.anthropic:
data = self._from_anthropic_to_openai(data)
return self._set_destination_info(data, model_route)


class StreamChunkFormatter:
"""
Expand Down
2 changes: 1 addition & 1 deletion src/codegate/muxing/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ async def route_to_dest_provider(

# 2. Map the request body to the destination provider format.
rest_of_path = self._ensure_path_starts_with_slash(rest_of_path)
new_data = self._body_adapter.map_body_to_dest(model_route, data)
new_data = self._body_adapter.set_destination_info(model_route, data)

# 3. Run pipeline. Selecting the correct destination provider.
provider = self._provider_registry.get_provider(model_route.endpoint.provider_type)
Expand Down