Skip to content

Commit ed9efe7

Browse files
authored
Port LLaVA to new API (#7817)
## Summary - Port LLaVA model config to new classification API - Add 2 test cases (stripped LLaVA models variants to git-lfs) ## Related Issues / Discussions <!--WHEN APPLICABLE: List any related issues or discussions on github or discord. If this PR closes an issue, please use the "Closes #1234" format, so that the issue will be automatically closed when the PR merges.--> ## QA Instructions <!--WHEN APPLICABLE: Describe how you have tested the changes in this PR. Provide enough detail that a reviewer can reproduce your tests.--> ## Merge Plan <!--WHEN APPLICABLE: Large PRs, or PRs that touch sensitive things like DB schemas, may need some care when merging. For example, a careful rebase by the change author, timing to not interfere with a pending release, or a message to contributors on discord after merging.--> ## Checklist - [ ] _The PR has a short but descriptive title, suitable for a changelog_ - [ ] _Tests added / updated (if applicable)_ - [ ] _Documentation added / updated (if applicable)_ - [ ] _Updated `What's New` copy (if doing a release after this PR)_
2 parents 75d793f + ffa0beb commit ed9efe7

32 files changed

+160
-6
lines changed

invokeai/backend/model_manager/config.py

Lines changed: 65 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
"""
2222

2323
# pyright: reportIncompatibleVariableOverride=false
24+
import json
2425
import logging
2526
import time
2627
from abc import ABC, abstractmethod
@@ -232,6 +233,23 @@ def component_paths(self):
232233
extensions = {".safetensors", ".pt", ".pth", ".ckpt", ".bin", ".gguf"}
233234
return {f for f in self.path.rglob("*") if f.suffix in extensions}
234235

236+
def repo_variant(self):
237+
if self.format_type == ModelFormat.Checkpoint:
238+
return None
239+
240+
weight_files = list(self.path.glob("**/*.safetensors"))
241+
weight_files.extend(list(self.path.glob("**/*.bin")))
242+
for x in weight_files:
243+
if ".fp16" in x.suffixes:
244+
return ModelRepoVariant.FP16
245+
if "openvino_model" in x.name:
246+
return ModelRepoVariant.OpenVINO
247+
if "flax_model" in x.name:
248+
return ModelRepoVariant.Flax
249+
if x.suffix == ".onnx":
250+
return ModelRepoVariant.ONNX
251+
return ModelRepoVariant.Default
252+
235253
@staticmethod
236254
def load_state_dict(path: Path):
237255
with SilenceWarnings():
@@ -359,21 +377,43 @@ def matches(cls, mod: ModelOnDisk) -> bool:
359377
This doesn't need to be a perfect test - the aim is to eliminate unlikely matches quickly before parsing."""
360378
pass
361379

380+
@staticmethod
381+
def cast_overrides(overrides: dict[str, Any]):
382+
"""Casts user overrides from str to Enum"""
383+
if "type" in overrides:
384+
overrides["type"] = ModelType(overrides["type"])
385+
386+
if "format" in overrides:
387+
overrides["format"] = ModelFormat(overrides["format"])
388+
389+
if "base" in overrides:
390+
overrides["base"] = BaseModelType(overrides["base"])
391+
392+
if "source_type" in overrides:
393+
overrides["source_type"] = ModelSourceType(overrides["source_type"])
394+
362395
@classmethod
363396
def from_model_on_disk(cls, mod: ModelOnDisk, **overrides):
364397
"""Creates an instance of this config or raises InvalidModelConfigException."""
365398
if not cls.matches(mod):
366399
raise InvalidModelConfigException(f"Path {mod.path} does not match {cls.__name__} format")
367400

368401
fields = cls.parse(mod)
402+
cls.cast_overrides(overrides)
403+
fields.update(overrides)
404+
405+
type = fields.get("type") or cls.model_fields["type"].default
406+
base = fields.get("base") or cls.model_fields["base"].default
369407

370408
fields["path"] = mod.path.as_posix()
371409
fields["source"] = fields.get("source") or fields["path"]
372410
fields["source_type"] = fields.get("source_type") or ModelSourceType.Path
373-
fields["name"] = mod.name
411+
fields["name"] = name = fields.get("name") or mod.name
374412
fields["hash"] = fields.get("hash") or mod.hash()
413+
fields["key"] = fields.get("key") or uuid_string()
414+
fields["description"] = fields.get("description") or f"{base.value} {type.value} model {name}"
415+
fields["repo_variant"] = fields.get("repo_variant") or mod.repo_variant()
375416

376-
fields.update(overrides)
377417
return cls(**fields)
378418

379419

@@ -625,12 +665,34 @@ class FluxReduxConfig(LegacyProbeMixin, ModelConfigBase):
625665
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
626666

627667

628-
class LlavaOnevisionConfig(DiffusersConfigBase, LegacyProbeMixin, ModelConfigBase):
668+
class LlavaOnevisionConfig(DiffusersConfigBase, ModelConfigBase):
629669
"""Model config for Llava Onevision models."""
630670

631671
type: Literal[ModelType.LlavaOnevision] = ModelType.LlavaOnevision
632672
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
633673

674+
@classmethod
675+
def matches(cls, mod: ModelOnDisk) -> bool:
676+
if mod.format_type == ModelFormat.Checkpoint:
677+
return False
678+
679+
config_path = mod.path / "config.json"
680+
try:
681+
with open(config_path, "r") as file:
682+
config = json.load(file)
683+
except FileNotFoundError:
684+
return False
685+
686+
architectures = config.get("architectures")
687+
return architectures and architectures[0] == "LlavaOnevisionForConditionalGeneration"
688+
689+
@classmethod
690+
def parse(cls, mod: ModelOnDisk) -> dict[str, Any]:
691+
return {
692+
"base": BaseModelType.Any,
693+
"variant": ModelVariantType.Normal,
694+
}
695+
634696

635697
def get_model_discriminator_value(v: Any) -> str:
636698
"""

tests/test_model_probe.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -148,22 +148,24 @@ def test_regression_against_model_probe(datadir: Path, override_model_loading):
148148
configs_with_tests = set()
149149
model_paths = ModelSearch().search(datadir / "stripped_models")
150150
fake_hash = "abcdefgh" # skip hashing to make test quicker
151+
fake_key = "123" # fixed uuid for comparison
151152

152153
for path in model_paths:
153154
legacy_config = new_config = None
154155

155156
try:
156-
legacy_config = ModelProbe.probe(path, {"hash": fake_hash})
157+
legacy_config = ModelProbe.probe(path, {"hash": fake_hash, "key": fake_key})
157158
except InvalidModelConfigException:
158159
pass
159160

160161
try:
161-
new_config = ModelConfigBase.classify(path, hash=fake_hash)
162+
new_config = ModelConfigBase.classify(path, hash=fake_hash, key=fake_key)
162163
except InvalidModelConfigException:
163164
pass
164165

165166
if legacy_config and new_config:
166-
assert legacy_config == new_config
167+
assert type(legacy_config) is type(new_config)
168+
assert legacy_config.model_dump_json() == new_config.model_dump_json()
167169

168170
elif legacy_config:
169171
assert type(legacy_config) in ModelConfigBase._USING_LEGACY_PROBE
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
version https://git-lfs.github.com/spec/v1
2+
oid sha256:33e0fb93dadacb864bd2f2e8441e147daa2baceb67f94d3ef5283b495572cea0
3+
size 122
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
version https://git-lfs.github.com/spec/v1
2+
oid sha256:2466d1704df30f0067f28d8e30e0190a1bf74e5b430942697af974d162a056bd
3+
size 826
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
version https://git-lfs.github.com/spec/v1
2+
oid sha256:839a4fba0bd6949f0db22d4f840935cb0318f6ac28b29c9ce1a5b15735a4a740
3+
size 2591
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
version https://git-lfs.github.com/spec/v1
2+
oid sha256:89dc53229f50b59570b6852056dafeac8116c458f1a748bff491b6d4d24d3b51
3+
size 126
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
version https://git-lfs.github.com/spec/v1
2+
oid sha256:8831e4f1a044471340f7c0a83d7bd71306a5b867e95fd870f74d0c5308a904d5
3+
size 1671853
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
version https://git-lfs.github.com/spec/v1
2+
oid sha256:a0e4b0349d188ce8618b4cfdc01f87d58f912bf9c55cfb8a2d80b9ecae39a870
3+
size 136697
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
version https://git-lfs.github.com/spec/v1
2+
oid sha256:3644c108b9f0fa53e62ff422a9be6639642f0e64dab4a71f961c7911d4386384
3+
size 1732
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
version https://git-lfs.github.com/spec/v1
2+
oid sha256:04e9899e93f2a412c94e153cab4081457c9a44defb3b2c0b9df673d42c42cdd0
3+
size 178

0 commit comments

Comments
 (0)