Skip to content

Commit e68580c

Browse files
authored
Addd ability to change authentication for model (#809)
This allows us to reset authentication or even push API keys to codegate. Signed-off-by: Juan Antonio Osorio <[email protected]>
1 parent 6d5d895 commit e68580c

File tree

4 files changed

+54
-1
lines changed

4 files changed

+54
-1
lines changed

src/codegate/api/v1.py

+21
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,27 @@ async def add_provider_endpoint(
128128
return provend
129129

130130

131+
@v1.put(
132+
"/provider-endpoints/{provider_id}/auth-material",
133+
tags=["Providers"],
134+
generate_unique_id_function=uniq_name,
135+
status_code=204,
136+
)
137+
async def configure_auth_material(
138+
provider_id: UUID,
139+
request: v1_models.ConfigureAuthMaterial,
140+
):
141+
"""Configure auth material for a provider."""
142+
try:
143+
await pcrud.configure_auth_material(provider_id, request)
144+
except provendcrud.ProviderNotFoundError:
145+
raise HTTPException(status_code=404, detail="Provider endpoint not found")
146+
except Exception:
147+
raise HTTPException(status_code=500, detail="Internal server error")
148+
149+
return Response(status_code=204)
150+
151+
131152
@v1.put(
132153
"/provider-endpoints/{provider_id}", tags=["Providers"], generate_unique_id_function=uniq_name
133154
)

src/codegate/api/v1_models.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,7 @@ class ProviderEndpoint(pydantic.BaseModel):
223223
description: str = ""
224224
provider_type: ProviderType
225225
endpoint: str
226-
auth_type: ProviderAuthType
226+
auth_type: Optional[ProviderAuthType] = ProviderAuthType.none
227227

228228
@staticmethod
229229
def from_db_model(db_model: db_models.ProviderEndpoint) -> "ProviderEndpoint":
@@ -250,6 +250,15 @@ def get_from_registry(self, registry: ProviderRegistry) -> Optional[BaseProvider
250250
return registry.get_provider(self.provider_type)
251251

252252

253+
class ConfigureAuthMaterial(pydantic.BaseModel):
254+
"""
255+
Represents a request to configure auth material for a provider.
256+
"""
257+
258+
auth_type: ProviderAuthType
259+
api_key: Optional[str] = None
260+
261+
253262
class ModelByProvider(pydantic.BaseModel):
254263
"""
255264
Represents a model supported by a provider.

src/codegate/db/connection.py

+2
Original file line numberDiff line numberDiff line change
@@ -441,8 +441,10 @@ async def push_provider_auth_material(self, auth_material: ProviderAuthMaterial)
441441
UPDATE provider_endpoints
442442
SET auth_type = :auth_type, auth_blob = :auth_blob
443443
WHERE id = :provider_endpoint_id
444+
RETURNING id as provider_endpoint_id, auth_type, auth_blob
444445
"""
445446
)
447+
# Here we DONT want to return the result
446448
_ = await self._execute_update_pydantic_model(auth_material, sql, should_raise=True)
447449
return
448450

src/codegate/providers/crud/crud.py

+21
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,27 @@ async def update_endpoint(
8181
dbendpoint = await self._db_writer.update_provider_endpoint(endpoint.to_db_model())
8282
return apimodelsv1.ProviderEndpoint.from_db_model(dbendpoint)
8383

84+
async def configure_auth_material(
85+
self, provider_id: UUID, config: apimodelsv1.ConfigureAuthMaterial
86+
):
87+
"""Add an API key."""
88+
if config.auth_type == apimodelsv1.ProviderAuthType.api_key and not config.api_key:
89+
raise ValueError("API key must be provided for API auth type")
90+
elif config.auth_type != apimodelsv1.ProviderAuthType.api_key and config.api_key:
91+
raise ValueError("API key provided for non-API auth type")
92+
93+
dbendpoint = await self._db_reader.get_provider_endpoint_by_id(str(provider_id))
94+
if dbendpoint is None:
95+
raise ProviderNotFoundError("Provider not found")
96+
97+
await self._db_writer.push_provider_auth_material(
98+
dbmodels.ProviderAuthMaterial(
99+
provider_endpoint_id=dbendpoint.id,
100+
auth_type=config.auth_type,
101+
auth_blob=config.api_key if config.api_key else "",
102+
)
103+
)
104+
84105
async def delete_endpoint(self, provider_id: UUID):
85106
"""Delete an endpoint."""
86107

0 commit comments

Comments
 (0)