Skip to content

Commit f361178

Browse files
authored
Enable provider model updating when updating provider itself (#1014)
This now calls the model updating logic when updating the provider itself. Thus allowing us to have a way to update the model list. Signed-off-by: Juan Antonio Osorio <[email protected]>
1 parent d86a965 commit f361178

File tree

2 files changed

+56
-11
lines changed

2 files changed

+56
-11
lines changed

src/codegate/providers/crud/crud.py

+55-10
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,32 @@ async def update_endpoint(
144144

145145
dbendpoint = await self._db_writer.update_provider_endpoint(endpoint.to_db_model())
146146

147+
# If the auth type has not changed or no authentication is needed,
148+
# we can update the models
149+
if (
150+
founddbe.auth_type == endpoint.auth_type
151+
or endpoint.auth_type == apimodelsv1.ProviderAuthType.none
152+
):
153+
try:
154+
authm = await self._db_reader.get_auth_material_by_provider_id(str(endpoint.id))
155+
156+
models = await self._find_models_for_provider(
157+
endpoint, authm.auth_type, authm.auth_blob, prov
158+
)
159+
160+
await self._update_models_for_provider(dbendpoint, endpoint, prov, models)
161+
162+
# a model might have been deleted, let's repopulate the cache
163+
await self._ws_crud.repopulate_mux_cache()
164+
except Exception as err:
165+
# This is a non-fatal error. The endpoint might have changed
166+
# And the user will need to push a new API key anyway.
167+
logger.error(
168+
"Unable to update models for provider",
169+
provider=endpoint.name,
170+
err=str(err),
171+
)
172+
147173
return apimodelsv1.ProviderEndpoint.from_db_model(dbendpoint)
148174

149175
async def configure_auth_material(
@@ -164,12 +190,9 @@ async def configure_auth_material(
164190
provider_registry = get_provider_registry()
165191
prov = endpoint.get_from_registry(provider_registry)
166192

167-
models = []
168-
if config.auth_type != apimodelsv1.ProviderAuthType.passthrough:
169-
try:
170-
models = prov.models(endpoint=endpoint.endpoint, api_key=config.api_key)
171-
except Exception as err:
172-
raise ProviderModelsNotFoundError(f"Unable to get models from provider: {err}")
193+
models = await self._find_models_for_provider(
194+
endpoint, config.auth_type, config.api_key, prov
195+
)
173196

174197
await self._db_writer.push_provider_auth_material(
175198
dbmodels.ProviderAuthMaterial(
@@ -179,7 +202,32 @@ async def configure_auth_material(
179202
)
180203
)
181204

182-
models_set = set(models)
205+
await self._update_models_for_provider(dbendpoint, endpoint, models)
206+
207+
# a model might have been deleted, let's repopulate the cache
208+
await self._ws_crud.repopulate_mux_cache()
209+
210+
async def _find_models_for_provider(
211+
self,
212+
endpoint: apimodelsv1.ProviderEndpoint,
213+
auth_type: apimodelsv1.ProviderAuthType,
214+
api_key: str,
215+
prov: BaseProvider,
216+
) -> List[str]:
217+
if auth_type != apimodelsv1.ProviderAuthType.passthrough:
218+
try:
219+
return prov.models(endpoint=endpoint.endpoint, api_key=api_key)
220+
except Exception as err:
221+
raise ProviderModelsNotFoundError(f"Unable to get models from provider: {err}")
222+
return []
223+
224+
async def _update_models_for_provider(
225+
self,
226+
dbendpoint: dbmodels.ProviderEndpoint,
227+
endpoint: apimodelsv1.ProviderEndpoint,
228+
found_models: List[str],
229+
) -> None:
230+
models_set = set(found_models)
183231

184232
# Get the models from the provider
185233
models_in_db = await self._db_reader.get_provider_models_by_provider_id(str(endpoint.id))
@@ -202,9 +250,6 @@ async def configure_auth_material(
202250
model,
203251
)
204252

205-
# a model might have been deleted, let's repopulate the cache
206-
await self._ws_crud.repopulate_mux_cache()
207-
208253
async def delete_endpoint(self, provider_id: UUID):
209254
"""Delete an endpoint."""
210255

src/codegate/providers/ollama/provider.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import json
2-
from typing import List, Optional
2+
from typing import List
33

44
import httpx
55
import structlog

0 commit comments

Comments
 (0)