1- # Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654), 2023 Kent Keirsey (https://github.com/hipsterusername), 2024 Lincoln Stein
1+ # Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654), 2023 Kent Keirsey (https://github.com/hipsterusername), 2023 Lincoln D. Stein
22
33
4+ import pathlib
45from typing import Literal , List , Optional , Union
56
67from fastapi import Body , Path , Query , Response
2223ImportModelResponse = Union [tuple (OPENAPI_MODEL_CONFIGS )]
2324ConvertModelResponse = Union [tuple (OPENAPI_MODEL_CONFIGS )]
2425MergeModelResponse = Union [tuple (OPENAPI_MODEL_CONFIGS )]
26+ ImportModelAttributes = Union [tuple (OPENAPI_MODEL_CONFIGS )]
2527
2628class ModelsList (BaseModel ):
2729 models : list [Union [tuple (OPENAPI_MODEL_CONFIGS )]]
@@ -78,7 +80,7 @@ async def update_model(
7880 return model_response
7981
8082@models_router .post (
81- "/" ,
83+ "/import " ,
8284 operation_id = "import_model" ,
8385 responses = {
8486 201 : {"description" : "The model imported successfully" },
@@ -94,7 +96,7 @@ async def import_model(
9496 prediction_type : Optional [Literal ['v_prediction' ,'epsilon' ,'sample' ]] = \
9597 Body (description = 'Prediction type for SDv2 checkpoint files' , default = "v_prediction" ),
9698) -> ImportModelResponse :
97- """ Add a model using its local path, repo_id, or remote URL """
99+ """ Add a model using its local path, repo_id, or remote URL. Model characteristics will be probed and configured automatically """
98100
99101 items_to_import = {location }
100102 prediction_types = { x .value : x for x in SchedulerPredictionType }
@@ -126,18 +128,100 @@ async def import_model(
126128 logger .error (str (e ))
127129 raise HTTPException (status_code = 409 , detail = str (e ))
128130
131+ @models_router .post (
132+ "/add" ,
133+ operation_id = "add_model" ,
134+ responses = {
135+ 201 : {"description" : "The model added successfully" },
136+ 404 : {"description" : "The model could not be found" },
137+ 424 : {"description" : "The model appeared to add successfully, but could not be found in the model manager" },
138+ 409 : {"description" : "There is already a model corresponding to this path or repo_id" },
139+ },
140+ status_code = 201 ,
141+ response_model = ImportModelResponse
142+ )
143+ async def add_model (
144+ info : Union [tuple (OPENAPI_MODEL_CONFIGS )] = Body (description = "Model configuration" ),
145+ ) -> ImportModelResponse :
146+ """ Add a model using the configuration information appropriate for its type. Only local models can be added by path"""
147+
148+ logger = ApiDependencies .invoker .services .logger
149+
150+ try :
151+ ApiDependencies .invoker .services .model_manager .add_model (
152+ info .model_name ,
153+ info .base_model ,
154+ info .model_type ,
155+ model_attributes = info .dict ()
156+ )
157+ logger .info (f'Successfully added { info .model_name } ' )
158+ model_raw = ApiDependencies .invoker .services .model_manager .list_model (
159+ model_name = info .model_name ,
160+ base_model = info .base_model ,
161+ model_type = info .model_type
162+ )
163+ return parse_obj_as (ImportModelResponse , model_raw )
164+ except KeyError as e :
165+ logger .error (str (e ))
166+ raise HTTPException (status_code = 404 , detail = str (e ))
167+ except ValueError as e :
168+ logger .error (str (e ))
169+ raise HTTPException (status_code = 409 , detail = str (e ))
129170
171+ @models_router .post (
172+ "/rename/{base_model}/{model_type}/{model_name}" ,
173+ operation_id = "rename_model" ,
174+ responses = {
175+ 201 : {"description" : "The model was renamed successfully" },
176+ 404 : {"description" : "The model could not be found" },
177+ 409 : {"description" : "There is already a model corresponding to the new name" },
178+ },
179+ status_code = 201 ,
180+ response_model = ImportModelResponse
181+ )
182+ async def rename_model (
183+ base_model : BaseModelType = Path (description = "Base model" ),
184+ model_type : ModelType = Path (description = "The type of model" ),
185+ model_name : str = Path (description = "current model name" ),
186+ new_name : Optional [str ] = Query (description = "new model name" , default = None ),
187+ new_base : Optional [BaseModelType ] = Query (description = "new model base" , default = None ),
188+ ) -> ImportModelResponse :
189+ """ Rename a model"""
190+
191+ logger = ApiDependencies .invoker .services .logger
192+
193+ try :
194+ result = ApiDependencies .invoker .services .model_manager .rename_model (
195+ base_model = base_model ,
196+ model_type = model_type ,
197+ model_name = model_name ,
198+ new_name = new_name ,
199+ new_base = new_base ,
200+ )
201+ logger .debug (result )
202+ logger .info (f'Successfully renamed { model_name } =>{ new_name } ' )
203+ model_raw = ApiDependencies .invoker .services .model_manager .list_model (
204+ model_name = new_name or model_name ,
205+ base_model = new_base or base_model ,
206+ model_type = model_type
207+ )
208+ return parse_obj_as (ImportModelResponse , model_raw )
209+ except KeyError as e :
210+ logger .error (str (e ))
211+ raise HTTPException (status_code = 404 , detail = str (e ))
212+ except ValueError as e :
213+ logger .error (str (e ))
214+ raise HTTPException (status_code = 409 , detail = str (e ))
215+
130216@models_router .delete (
131217 "/{base_model}/{model_type}/{model_name}" ,
132218 operation_id = "del_model" ,
133219 responses = {
134- 204 : {
135- "description" : "Model deleted successfully"
136- },
137- 404 : {
138- "description" : "Model not found"
139- }
220+ 204 : { "description" : "Model deleted successfully" },
221+ 404 : { "description" : "Model not found" }
140222 },
223+ status_code = 204 ,
224+ response_model = None ,
141225)
142226async def delete_model (
143227 base_model : BaseModelType = Path (description = "Base model" ),
@@ -173,14 +257,17 @@ async def convert_model(
173257 base_model : BaseModelType = Path (description = "Base model" ),
174258 model_type : ModelType = Path (description = "The type of model" ),
175259 model_name : str = Path (description = "model name" ),
260+ convert_dest_directory : Optional [str ] = Query (default = None , description = "Save the converted model to the designated directory" ),
176261) -> ConvertModelResponse :
177- """Convert a checkpoint model into a diffusers model"""
262+ """Convert a checkpoint model into a diffusers model, optionally saving to the indicated destination directory, or `models` if none. """
178263 logger = ApiDependencies .invoker .services .logger
179264 try :
180265 logger .info (f"Converting model: { model_name } " )
266+ dest = pathlib .Path (convert_dest_directory ) if convert_dest_directory else None
181267 ApiDependencies .invoker .services .model_manager .convert_model (model_name ,
182268 base_model = base_model ,
183- model_type = model_type
269+ model_type = model_type ,
270+ convert_dest_directory = dest ,
184271 )
185272 model_raw = ApiDependencies .invoker .services .model_manager .list_model (model_name ,
186273 base_model = base_model ,
@@ -191,6 +278,53 @@ async def convert_model(
191278 except ValueError as e :
192279 raise HTTPException (status_code = 400 , detail = str (e ))
193280 return response
281+
282+ @models_router .get (
283+ "/search" ,
284+ operation_id = "search_for_models" ,
285+ responses = {
286+ 200 : { "description" : "Directory searched successfully" },
287+ 404 : { "description" : "Invalid directory path" },
288+ },
289+ status_code = 200 ,
290+ response_model = List [pathlib .Path ]
291+ )
292+ async def search_for_models (
293+ search_path : pathlib .Path = Query (description = "Directory path to search for models" )
294+ )-> List [pathlib .Path ]:
295+ if not search_path .is_dir ():
296+ raise HTTPException (status_code = 404 , detail = f"The search path '{ search_path } ' does not exist or is not directory" )
297+ return ApiDependencies .invoker .services .model_manager .search_for_models ([search_path ])
298+
299+ @models_router .get (
300+ "/ckpt_confs" ,
301+ operation_id = "list_ckpt_configs" ,
302+ responses = {
303+ 200 : { "description" : "paths retrieved successfully" },
304+ },
305+ status_code = 200 ,
306+ response_model = List [pathlib .Path ]
307+ )
308+ async def list_ckpt_configs (
309+ )-> List [pathlib .Path ]:
310+ """Return a list of the legacy checkpoint configuration files stored in `ROOT/configs/stable-diffusion`, relative to ROOT."""
311+ return ApiDependencies .invoker .services .model_manager .list_checkpoint_configs ()
312+
313+
314+ @models_router .get (
315+ "/sync" ,
316+ operation_id = "sync_to_config" ,
317+ responses = {
318+ 201 : { "description" : "synchronization successful" },
319+ },
320+ status_code = 201 ,
321+ response_model = None
322+ )
323+ async def sync_to_config (
324+ )-> None :
325+ """Call after making changes to models.yaml, autoimport directories or models directory to synchronize
326+ in-memory data structures with disk data structures."""
327+ return ApiDependencies .invoker .services .model_manager .sync_to_config ()
194328
195329@models_router .put (
196330 "/merge/{base_model}" ,
@@ -210,17 +344,21 @@ async def merge_models(
210344 alpha : Optional [float ] = Body (description = "Alpha weighting strength to apply to 2d and 3d models" , default = 0.5 ),
211345 interp : Optional [MergeInterpolationMethod ] = Body (description = "Interpolation method" ),
212346 force : Optional [bool ] = Body (description = "Force merging of models created with different versions of diffusers" , default = False ),
347+ merge_dest_directory : Optional [str ] = Body (description = "Save the merged model to the designated directory (with 'merged_model_name' appended)" , default = None )
213348) -> MergeModelResponse :
214349 """Convert a checkpoint model into a diffusers model"""
215350 logger = ApiDependencies .invoker .services .logger
216351 try :
217- logger .info (f"Merging models: { model_names } " )
352+ logger .info (f"Merging models: { model_names } into { merge_dest_directory or '<MODELS>' } /{ merged_model_name } " )
353+ dest = pathlib .Path (merge_dest_directory ) if merge_dest_directory else None
218354 result = ApiDependencies .invoker .services .model_manager .merge_models (model_names ,
219355 base_model ,
220- merged_model_name or "+" .join (model_names ),
221- alpha ,
222- interp ,
223- force )
356+ merged_model_name = merged_model_name or "+" .join (model_names ),
357+ alpha = alpha ,
358+ interp = interp ,
359+ force = force ,
360+ merge_dest_directory = dest
361+ )
224362 model_raw = ApiDependencies .invoker .services .model_manager .list_model (result .name ,
225363 base_model = base_model ,
226364 model_type = ModelType .Main ,
0 commit comments