Skip to content

Commit 454d05b

Browse files
refactor: model manager v3 (#8607)
* feat(mm): add UnknownModelConfig * refactor(ui): move model categorisation-ish logic to central location, simplify model manager models list * refactor(ui)refactor(ui): more cleanup of model categories * refactor(ui): remove unused excludeSubmodels I can't remember what this was for and don't see any reference to it. Maybe it's just remnants from a previous implementation? * feat(nodes): add unknown as model base * chore(ui): typegen * feat(ui): add unknown model base support in ui * feat(ui): allow changing model type in MM, fix up base and variant selects * feat(mm): omit model description instead of making it "base type filename model" * feat(app): add setting to allow unknown models * feat(ui): allow changing model format in MM * feat(app): add the installed model config to install complete events * chore(ui): typegen * feat(ui): toast warning when installed model is unidentified * docs: update config docstrings * chore(ui): typegen * tests(mm): fix test for MM, leave the UnknownModelConfig class in the list of configs * tidy(ui): prefer types from zod schemas for model attrs * chore(ui): lint * fix(ui): wrong translation string * feat(mm): normalized model storage Store models in a flat directory structure. Each model is in a dir named its unique key (a UUID). Inside that dir is either the model file or the model dir. * feat(mm): add migration to flat model storage * fix(mm): normalized multi-file/diffusers model installation no worky now worky * refactor: port MM probes to new api - Add concept of match certainty to new probe - Port CLIP Embed models to new API - Fiddle with stuff * feat(mm): port TIs to new API * tidy(mm): remove unused probes * feat(mm): port spandrel to new API * fix(mm): parsing for spandrel * fix(mm): loader for clip embed * fix(mm): tis use existing weight_files method * feat(mm): port vae to new API * fix(mm): vae class inheritance and config_path * tidy(mm): patcher types and import paths * feat(mm): better errors when invalid model config found in db * feat(mm): port t5 to new API * feat(mm): make config_path optional * refactor(mm): simplify model classification process Previously, we had a multi-phase strategy to identify models from their files on disk: 1. Run each model config classes' `matches()` method on the files. It checks if the model could possibly be an identified as the candidate model type. This was intended to be a quick check. Break on the first match. 2. If we have a match, run the config class's `parse()` method. It derive some additional model config attrs from the model files. This was intended to encapsulate heavier operations that may require loading the model into memory. 3. Derive the common model config attrs, like name, description, calculate the hash, etc. Some of these are also heavier operations. This strategy has some issues: - It is not clear how the pieces fit together. There is some back-and-forth between different methods and the config base class. It is hard to trace the flow of logic until you fully wrap your head around the system and therefore difficult to add a model architecture to the probe. - The assumption that we could do quick, lightweight checks before heavier checks is incorrect. We often _must_ load the model state dict in the `matches()` method. So there is no practical perf benefit to splitting up the responsibility of `matches()` and `parse()`. - Sometimes we need to do the same checks in `matches()` and `parse()`. In these cases, splitting the logic is has a negative perf impact because we are doing the same work twice. - As we introduce the concept of an "unknown" model config (i.e. a model that we cannot identify, but still record in the db; see #8582), we will _always_ run _all_ the checks for every model. Therefore we need not try to defer heavier checks or resource-intensive ops like hashing. We are going to do them anyways. - There are situations where a model may match multiple configs. One known case are SD pipeline models with merged LoRAs. In the old probe API, we relied on the implicit order of checks to know that if a model matched for pipeline _and_ LoRA, we prefer the pipeline match. But, in the new API, we do not have this implicit ordering of checks. To resolve this in a resilient way, we need to get all matches up front, then use tie-breaker logic to figure out which should win (or add "differential diagnosis" logic to the matchers). - Field overrides weren't handled well by this strategy. They were only applied at the very end, if a model matched successfully. This means we cannot tell the system "Hey, this model is type X with base Y. Trust me bro.". We cannot override the match logic. As we move towards letting users correct mis-identified models (see #8582), this is a requirement. We can simplify the process significantly and better support "unknown" models. Firstly, model config classes now have a single `from_model_on_disk()` method that attempts to construct an instance of the class from the model files. This replaces the `matches()` and `parse()` methods. If we fail to create the config instance, a special exception is raised that indicates why we think the files cannot be identified as the given model config class. Next, the flow for model identification is a bit simpler: - Derive all the common fields up-front (name, desc, hash, etc). - Merge in overrides. - Call `from_model_on_disk()` for every config class, passing in the fields. Overrides are handled in this method. - Record the results for each config class and choose the best one. The identification logic is a bit more verbose, with the special exceptions and handling of overrides, but it is very clear what is happening. The one downside I can think of for this strategy is we do need to check every model type, instead of stopping at the first match. It's a bit less efficient. In practice, however, this isn't a hot code path, and the improved clarity is worth far more than perf optimizations that the end user will likely never notice. * refactor(mm): remove unused methods in config.py * refactor(mm): add model config parsing utils * fix(mm): abstractmethod bork * tidy(mm): clarify that model id utils are private * fix(mm): fall back to UnknownModelConfig correctly * feat(mm): port CLIPVisionDiffusersConfig to new api * feat(mm): port SigLIPDiffusersConfig to new api * feat(mm): make match helpers more succint * feat(mm): port flux redux to new api * feat(mm): port ip adapter to new api * tidy(mm): skip optimistic override handling for now * refactor(mm): continue iterating on config * feat(mm): port flux "control lora" and t2i adapter to new api * tidy(ui): use Extract to get model config types * fix(mm): t2i base determination * feat(mm): port cnet to new api * refactor(mm): add config validation utils, make it all consistent and clean * feat(mm): wip port of main models to new api * feat(mm): wip port of main models to new api * feat(mm): wip port of main models to new api * docs(mm): add todos * tidy(mm): removed unused model merge class * feat(mm): wip port main models to new api * tidy(mm): clean up model heuristic utils * tidy(mm): clean up ModelOnDisk caching * tidy(mm): flux lora format util * refactor(mm): make config classes narrow Simpler logic to identify, less complexity to add new model, fewer useless attrs that do not relate to the model arch, etc * refactor(mm): diffusers loras w * feat(mm): consistent naming for all model config classes * fix(mm): tag generation & scattered probe fixes * tidy(mm): consistent class names * refactor(mm): split configs into separate files * docs(mm): add comments for identification utils * chore(ui): typegen * refactor(mm): remove legacy probe, new configs dir structure, update imports * fix(mm): inverted condition * docs(mm): update docsstrings in factory.py * docs(mm): document flux variant attr * feat(mm): add helper method for legacy configs * feat(mm): satisfy type checker in flux denoise * docs(mm): remove extraneous comment * fix(mm): ensure unknown model configs get unknown attrs * fix(mm): t5 identification * fix(mm): sdxl ip adapter identification * feat(mm): more flexible config matching utils * fix(mm): clip vision identification * feat(mm): add sanity checks before probing paths * docs(mm): add reminder for self for field migrations * feat(mm): clearer naming for main config class hierarchy * feat(mm): fix clip vision starter model bases, add ref to actual models * feat(mm): add model config schema migration logic * fix(mm): duplicate import * refactor(mm): split big migration into 3 Split the big migration that did all of these things into 3: - Migration 22: Remove unique contraint on base/name/type in models table - Migration 23: Migrate configs to v6.8.0 schemas - Migration 24: Normalize file storage * fix(mm): pop base/type/format when creating unknown model config * fix(db): migration 22 insert only real cols * fix(db): migration 23 fall back to unknown model when config change fails * feat(db): run migrations 23 and 24 * fix(mm): false negative on flux lora * fix(mm): vae checkpoint probe checking for dir instead of file * fix(mm): ModelOnDisk skips dirs when looking for weights Previously a path w/ any of the known weights suffixes would be seen as a weights file, even if it was a directory. We now check to ensure the candidate path is actually a file before adding it to the list of weights. * feat(mm): add method to get main model defaults from a base * feat(mm): do not log when multiple non-unknown model matches * refactor(mm): continued iteration on model identifcation * tests(mm): refactor model identification tests Overhaul of model identification (probing) tests. Previously we didn't test the correctness of probing except in a few narrow cases - now we do. See tests/model_identification/README.md for a detailed overview of the new test setup. It includes instructions for adding a new test case. In brief: - Download the model you want to add as a test case - Run a script against it to generate the test model files - Fill in the expected model type/format/base/etc in the generated test metadata JSON file Included test cases: - All starter models - A handful of other models that I had installed - Models present in the previous test cases as smoke tests, now also tested for correctness * fix(mm): omit type/format/base when creating unknown config instance * feat(mm): use ValueError for model id sanity checks * feat(mm): add flag for updating models to allow class changes * tests(mm): fix remaining MM tests * feat: allow users to edit models freely * feat(ui): add warning for model settings edit * tests(mm): flux state dict tests * tidy: remove unused file * fix(mm): lora state dict loading in model id * feat(ui): use translation string for model edit warning * docs(db): update version numbers in migration comments * chore: bump version to v6.9.0a1 * docs: update model id readme * tests(mm): attempt to fix windows model id tests * fix(mm): issue with deleting single file models * feat(mm): just delete the dir w/ rmtree when deleting model * tests(mm): windows CI issue * fix(ui): typegen schema sync * fix(mm): fixes for migration 23 - Handle CLIP Embed and Main SD models missing variant field - Handle errors when calling the discriminator function, previously only handled ValidationError but it could be a ValueError or something else - Better logging for config migration * chore: bump version to v6.9.0a2 * chore: bump version to v6.9.0a3
1 parent a7e1f06 commit 454d05b

File tree

548 files changed

+19532
-14038
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

548 files changed

+19532
-14038
lines changed

.gitattributes

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,4 @@
44
* text=auto
55
docker/** text eol=lf
66
tests/test_model_probe/stripped_models/** filter=lfs diff=lfs merge=lfs -text
7+
tests/model_identification/stripped_models/** filter=lfs diff=lfs merge=lfs -text

invokeai/app/api/routers/model_manager.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,12 @@
2828
UnknownModelException,
2929
)
3030
from invokeai.app.util.suppress_output import SuppressOutput
31-
from invokeai.backend.model_manager import BaseModelType, ModelFormat, ModelType
32-
from invokeai.backend.model_manager.config import (
33-
AnyModelConfig,
34-
MainCheckpointConfig,
31+
from invokeai.backend.model_manager.configs.factory import AnyModelConfig
32+
from invokeai.backend.model_manager.configs.main import (
33+
Main_Checkpoint_SD1_Config,
34+
Main_Checkpoint_SD2_Config,
35+
Main_Checkpoint_SDXL_Config,
36+
Main_Checkpoint_SDXLRefiner_Config,
3537
)
3638
from invokeai.backend.model_manager.load.model_cache.cache_stats import CacheStats
3739
from invokeai.backend.model_manager.metadata.fetch.huggingface import HuggingFaceMetadataFetch
@@ -44,6 +46,7 @@
4446
StarterModelBundle,
4547
StarterModelWithoutDependencies,
4648
)
49+
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelFormat, ModelType
4750

4851
model_manager_router = APIRouter(prefix="/v2/models", tags=["model_manager"])
4952

@@ -297,10 +300,8 @@ async def update_model_record(
297300
"""Update a model's config."""
298301
logger = ApiDependencies.invoker.services.logger
299302
record_store = ApiDependencies.invoker.services.model_manager.store
300-
installer = ApiDependencies.invoker.services.model_manager.install
301303
try:
302-
record_store.update_model(key, changes=changes)
303-
config = installer.sync_model_path(key)
304+
config = record_store.update_model(key, changes=changes, allow_class_change=True)
304305
config = add_cover_image_to_model_config(config, ApiDependencies)
305306
logger.info(f"Updated model: {key}")
306307
except UnknownModelException as e:
@@ -743,9 +744,18 @@ async def convert_model(
743744
logger.error(str(e))
744745
raise HTTPException(status_code=424, detail=str(e))
745746

746-
if not isinstance(model_config, MainCheckpointConfig):
747-
logger.error(f"The model with key {key} is not a main checkpoint model.")
748-
raise HTTPException(400, f"The model with key {key} is not a main checkpoint model.")
747+
if not isinstance(
748+
model_config,
749+
(
750+
Main_Checkpoint_SD1_Config,
751+
Main_Checkpoint_SD2_Config,
752+
Main_Checkpoint_SDXL_Config,
753+
Main_Checkpoint_SDXLRefiner_Config,
754+
),
755+
):
756+
msg = f"The model with key {key} is not a main SD 1/2/XL checkpoint model."
757+
logger.error(msg)
758+
raise HTTPException(400, msg)
749759

750760
with TemporaryDirectory(dir=ApiDependencies.invoker.services.configuration.models_path) as tmpdir:
751761
convert_path = pathlib.Path(tmpdir) / pathlib.Path(model_config.path).stem

invokeai/app/invocations/cogview4_denoise.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from invokeai.app.invocations.primitives import LatentsOutput
2323
from invokeai.app.services.shared.invocation_context import InvocationContext
2424
from invokeai.backend.flux.sampling_utils import clip_timestep_schedule_fractional
25-
from invokeai.backend.model_manager.config import BaseModelType
25+
from invokeai.backend.model_manager.taxonomy import BaseModelType
2626
from invokeai.backend.rectified_flow.rectified_flow_inpaint_extension import RectifiedFlowInpaintExtension
2727
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
2828
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import CogView4ConditioningInfo

invokeai/app/invocations/cogview4_model_loader.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,7 @@
1313
VAEField,
1414
)
1515
from invokeai.app.services.shared.invocation_context import InvocationContext
16-
from invokeai.backend.model_manager.config import SubModelType
17-
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType
16+
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType, SubModelType
1817

1918

2019
@invocation_output("cogview4_model_loader_output")

invokeai/app/invocations/create_gradient_mask.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,7 @@
2020
from invokeai.app.invocations.image_to_latents import ImageToLatentsInvocation
2121
from invokeai.app.invocations.model import UNetField, VAEField
2222
from invokeai.app.services.shared.invocation_context import InvocationContext
23-
from invokeai.backend.model_manager import LoadedModel
24-
from invokeai.backend.model_manager.config import MainConfigBase
25-
from invokeai.backend.model_manager.taxonomy import ModelVariantType
23+
from invokeai.backend.model_manager.taxonomy import FluxVariantType, ModelType, ModelVariantType
2624
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
2725

2826

@@ -182,10 +180,11 @@ def invoke(self, context: InvocationContext) -> GradientMaskOutput:
182180
if self.unet is not None and self.vae is not None and self.image is not None:
183181
# all three fields must be present at the same time
184182
main_model_config = context.models.get_config(self.unet.unet.key)
185-
assert isinstance(main_model_config, MainConfigBase)
186-
if main_model_config.variant is ModelVariantType.Inpaint:
183+
assert main_model_config.type is ModelType.Main
184+
variant = getattr(main_model_config, "variant", None)
185+
if variant is ModelVariantType.Inpaint or variant is FluxVariantType.DevFill:
187186
mask = dilated_mask_tensor
188-
vae_info: LoadedModel = context.models.load(self.vae.vae)
187+
vae_info = context.models.load(self.vae.vae)
189188
image = context.images.get_pil(self.image.image_name)
190189
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
191190
if image_tensor.dim() == 3:

invokeai/app/invocations/denoise_latents.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
from invokeai.app.services.shared.invocation_context import InvocationContext
4040
from invokeai.app.util.controlnet_utils import prepare_control_image
4141
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
42-
from invokeai.backend.model_manager.config import AnyModelConfig
42+
from invokeai.backend.model_manager.configs.factory import AnyModelConfig
4343
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelVariantType
4444
from invokeai.backend.model_patcher import ModelPatcher
4545
from invokeai.backend.patches.layer_patcher import LayerPatcher

invokeai/app/invocations/flux_denoise.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848
unpack,
4949
)
5050
from invokeai.backend.flux.text_conditioning import FluxReduxConditioning, FluxTextConditioning
51-
from invokeai.backend.model_manager.taxonomy import ModelFormat, ModelVariantType
51+
from invokeai.backend.model_manager.taxonomy import BaseModelType, FluxVariantType, ModelFormat, ModelType
5252
from invokeai.backend.patches.layer_patcher import LayerPatcher
5353
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
5454
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
@@ -232,7 +232,8 @@ def _run_diffusion(
232232
)
233233

234234
transformer_config = context.models.get_config(self.transformer.transformer)
235-
is_schnell = "schnell" in getattr(transformer_config, "config_path", "")
235+
assert transformer_config.base is BaseModelType.Flux and transformer_config.type is ModelType.Main
236+
is_schnell = transformer_config.variant is FluxVariantType.Schnell
236237

237238
# Calculate the timestep schedule.
238239
timesteps = get_schedule(
@@ -277,7 +278,7 @@ def _run_diffusion(
277278

278279
# Prepare the extra image conditioning tensor (img_cond) for either FLUX structural control or FLUX Fill.
279280
img_cond: torch.Tensor | None = None
280-
is_flux_fill = transformer_config.variant == ModelVariantType.Inpaint # type: ignore
281+
is_flux_fill = transformer_config.variant is FluxVariantType.DevFill
281282
if is_flux_fill:
282283
img_cond = self._prep_flux_fill_img_cond(
283284
context, device=TorchDevice.choose_torch_device(), dtype=inference_dtype

invokeai/app/invocations/flux_ip_adapter.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,7 @@
1616
from invokeai.app.invocations.primitives import ImageField
1717
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
1818
from invokeai.app.services.shared.invocation_context import InvocationContext
19-
from invokeai.backend.model_manager.config import (
20-
IPAdapterCheckpointConfig,
21-
IPAdapterInvokeAIConfig,
22-
)
19+
from invokeai.backend.model_manager.configs.ip_adapter import IPAdapter_Checkpoint_FLUX_Config
2320
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType
2421

2522

@@ -68,7 +65,7 @@ def validate_begin_end_step_percent(self) -> Self:
6865
def invoke(self, context: InvocationContext) -> IPAdapterOutput:
6966
# Lookup the CLIP Vision encoder that is intended to be used with the IP-Adapter model.
7067
ip_adapter_info = context.models.get_config(self.ip_adapter_model.key)
71-
assert isinstance(ip_adapter_info, (IPAdapterInvokeAIConfig, IPAdapterCheckpointConfig))
68+
assert isinstance(ip_adapter_info, IPAdapter_Checkpoint_FLUX_Config)
7269

7370
# Note: There is a IPAdapterInvokeAIConfig.image_encoder_model_id field, but it isn't trustworthy.
7471
image_encoder_starter_model = CLIP_VISION_MODEL_MAP[self.clip_vision_model]

invokeai/app/invocations/flux_model_loader.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,8 @@
1313
preprocess_t5_encoder_model_identifier,
1414
preprocess_t5_tokenizer_model_identifier,
1515
)
16-
from invokeai.backend.flux.util import max_seq_lengths
17-
from invokeai.backend.model_manager.config import (
18-
CheckpointConfigBase,
19-
)
16+
from invokeai.backend.flux.util import get_flux_max_seq_length
17+
from invokeai.backend.model_manager.configs.base import Checkpoint_Config_Base
2018
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType, SubModelType
2119

2220

@@ -87,12 +85,12 @@ def invoke(self, context: InvocationContext) -> FluxModelLoaderOutput:
8785
t5_encoder = preprocess_t5_encoder_model_identifier(self.t5_encoder_model)
8886

8987
transformer_config = context.models.get_config(transformer)
90-
assert isinstance(transformer_config, CheckpointConfigBase)
88+
assert isinstance(transformer_config, Checkpoint_Config_Base)
9189

9290
return FluxModelLoaderOutput(
9391
transformer=TransformerField(transformer=transformer, loras=[]),
9492
clip=CLIPField(tokenizer=tokenizer, text_encoder=clip_encoder, loras=[], skipped_layers=0),
9593
t5_encoder=T5EncoderField(tokenizer=tokenizer2, text_encoder=t5_encoder, loras=[]),
9694
vae=VAEField(vae=vae),
97-
max_seq_len=max_seq_lengths[transformer_config.config_path],
95+
max_seq_len=get_flux_max_seq_length(transformer_config.variant),
9896
)

invokeai/app/invocations/flux_redux.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,9 @@
2424
from invokeai.app.services.model_records.model_records_base import ModelRecordChanges
2525
from invokeai.app.services.shared.invocation_context import InvocationContext
2626
from invokeai.backend.flux.redux.flux_redux_model import FluxReduxModel
27-
from invokeai.backend.model_manager import BaseModelType, ModelType
28-
from invokeai.backend.model_manager.config import AnyModelConfig
27+
from invokeai.backend.model_manager.configs.factory import AnyModelConfig
2928
from invokeai.backend.model_manager.starter_models import siglip
29+
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType
3030
from invokeai.backend.sig_lip.sig_lip_pipeline import SigLipPipeline
3131
from invokeai.backend.util.devices import TorchDevice
3232

0 commit comments

Comments
 (0)