Skip to content
This repository was archived by the owner on Jun 5, 2025. It is now read-only.

Update models on codegate initialization #1027

Merged
merged 1 commit into from
Feb 12, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 47 additions & 9 deletions src/codegate/providers/crud/crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,10 +154,10 @@ async def update_endpoint(
authm = await self._db_reader.get_auth_material_by_provider_id(str(endpoint.id))

models = await self._find_models_for_provider(
endpoint, authm.auth_type, authm.auth_blob, prov
endpoint.endpoint, authm.auth_type, authm.auth_blob, prov
)

await self._update_models_for_provider(dbendpoint, endpoint, prov, models)
await self._update_models_for_provider(dbendpoint, models)

# a model might have been deleted, let's repopulate the cache
await self._ws_crud.repopulate_mux_cache()
Expand Down Expand Up @@ -191,7 +191,7 @@ async def configure_auth_material(
prov = endpoint.get_from_registry(provider_registry)

models = await self._find_models_for_provider(
endpoint, config.auth_type, config.api_key, prov
endpoint.endpoint, config.auth_type, config.api_key, prov
)

await self._db_writer.push_provider_auth_material(
Expand All @@ -202,35 +202,34 @@ async def configure_auth_material(
)
)

await self._update_models_for_provider(dbendpoint, endpoint, models)
await self._update_models_for_provider(dbendpoint, models)

# a model might have been deleted, let's repopulate the cache
await self._ws_crud.repopulate_mux_cache()

async def _find_models_for_provider(
self,
endpoint: apimodelsv1.ProviderEndpoint,
endpoint: str,
auth_type: apimodelsv1.ProviderAuthType,
api_key: str,
prov: BaseProvider,
) -> List[str]:
if auth_type != apimodelsv1.ProviderAuthType.passthrough:
try:
return prov.models(endpoint=endpoint.endpoint, api_key=api_key)
return prov.models(endpoint=endpoint, api_key=api_key)
except Exception as err:
raise ProviderModelsNotFoundError(f"Unable to get models from provider: {err}")
return []

async def _update_models_for_provider(
self,
dbendpoint: dbmodels.ProviderEndpoint,
endpoint: apimodelsv1.ProviderEndpoint,
found_models: List[str],
) -> None:
models_set = set(found_models)

# Get the models from the provider
models_in_db = await self._db_reader.get_provider_models_by_provider_id(str(endpoint.id))
models_in_db = await self._db_reader.get_provider_models_by_provider_id(str(dbendpoint.id))

models_in_db_set = set(model.name for model in models_in_db)

Expand Down Expand Up @@ -318,7 +317,7 @@ async def initialize_provider_endpoints(preg: ProviderRegistry):
dbprovend = await db_reader.get_provider_endpoint_by_name(provend.name)
if dbprovend is not None:
logger.debug(
"Provider already in DB. Not re-adding.",
"Provider already in DB. skipping",
provider=provend.name,
endpoint=provend.endpoint,
)
Expand All @@ -334,6 +333,21 @@ async def initialize_provider_endpoints(preg: ProviderRegistry):
continue
await try_initialize_provider_endpoints(provend, pimpl, db_writer)

provcrud = ProviderCrud()

endpoints = await provcrud.list_endpoints()
for endpoint in endpoints:
dbprovend = await db_reader.get_provider_endpoint_by_name(endpoint.name)
pimpl = endpoint.get_from_registry(preg)
if pimpl is None:
logger.warning(
"Provider not found in registry",
provider=endpoint.name,
endpoint=endpoint.endpoint,
)
continue
await try_update_to_provider(provcrud, pimpl, dbprovend)


async def try_initialize_provider_endpoints(
provend: apimodelsv1.ProviderEndpoint,
Expand Down Expand Up @@ -376,6 +390,30 @@ async def try_initialize_provider_endpoints(
await asyncio.gather(*tasks)


async def try_update_to_provider(
provcrud: ProviderCrud, prov: BaseProvider, dbprovend: dbmodels.ProviderEndpoint
):

authm = await provcrud._db_reader.get_auth_material_by_provider_id(str(dbprovend.id))

try:
models = await provcrud._find_models_for_provider(
dbprovend.endpoint, authm.auth_type, authm.auth_blob, prov
)
except Exception as err:
logger.error(
"Unable to get models from provider. Skipping",
provider=dbprovend.name,
err=str(err),
)
return

await provcrud._update_models_for_provider(dbprovend, models)

# a model might have been deleted, let's repopulate the cache
await provcrud._ws_crud.repopulate_mux_cache()


def __provider_endpoint_from_cfg(
provider_name: str, provider_url: str
) -> Optional[apimodelsv1.ProviderEndpoint]:
Expand Down