Skip to content

Commit c16ace7

Browse files
committed
Add scaffolding of API for provider endpoints and mux rules
This providers a scaffolding to start discussions and start drafting the provider endpoints implementation and mux rules. Closes: #753 Signed-off-by: Juan Antonio Osorio <[email protected]>
1 parent e05f49a commit c16ace7

File tree

2 files changed

+199
-0
lines changed

2 files changed

+199
-0
lines changed

src/codegate/api/v1.py

+120
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,86 @@ def uniq_name(route: APIRoute):
2525
return f"v1_{route.name}"
2626

2727

28+
@v1.get("/provider-endpoints", tags=["Providers"], generate_unique_id_function=uniq_name)
29+
async def list_provider_endpoints() -> List[v1_models.ProviderEndpoint]:
30+
"""List all provider endpoints."""
31+
# NOTE: This is a dummy implementation. In the future, we should have a proper
32+
# implementation that fetches the provider endpoints from the database.
33+
return list(
34+
v1_models.ProviderEndpoint(
35+
name="dummy",
36+
description="Dummy provider endpoint",
37+
endpoint="http://example.com",
38+
provider_type=v1_models.ProviderType.openai,
39+
auth_type=v1_models.ProviderAuthType.none,
40+
)
41+
)
42+
43+
44+
# QUESTION: Should we normalize the provider names to lowercase?
45+
@v1.get(
46+
"/provider-endpoints/{provider_name}", tags=["Providers"], generate_unique_id_function=uniq_name
47+
)
48+
async def get_provider_endpoint(provider_name: str) -> v1_models.ProviderEndpoint:
49+
"""Get a provider endpoint by name."""
50+
# NOTE: This is a dummy implementation. In the future, we should have a proper
51+
# implementation that fetches the provider endpoint from the database.
52+
return v1_models.ProviderEndpoint(
53+
name="dummy",
54+
description="Dummy provider endpoint",
55+
endpoint="http://example.com",
56+
provider_type=v1_models.ProviderType.openai,
57+
auth_type=v1_models.ProviderAuthType.none,
58+
)
59+
60+
61+
@v1.post(
62+
"/provider-endpoints",
63+
tags=["Providers"],
64+
generate_unique_id_function=uniq_name,
65+
status_code=201,
66+
)
67+
async def add_provider_endpoint(request: v1_models.ProviderEndpoint) -> v1_models.ProviderEndpoint:
68+
"""Add a provider endpoint."""
69+
# NOTE: This is a dummy implementation. In the future, we should have a proper
70+
# implementation that adds the provider endpoint to the database.
71+
return request
72+
73+
74+
@v1.put(
75+
"/provider-endpoints/{provider_name}", tags=["Providers"], generate_unique_id_function=uniq_name
76+
)
77+
async def update_provider_endpoint(
78+
provider_name: str, request: v1_models.ProviderEndpoint
79+
) -> v1_models.ProviderEndpoint:
80+
"""Update a provider endpoint by name."""
81+
# NOTE: This is a dummy implementation. In the future, we should have a proper
82+
# implementation that updates the provider endpoint in the database.
83+
return request
84+
85+
86+
@v1.delete(
87+
"/provider-endpoints/{provider_name}", tags=["Providers"], generate_unique_id_function=uniq_name
88+
)
89+
async def delete_provider_endpoint(provider_name: str):
90+
"""Delete a provider endpoint by name."""
91+
# NOTE: This is a dummy implementation. In the future, we should have a proper
92+
# implementation that deletes the provider endpoint from the database.
93+
return Response(status_code=204)
94+
95+
96+
@v1.get(
97+
"/provider-endpoints/{provider_name}/models",
98+
tags=["Providers"],
99+
generate_unique_id_function=uniq_name,
100+
)
101+
async def list_models_by_provider(provider_name: str) -> List[v1_models.ModelByProvider]:
102+
"""List models by provider."""
103+
# NOTE: This is a dummy implementation. In the future, we should have a proper
104+
# implementation that fetches the models by provider from the database.
105+
return list(v1_models.ModelByProvider(name="dummy", provider="dummy"))
106+
107+
28108
@v1.get("/workspaces", tags=["Workspaces"], generate_unique_id_function=uniq_name)
29109
async def list_workspaces() -> v1_models.ListWorkspacesResponse:
30110
"""List all workspaces."""
@@ -296,6 +376,46 @@ async def delete_workspace_custom_instructions(workspace_name: str):
296376
return Response(status_code=204)
297377

298378

379+
@v1.get(
380+
"/workspaces/{workspace_name}/muxes",
381+
tags=["Workspaces", "Muxes"],
382+
generate_unique_id_function=uniq_name,
383+
)
384+
async def get_workspace_muxes(workspace_name: str) -> List[v1_models.MuxRule]:
385+
"""Get the mux rules of a workspace.
386+
387+
The list is ordered in order of priority. That is, the first rule in the list
388+
has the highest priority."""
389+
# TODO: This is a dummy implementation. In the future, we should have a proper
390+
# implementation that fetches the mux rules from the database.
391+
return [
392+
v1_models.MuxRule(
393+
provider="openai",
394+
model="gpt-3.5-turbo",
395+
matcher_type=v1_models.MuxMatcherType.file_regex,
396+
matcher=".*\\.txt",
397+
),
398+
v1_models.MuxRule(
399+
provider="anthropic",
400+
model="davinci",
401+
matcher_type=v1_models.MuxMatcherType.catch_all,
402+
),
403+
]
404+
405+
406+
@v1.put(
407+
"/workspaces/{workspace_name}/muxes",
408+
tags=["Workspaces", "Muxes"],
409+
generate_unique_id_function=uniq_name,
410+
status_code=204,
411+
)
412+
async def set_workspace_muxes(workspace_name: str, request: List[v1_models.MuxRule]):
413+
"""Set the mux rules of a workspace."""
414+
# TODO: This is a dummy implementation. In the future, we should have a proper
415+
# implementation that sets the mux rules in the database.
416+
return Response(status_code=204)
417+
418+
299419
@v1.get("/alerts_notification", tags=["Dashboard"], generate_unique_id_function=uniq_name)
300420
async def stream_sse():
301421
"""

src/codegate/api/v1_models.py

+79
Original file line numberDiff line numberDiff line change
@@ -138,3 +138,82 @@ class AlertConversation(pydantic.BaseModel):
138138
trigger_type: str
139139
trigger_category: Optional[str]
140140
timestamp: datetime.datetime
141+
142+
143+
class ProviderType(str, Enum):
144+
"""
145+
Represents the different types of providers we support.
146+
"""
147+
148+
openai = "openai"
149+
anthropic = "anthropic"
150+
vllm = "vllm"
151+
152+
153+
class ProviderAuthType(str, Enum):
154+
"""
155+
Represents the different types of auth we support for providers.
156+
"""
157+
158+
# No auth required
159+
none = "none"
160+
# Whatever the user provides is passed through
161+
passthrough = "passthrough"
162+
# API key is required
163+
api_key = "api_key"
164+
165+
166+
class ProviderEndpoint(pydantic.BaseModel):
167+
"""
168+
Represents a provider's endpoint configuration. This
169+
allows us to persist the configuration for each provider,
170+
so we can use this for muxing messages.
171+
"""
172+
173+
name: str
174+
description: str = ""
175+
provider_type: ProviderType
176+
endpoint: str
177+
auth_type: ProviderAuthType
178+
179+
180+
class ModelByProvider(pydantic.BaseModel):
181+
"""
182+
Represents a model supported by a provider.
183+
184+
Note that these are auto-discovered by the provider.
185+
"""
186+
187+
name: str
188+
provider: str
189+
190+
def __str__(self):
191+
return f"{self.provider}/{self.name}"
192+
193+
194+
class MuxMatcherType(str, Enum):
195+
"""
196+
Represents the different types of matchers we support.
197+
"""
198+
199+
# Match a regular expression for a file path
200+
# in the prompt. Note that if no file is found,
201+
# the prompt will be passed through.
202+
file_regex = "file_regex"
203+
204+
# Always match this prompt
205+
catch_all = "catch_all"
206+
207+
208+
class MuxRule(pydantic.BaseModel):
209+
"""
210+
Represents a mux rule for a provider.
211+
"""
212+
213+
provider: str
214+
model: str
215+
# The type of matcher to use
216+
matcher_type: MuxMatcherType
217+
# The actual matcher to use. Note that
218+
# this depends on the matcher type.
219+
matcher: Optional[str]

0 commit comments

Comments
 (0)