Skip to content

Commit 5d5a497

Browse files
Model manager route enhancements (#3768)
# Multiple enhancements to model manager REACT API 1. add a `/sync` route for synchronizing the in-memory model lists to models.yaml, the models directory, and the autoimport directories. 2. added optional destination directories to convert_model and merge_model operations. 3. added a `/ckpt_confs` route for retrieving known legacy checkpoint configuration files. 4. added a `/search` route for finding all models in a directory located in the server filesystem 5. added a `/add` route for manual addition of a local models 6. added a `/rename` route for renaming and/or rebasing models 7. changed the path of the `import_model` route to `/import` # Slightly annoying detail: When adding a model manually using `/add`, the body JSON must exactly match one of the model configurations returned by `list_models` (i.e. there is no defaulting of fields). This includes the `error` field, which should be set to "null".
2 parents 194434d + 808b2de commit 5d5a497

File tree

9 files changed

+482
-94
lines changed

9 files changed

+482
-94
lines changed

invokeai/app/api/routers/models.py

Lines changed: 154 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1-
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654), 2023 Kent Keirsey (https://github.com/hipsterusername), 2024 Lincoln Stein
1+
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654), 2023 Kent Keirsey (https://github.com/hipsterusername), 2023 Lincoln D. Stein
22

33

4+
import pathlib
45
from typing import Literal, List, Optional, Union
56

67
from fastapi import Body, Path, Query, Response
@@ -22,6 +23,7 @@
2223
ImportModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
2324
ConvertModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
2425
MergeModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
26+
ImportModelAttributes = Union[tuple(OPENAPI_MODEL_CONFIGS)]
2527

2628
class ModelsList(BaseModel):
2729
models: list[Union[tuple(OPENAPI_MODEL_CONFIGS)]]
@@ -78,7 +80,7 @@ async def update_model(
7880
return model_response
7981

8082
@models_router.post(
81-
"/",
83+
"/import",
8284
operation_id="import_model",
8385
responses= {
8486
201: {"description" : "The model imported successfully"},
@@ -94,7 +96,7 @@ async def import_model(
9496
prediction_type: Optional[Literal['v_prediction','epsilon','sample']] = \
9597
Body(description='Prediction type for SDv2 checkpoint files', default="v_prediction"),
9698
) -> ImportModelResponse:
97-
""" Add a model using its local path, repo_id, or remote URL """
99+
""" Add a model using its local path, repo_id, or remote URL. Model characteristics will be probed and configured automatically """
98100

99101
items_to_import = {location}
100102
prediction_types = { x.value: x for x in SchedulerPredictionType }
@@ -126,18 +128,100 @@ async def import_model(
126128
logger.error(str(e))
127129
raise HTTPException(status_code=409, detail=str(e))
128130

131+
@models_router.post(
132+
"/add",
133+
operation_id="add_model",
134+
responses= {
135+
201: {"description" : "The model added successfully"},
136+
404: {"description" : "The model could not be found"},
137+
424: {"description" : "The model appeared to add successfully, but could not be found in the model manager"},
138+
409: {"description" : "There is already a model corresponding to this path or repo_id"},
139+
},
140+
status_code=201,
141+
response_model=ImportModelResponse
142+
)
143+
async def add_model(
144+
info: Union[tuple(OPENAPI_MODEL_CONFIGS)] = Body(description="Model configuration"),
145+
) -> ImportModelResponse:
146+
""" Add a model using the configuration information appropriate for its type. Only local models can be added by path"""
147+
148+
logger = ApiDependencies.invoker.services.logger
149+
150+
try:
151+
ApiDependencies.invoker.services.model_manager.add_model(
152+
info.model_name,
153+
info.base_model,
154+
info.model_type,
155+
model_attributes = info.dict()
156+
)
157+
logger.info(f'Successfully added {info.model_name}')
158+
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
159+
model_name=info.model_name,
160+
base_model=info.base_model,
161+
model_type=info.model_type
162+
)
163+
return parse_obj_as(ImportModelResponse, model_raw)
164+
except KeyError as e:
165+
logger.error(str(e))
166+
raise HTTPException(status_code=404, detail=str(e))
167+
except ValueError as e:
168+
logger.error(str(e))
169+
raise HTTPException(status_code=409, detail=str(e))
129170

171+
@models_router.post(
172+
"/rename/{base_model}/{model_type}/{model_name}",
173+
operation_id="rename_model",
174+
responses= {
175+
201: {"description" : "The model was renamed successfully"},
176+
404: {"description" : "The model could not be found"},
177+
409: {"description" : "There is already a model corresponding to the new name"},
178+
},
179+
status_code=201,
180+
response_model=ImportModelResponse
181+
)
182+
async def rename_model(
183+
base_model: BaseModelType = Path(description="Base model"),
184+
model_type: ModelType = Path(description="The type of model"),
185+
model_name: str = Path(description="current model name"),
186+
new_name: Optional[str] = Query(description="new model name", default=None),
187+
new_base: Optional[BaseModelType] = Query(description="new model base", default=None),
188+
) -> ImportModelResponse:
189+
""" Rename a model"""
190+
191+
logger = ApiDependencies.invoker.services.logger
192+
193+
try:
194+
result = ApiDependencies.invoker.services.model_manager.rename_model(
195+
base_model = base_model,
196+
model_type = model_type,
197+
model_name = model_name,
198+
new_name = new_name,
199+
new_base = new_base,
200+
)
201+
logger.debug(result)
202+
logger.info(f'Successfully renamed {model_name}=>{new_name}')
203+
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
204+
model_name=new_name or model_name,
205+
base_model=new_base or base_model,
206+
model_type=model_type
207+
)
208+
return parse_obj_as(ImportModelResponse, model_raw)
209+
except KeyError as e:
210+
logger.error(str(e))
211+
raise HTTPException(status_code=404, detail=str(e))
212+
except ValueError as e:
213+
logger.error(str(e))
214+
raise HTTPException(status_code=409, detail=str(e))
215+
130216
@models_router.delete(
131217
"/{base_model}/{model_type}/{model_name}",
132218
operation_id="del_model",
133219
responses={
134-
204: {
135-
"description": "Model deleted successfully"
136-
},
137-
404: {
138-
"description": "Model not found"
139-
}
220+
204: { "description": "Model deleted successfully" },
221+
404: { "description": "Model not found" }
140222
},
223+
status_code = 204,
224+
response_model = None,
141225
)
142226
async def delete_model(
143227
base_model: BaseModelType = Path(description="Base model"),
@@ -173,14 +257,17 @@ async def convert_model(
173257
base_model: BaseModelType = Path(description="Base model"),
174258
model_type: ModelType = Path(description="The type of model"),
175259
model_name: str = Path(description="model name"),
260+
convert_dest_directory: Optional[str] = Query(default=None, description="Save the converted model to the designated directory"),
176261
) -> ConvertModelResponse:
177-
"""Convert a checkpoint model into a diffusers model"""
262+
"""Convert a checkpoint model into a diffusers model, optionally saving to the indicated destination directory, or `models` if none."""
178263
logger = ApiDependencies.invoker.services.logger
179264
try:
180265
logger.info(f"Converting model: {model_name}")
266+
dest = pathlib.Path(convert_dest_directory) if convert_dest_directory else None
181267
ApiDependencies.invoker.services.model_manager.convert_model(model_name,
182268
base_model = base_model,
183-
model_type = model_type
269+
model_type = model_type,
270+
convert_dest_directory = dest,
184271
)
185272
model_raw = ApiDependencies.invoker.services.model_manager.list_model(model_name,
186273
base_model = base_model,
@@ -191,6 +278,53 @@ async def convert_model(
191278
except ValueError as e:
192279
raise HTTPException(status_code=400, detail=str(e))
193280
return response
281+
282+
@models_router.get(
283+
"/search",
284+
operation_id="search_for_models",
285+
responses={
286+
200: { "description": "Directory searched successfully" },
287+
404: { "description": "Invalid directory path" },
288+
},
289+
status_code = 200,
290+
response_model = List[pathlib.Path]
291+
)
292+
async def search_for_models(
293+
search_path: pathlib.Path = Query(description="Directory path to search for models")
294+
)->List[pathlib.Path]:
295+
if not search_path.is_dir():
296+
raise HTTPException(status_code=404, detail=f"The search path '{search_path}' does not exist or is not directory")
297+
return ApiDependencies.invoker.services.model_manager.search_for_models([search_path])
298+
299+
@models_router.get(
300+
"/ckpt_confs",
301+
operation_id="list_ckpt_configs",
302+
responses={
303+
200: { "description" : "paths retrieved successfully" },
304+
},
305+
status_code = 200,
306+
response_model = List[pathlib.Path]
307+
)
308+
async def list_ckpt_configs(
309+
)->List[pathlib.Path]:
310+
"""Return a list of the legacy checkpoint configuration files stored in `ROOT/configs/stable-diffusion`, relative to ROOT."""
311+
return ApiDependencies.invoker.services.model_manager.list_checkpoint_configs()
312+
313+
314+
@models_router.get(
315+
"/sync",
316+
operation_id="sync_to_config",
317+
responses={
318+
201: { "description": "synchronization successful" },
319+
},
320+
status_code = 201,
321+
response_model = None
322+
)
323+
async def sync_to_config(
324+
)->None:
325+
"""Call after making changes to models.yaml, autoimport directories or models directory to synchronize
326+
in-memory data structures with disk data structures."""
327+
return ApiDependencies.invoker.services.model_manager.sync_to_config()
194328

195329
@models_router.put(
196330
"/merge/{base_model}",
@@ -210,17 +344,21 @@ async def merge_models(
210344
alpha: Optional[float] = Body(description="Alpha weighting strength to apply to 2d and 3d models", default=0.5),
211345
interp: Optional[MergeInterpolationMethod] = Body(description="Interpolation method"),
212346
force: Optional[bool] = Body(description="Force merging of models created with different versions of diffusers", default=False),
347+
merge_dest_directory: Optional[str] = Body(description="Save the merged model to the designated directory (with 'merged_model_name' appended)", default=None)
213348
) -> MergeModelResponse:
214349
"""Convert a checkpoint model into a diffusers model"""
215350
logger = ApiDependencies.invoker.services.logger
216351
try:
217-
logger.info(f"Merging models: {model_names}")
352+
logger.info(f"Merging models: {model_names} into {merge_dest_directory or '<MODELS>'}/{merged_model_name}")
353+
dest = pathlib.Path(merge_dest_directory) if merge_dest_directory else None
218354
result = ApiDependencies.invoker.services.model_manager.merge_models(model_names,
219355
base_model,
220-
merged_model_name or "+".join(model_names),
221-
alpha,
222-
interp,
223-
force)
356+
merged_model_name=merged_model_name or "+".join(model_names),
357+
alpha=alpha,
358+
interp=interp,
359+
force=force,
360+
merge_dest_directory = dest
361+
)
224362
model_raw = ApiDependencies.invoker.services.model_manager.list_model(result.name,
225363
base_model = base_model,
226364
model_type = ModelType.Main,

0 commit comments

Comments
 (0)