diff --git a/src/codegate/providers/crud/crud.py b/src/codegate/providers/crud/crud.py index 453af16f..0bffe1a8 100644 --- a/src/codegate/providers/crud/crud.py +++ b/src/codegate/providers/crud/crud.py @@ -154,10 +154,10 @@ async def update_endpoint( authm = await self._db_reader.get_auth_material_by_provider_id(str(endpoint.id)) models = await self._find_models_for_provider( - endpoint, authm.auth_type, authm.auth_blob, prov + endpoint.endpoint, authm.auth_type, authm.auth_blob, prov ) - await self._update_models_for_provider(dbendpoint, endpoint, prov, models) + await self._update_models_for_provider(dbendpoint, models) # a model might have been deleted, let's repopulate the cache await self._ws_crud.repopulate_mux_cache() @@ -191,7 +191,7 @@ async def configure_auth_material( prov = endpoint.get_from_registry(provider_registry) models = await self._find_models_for_provider( - endpoint, config.auth_type, config.api_key, prov + endpoint.endpoint, config.auth_type, config.api_key, prov ) await self._db_writer.push_provider_auth_material( @@ -202,21 +202,21 @@ async def configure_auth_material( ) ) - await self._update_models_for_provider(dbendpoint, endpoint, models) + await self._update_models_for_provider(dbendpoint, models) # a model might have been deleted, let's repopulate the cache await self._ws_crud.repopulate_mux_cache() async def _find_models_for_provider( self, - endpoint: apimodelsv1.ProviderEndpoint, + endpoint: str, auth_type: apimodelsv1.ProviderAuthType, api_key: str, prov: BaseProvider, ) -> List[str]: if auth_type != apimodelsv1.ProviderAuthType.passthrough: try: - return prov.models(endpoint=endpoint.endpoint, api_key=api_key) + return prov.models(endpoint=endpoint, api_key=api_key) except Exception as err: raise ProviderModelsNotFoundError(f"Unable to get models from provider: {err}") return [] @@ -224,13 +224,12 @@ async def _find_models_for_provider( async def _update_models_for_provider( self, dbendpoint: dbmodels.ProviderEndpoint, - endpoint: apimodelsv1.ProviderEndpoint, found_models: List[str], ) -> None: models_set = set(found_models) # Get the models from the provider - models_in_db = await self._db_reader.get_provider_models_by_provider_id(str(endpoint.id)) + models_in_db = await self._db_reader.get_provider_models_by_provider_id(str(dbendpoint.id)) models_in_db_set = set(model.name for model in models_in_db) @@ -318,7 +317,7 @@ async def initialize_provider_endpoints(preg: ProviderRegistry): dbprovend = await db_reader.get_provider_endpoint_by_name(provend.name) if dbprovend is not None: logger.debug( - "Provider already in DB. Not re-adding.", + "Provider already in DB. skipping", provider=provend.name, endpoint=provend.endpoint, ) @@ -334,6 +333,21 @@ async def initialize_provider_endpoints(preg: ProviderRegistry): continue await try_initialize_provider_endpoints(provend, pimpl, db_writer) + provcrud = ProviderCrud() + + endpoints = await provcrud.list_endpoints() + for endpoint in endpoints: + dbprovend = await db_reader.get_provider_endpoint_by_name(endpoint.name) + pimpl = endpoint.get_from_registry(preg) + if pimpl is None: + logger.warning( + "Provider not found in registry", + provider=endpoint.name, + endpoint=endpoint.endpoint, + ) + continue + await try_update_to_provider(provcrud, pimpl, dbprovend) + async def try_initialize_provider_endpoints( provend: apimodelsv1.ProviderEndpoint, @@ -376,6 +390,30 @@ async def try_initialize_provider_endpoints( await asyncio.gather(*tasks) +async def try_update_to_provider( + provcrud: ProviderCrud, prov: BaseProvider, dbprovend: dbmodels.ProviderEndpoint +): + + authm = await provcrud._db_reader.get_auth_material_by_provider_id(str(dbprovend.id)) + + try: + models = await provcrud._find_models_for_provider( + dbprovend.endpoint, authm.auth_type, authm.auth_blob, prov + ) + except Exception as err: + logger.error( + "Unable to get models from provider. Skipping", + provider=dbprovend.name, + err=str(err), + ) + return + + await provcrud._update_models_for_provider(dbprovend, models) + + # a model might have been deleted, let's repopulate the cache + await provcrud._ws_crud.repopulate_mux_cache() + + def __provider_endpoint_from_cfg( provider_name: str, provider_url: str ) -> Optional[apimodelsv1.ProviderEndpoint]: