11
11
from codegate .db import models as dbmodels
12
12
from codegate .db .connection import DbReader , DbRecorder
13
13
from codegate .providers .base import BaseProvider
14
- from codegate .providers .registry import ProviderRegistry
14
+ from codegate .providers .registry import ProviderRegistry , get_provider_registry
15
15
16
16
logger = structlog .get_logger ("codegate" )
17
17
@@ -62,23 +62,106 @@ async def get_endpoint_by_name(self, name: str) -> Optional[apimodelsv1.Provider
62
62
return apimodelsv1 .ProviderEndpoint .from_db_model (dbendpoint )
63
63
64
64
async def add_endpoint (
65
- self , endpoint : apimodelsv1 .ProviderEndpoint
65
+ self , endpoint : apimodelsv1 .AddProviderEndpointRequest
66
66
) -> apimodelsv1 .ProviderEndpoint :
67
67
"""Add an endpoint."""
68
+
69
+ if not endpoint .endpoint :
70
+ endpoint .endpoint = provider_default_endpoints (endpoint .provider_type )
71
+
72
+ # If we STILL don't have an endpoint, we can't continue
73
+ if not endpoint .endpoint :
74
+ raise ValueError ("No endpoint provided and no default found for provider type" )
75
+
68
76
dbend = endpoint .to_db_model ()
77
+ provider_registry = get_provider_registry ()
69
78
70
79
# We override the ID here, as we want to generate it.
71
80
dbend .id = str (uuid4 ())
72
81
73
- dbendpoint = await self ._db_writer .add_provider_endpoint ()
82
+ prov = endpoint .get_from_registry (provider_registry )
83
+ if prov is None :
84
+ raise ValueError ("Unknown provider type: {}" .format (endpoint .provider_type ))
85
+
86
+ models = []
87
+ if endpoint .auth_type == apimodelsv1 .ProviderAuthType .api_key and not endpoint .api_key :
88
+ raise ValueError ("API key must be provided for API auth type" )
89
+ if endpoint .auth_type != apimodelsv1 .ProviderAuthType .passthrough :
90
+ try :
91
+ models = prov .models (endpoint = endpoint .endpoint , api_key = endpoint .api_key )
92
+ except Exception as err :
93
+ raise ValueError ("Unable to get models from provider: {}" .format (str (err )))
94
+
95
+ dbendpoint = await self ._db_writer .add_provider_endpoint (dbend )
96
+
97
+ await self ._db_writer .push_provider_auth_material (
98
+ dbmodels .ProviderAuthMaterial (
99
+ provider_endpoint_id = dbendpoint .id ,
100
+ auth_type = endpoint .auth_type ,
101
+ auth_blob = endpoint .api_key if endpoint .api_key else "" ,
102
+ )
103
+ )
104
+
105
+ for model in models :
106
+ await self ._db_writer .add_provider_model (
107
+ dbmodels .ProviderModel (
108
+ provider_endpoint_id = dbendpoint .id ,
109
+ name = model ,
110
+ )
111
+ )
74
112
return apimodelsv1 .ProviderEndpoint .from_db_model (dbendpoint )
75
113
76
114
async def update_endpoint (
77
- self , endpoint : apimodelsv1 .ProviderEndpoint
115
+ self , endpoint : apimodelsv1 .AddProviderEndpointRequest
78
116
) -> apimodelsv1 .ProviderEndpoint :
79
117
"""Update an endpoint."""
80
118
119
+ if not endpoint .endpoint :
120
+ endpoint .endpoint = provider_default_endpoints (endpoint .provider_type )
121
+
122
+ # If we STILL don't have an endpoint, we can't continue
123
+ if not endpoint .endpoint :
124
+ raise ValueError ("No endpoint provided and no default found for provider type" )
125
+
126
+ provider_registry = get_provider_registry ()
127
+ prov = endpoint .get_from_registry (provider_registry )
128
+ if prov is None :
129
+ raise ValueError ("Unknown provider type: {}" .format (endpoint .provider_type ))
130
+
131
+ founddbe = await self ._db_reader .get_provider_endpoint_by_id (str (endpoint .id ))
132
+ if founddbe is None :
133
+ raise ProviderNotFoundError ("Provider not found" )
134
+
135
+ models = []
136
+ if endpoint .auth_type == apimodelsv1 .ProviderAuthType .api_key and not endpoint .api_key :
137
+ raise ValueError ("API key must be provided for API auth type" )
138
+ if endpoint .auth_type != apimodelsv1 .ProviderAuthType .passthrough :
139
+ try :
140
+ models = prov .models (endpoint = endpoint .endpoint , api_key = endpoint .api_key )
141
+ except Exception as err :
142
+ raise ValueError ("Unable to get models from provider: {}" .format (str (err )))
143
+
144
+ # Reset all provider models.
145
+ await self ._db_writer .delete_provider_models (str (endpoint .id ))
146
+
147
+ for model in models :
148
+ await self ._db_writer .add_provider_model (
149
+ dbmodels .ProviderModel (
150
+ provider_endpoint_id = founddbe .id ,
151
+ name = model ,
152
+ )
153
+ )
154
+
81
155
dbendpoint = await self ._db_writer .update_provider_endpoint (endpoint .to_db_model ())
156
+
157
+ await self ._db_writer .push_provider_auth_material (
158
+ dbmodels .ProviderAuthMaterial (
159
+ provider_endpoint_id = dbendpoint .id ,
160
+ auth_type = endpoint .auth_type ,
161
+ auth_blob = endpoint .api_key if endpoint .api_key else "" ,
162
+ )
163
+ )
164
+
82
165
return apimodelsv1 .ProviderEndpoint .from_db_model (dbendpoint )
83
166
84
167
async def configure_auth_material (
@@ -175,6 +258,13 @@ async def initialize_provider_endpoints(preg: ProviderRegistry):
175
258
continue
176
259
177
260
pimpl = provend .get_from_registry (preg )
261
+ if pimpl is None :
262
+ logger .warning (
263
+ "Provider not found in registry" ,
264
+ provider = provend .name ,
265
+ endpoint = provend .endpoint ,
266
+ )
267
+ continue
178
268
await try_initialize_provider_endpoints (provend , pimpl , db_writer )
179
269
180
270
@@ -240,7 +330,7 @@ def __provider_endpoint_from_cfg(
240
330
description = ("Endpoint for the {} provided via the CodeGate configuration." ).format (
241
331
provider_name
242
332
),
243
- provider_type = provider_name ,
333
+ provider_type = provider_overrides ( provider_name ) ,
244
334
auth_type = apimodelsv1 .ProviderAuthType .passthrough ,
245
335
)
246
336
except ValidationError as err :
@@ -251,3 +341,24 @@ def __provider_endpoint_from_cfg(
251
341
err = str (err ),
252
342
)
253
343
return None
344
+
345
+
346
+ def provider_default_endpoints (provider_type : str ) -> str :
347
+ defaults = {
348
+ "openai" : "https://api.openai.com" ,
349
+ "anthropic" : "https://api.anthropic.com" ,
350
+ }
351
+
352
+ # If we have a default, we return it
353
+ # Otherwise, we return an empty string
354
+ return defaults .get (provider_type , "" )
355
+
356
+
357
+ def provider_overrides (provider_type : str ) -> str :
358
+ overrides = {
359
+ "lm_studio" : "openai" ,
360
+ }
361
+
362
+ # If we have an override, we return it
363
+ # Otherwise, we return the type
364
+ return overrides .get (provider_type , provider_type )
0 commit comments