@@ -144,6 +144,32 @@ async def update_endpoint(
144
144
145
145
dbendpoint = await self ._db_writer .update_provider_endpoint (endpoint .to_db_model ())
146
146
147
+ # If the auth type has not changed or no authentication is needed,
148
+ # we can update the models
149
+ if (
150
+ founddbe .auth_type == endpoint .auth_type
151
+ or endpoint .auth_type == apimodelsv1 .ProviderAuthType .none
152
+ ):
153
+ try :
154
+ authm = await self ._db_reader .get_auth_material_by_provider_id (str (endpoint .id ))
155
+
156
+ models = await self ._find_models_for_provider (
157
+ endpoint , authm .auth_type , authm .auth_blob , prov
158
+ )
159
+
160
+ await self ._update_models_for_provider (dbendpoint , endpoint , prov , models )
161
+
162
+ # a model might have been deleted, let's repopulate the cache
163
+ await self ._ws_crud .repopulate_mux_cache ()
164
+ except Exception as err :
165
+ # This is a non-fatal error. The endpoint might have changed
166
+ # And the user will need to push a new API key anyway.
167
+ logger .error (
168
+ "Unable to update models for provider" ,
169
+ provider = endpoint .name ,
170
+ err = str (err ),
171
+ )
172
+
147
173
return apimodelsv1 .ProviderEndpoint .from_db_model (dbendpoint )
148
174
149
175
async def configure_auth_material (
@@ -164,12 +190,9 @@ async def configure_auth_material(
164
190
provider_registry = get_provider_registry ()
165
191
prov = endpoint .get_from_registry (provider_registry )
166
192
167
- models = []
168
- if config .auth_type != apimodelsv1 .ProviderAuthType .passthrough :
169
- try :
170
- models = prov .models (endpoint = endpoint .endpoint , api_key = config .api_key )
171
- except Exception as err :
172
- raise ProviderModelsNotFoundError (f"Unable to get models from provider: { err } " )
193
+ models = await self ._find_models_for_provider (
194
+ endpoint , config .auth_type , config .api_key , prov
195
+ )
173
196
174
197
await self ._db_writer .push_provider_auth_material (
175
198
dbmodels .ProviderAuthMaterial (
@@ -179,7 +202,32 @@ async def configure_auth_material(
179
202
)
180
203
)
181
204
182
- models_set = set (models )
205
+ await self ._update_models_for_provider (dbendpoint , endpoint , models )
206
+
207
+ # a model might have been deleted, let's repopulate the cache
208
+ await self ._ws_crud .repopulate_mux_cache ()
209
+
210
+ async def _find_models_for_provider (
211
+ self ,
212
+ endpoint : apimodelsv1 .ProviderEndpoint ,
213
+ auth_type : apimodelsv1 .ProviderAuthType ,
214
+ api_key : str ,
215
+ prov : BaseProvider ,
216
+ ) -> List [str ]:
217
+ if auth_type != apimodelsv1 .ProviderAuthType .passthrough :
218
+ try :
219
+ return prov .models (endpoint = endpoint .endpoint , api_key = api_key )
220
+ except Exception as err :
221
+ raise ProviderModelsNotFoundError (f"Unable to get models from provider: { err } " )
222
+ return []
223
+
224
+ async def _update_models_for_provider (
225
+ self ,
226
+ dbendpoint : dbmodels .ProviderEndpoint ,
227
+ endpoint : apimodelsv1 .ProviderEndpoint ,
228
+ found_models : List [str ],
229
+ ) -> None :
230
+ models_set = set (found_models )
183
231
184
232
# Get the models from the provider
185
233
models_in_db = await self ._db_reader .get_provider_models_by_provider_id (str (endpoint .id ))
@@ -202,9 +250,6 @@ async def configure_auth_material(
202
250
model ,
203
251
)
204
252
205
- # a model might have been deleted, let's repopulate the cache
206
- await self ._ws_crud .repopulate_mux_cache ()
207
-
208
253
async def delete_endpoint (self , provider_id : UUID ):
209
254
"""Delete an endpoint."""
210
255
0 commit comments