From 2429fffd4f7430b559ac3bad098e0c71c7ef5cb4 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Fri, 18 Apr 2025 22:28:13 +1000 Subject: [PATCH] fix(mm): disable new model probe API There is a subtle change in behaviour with the new model probe API. Previously, checks for model types was done in a specific order. For example, we did all main model checks before LoRA checks. With the new API, the order of checks has changed. Check ordering is as follows: - New API checks are run first, then legacy API checks. - New API checks categorized by their speed. When we run new API checks, we sort them from fastest to slowest, and run them in that order. This is a performance optimization. Currently, LoRA and LLaVA models are the only model types with the new API. Checks for them are thus run first. LoRA checks involve checking the state dict for presence of keys with specific prefixes. We expect these keys to only exist in LoRAs. It turns out that main models may have some of these keys. For example, this model has keys that match the LoRA prefix `lora_te_`: https://civitai.com/models/134442/helloyoung25d Under the old probe, we'd do the main model checks first and correctly identify this as a main model. But with the new setup, we do the LoRA check first, and those pass. So we import this model as a LoRA. Thankfully, the old probe still exists. For now, the new probe is fully disabled. It was only called in one spot. I've also added the example affected model as a test case for the model probe. Right now, this causes the test to fail, and I've marked the test as xfail. CI will pass. Once we enable the new API again, the xfail will pass, and CI will fail, and we'll be reminded to update the test. --- .../services/model_install/model_install_default.py | 13 ++++++++----- tests/test_model_probe.py | 1 + .../stripped_models/helloyoung25d_V15j.safetensors | 3 +++ 3 files changed, 12 insertions(+), 5 deletions(-) create mode 100644 tests/test_model_probe/stripped_models/helloyoung25d_V15j.safetensors diff --git a/invokeai/app/services/model_install/model_install_default.py b/invokeai/app/services/model_install/model_install_default.py index 6a780256ed7..19a8789098b 100644 --- a/invokeai/app/services/model_install/model_install_default.py +++ b/invokeai/app/services/model_install/model_install_default.py @@ -38,7 +38,6 @@ AnyModelConfig, CheckpointConfigBase, InvalidModelConfigException, - ModelConfigBase, ) from invokeai.backend.model_manager.legacy_probe import ModelProbe from invokeai.backend.model_manager.metadata import ( @@ -647,10 +646,14 @@ def _probe(self, model_path: Path, config: Optional[ModelRecordChanges] = None): hash_algo = self._app_config.hashing_algorithm fields = config.model_dump() - try: - return ModelConfigBase.classify(model_path=model_path, hash_algo=hash_algo, **fields) - except InvalidModelConfigException: - return ModelProbe.probe(model_path=model_path, fields=fields, hash_algo=hash_algo) # type: ignore + return ModelProbe.probe(model_path=model_path, fields=fields, hash_algo=hash_algo) + + # New model probe API is disabled pending resolution of issue caused by a change of the ordering of checks. + # See commit message for details. + # try: + # return ModelConfigBase.classify(model_path=model_path, hash_algo=hash_algo, **fields) + # except InvalidModelConfigException: + # return ModelProbe.probe(model_path=model_path, fields=fields, hash_algo=hash_algo) # type: ignore def _register( self, model_path: Path, config: Optional[ModelRecordChanges] = None, info: Optional[AnyModelConfig] = None diff --git a/tests/test_model_probe.py b/tests/test_model_probe.py index a808b043930..108576716d2 100644 --- a/tests/test_model_probe.py +++ b/tests/test_model_probe.py @@ -137,6 +137,7 @@ def test_minimal_working_example(datadir: Path): assert config.fun_quote == "Minimal working example of a ModelConfigBase subclass" +@pytest.mark.xfail(reason="Known issue with 'helloyoung25d_V15j.safetensors'.", strict=True) def test_regression_against_model_probe(datadir: Path, override_model_loading): """Verifies results from ModelConfigBase.classify are consistent with those from ModelProbe.probe. The test paths are gathered from the 'test_model_probe' directory. diff --git a/tests/test_model_probe/stripped_models/helloyoung25d_V15j.safetensors b/tests/test_model_probe/stripped_models/helloyoung25d_V15j.safetensors new file mode 100644 index 00000000000..4ee620fc4d0 --- /dev/null +++ b/tests/test_model_probe/stripped_models/helloyoung25d_V15j.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0f0547f89bdcbb0dfd8b6ff1d8de63336df20107e9a27afc0934e8d3cce584d7 +size 308563