Skip to content

Commit 9536ba2

Browse files
authored
Convert custom VAEs during legacy checkpoint loading (#3010)
- When a legacy checkpoint model is loaded via --convert_ckpt and its models.yaml stanza refers to a custom VAE path (using the 'vae:' key), the custom VAE will be converted and used within the diffusers model. Otherwise the VAE contained within the legacy model will be used. - Note that the checkpoint import functions in the CLI or Web UIs continue to default to the standard stabilityai/sd-vae-ft-mse VAE. This can be fixed after the fact by editing VAE key using either the CLI or Web UI. - Fixes issue #2917
2 parents 9bfe2fa + 5503749 commit 9536ba2

File tree

3 files changed

+53
-60
lines changed

3 files changed

+53
-60
lines changed

invokeai/backend/model_management/convert_ckpt_to_diffusers.py

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1036,6 +1036,15 @@ def convert_open_clip_checkpoint(checkpoint):
10361036

10371037
return text_model
10381038

1039+
def replace_checkpoint_vae(checkpoint, vae_path:str):
1040+
if vae_path.endswith(".safetensors"):
1041+
vae_ckpt = load_file(vae_path)
1042+
else:
1043+
vae_ckpt = torch.load(vae_path, map_location="cpu")
1044+
state_dict = vae_ckpt['state_dict'] if "state_dict" in vae_ckpt else vae_ckpt
1045+
for vae_key in state_dict:
1046+
new_key = f'first_stage_model.{vae_key}'
1047+
checkpoint[new_key] = state_dict[vae_key]
10391048

10401049
def load_pipeline_from_original_stable_diffusion_ckpt(
10411050
checkpoint_path: str,
@@ -1048,6 +1057,7 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
10481057
extract_ema: bool = True,
10491058
upcast_attn: bool = False,
10501059
vae: AutoencoderKL = None,
1060+
vae_path: str = None,
10511061
precision: torch.dtype = torch.float32,
10521062
return_generator_pipeline: bool = False,
10531063
scan_needed:bool=True,
@@ -1078,6 +1088,8 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
10781088
:param precision: precision to use - torch.float16, torch.float32 or torch.autocast
10791089
:param upcast_attention: Whether the attention computation should always be upcasted. This is necessary when
10801090
running stable diffusion 2.1.
1091+
:param vae: A diffusers VAE to load into the pipeline.
1092+
:param vae_path: Path to a checkpoint VAE that will be converted into diffusers and loaded into the pipeline.
10811093
"""
10821094

10831095
with warnings.catch_warnings():
@@ -1214,9 +1226,19 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
12141226

12151227
unet.load_state_dict(converted_unet_checkpoint)
12161228

1217-
# Convert the VAE model, or use the one passed
1218-
if not vae:
1219-
print(" | Using checkpoint model's original VAE")
1229+
# If a replacement VAE path was specified, we'll incorporate that into
1230+
# the checkpoint model and then convert it
1231+
if vae_path:
1232+
print(f" | Converting VAE {vae_path}")
1233+
replace_checkpoint_vae(checkpoint,vae_path)
1234+
# otherwise we use the original VAE, provided that
1235+
# an externally loaded diffusers VAE was not passed
1236+
elif not vae:
1237+
print(" | Using checkpoint model's original VAE")
1238+
1239+
if vae:
1240+
print(" | Using replacement diffusers VAE")
1241+
else: # convert the original or replacement VAE
12201242
vae_config = create_vae_diffusers_config(
12211243
original_config, image_size=image_size
12221244
)
@@ -1226,8 +1248,6 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
12261248

12271249
vae = AutoencoderKL(**vae_config)
12281250
vae.load_state_dict(converted_vae_checkpoint)
1229-
else:
1230-
print(" | Using external VAE specified in config")
12311251

12321252
# Convert the text model.
12331253
model_type = pipeline_type

invokeai/backend/model_management/model_manager.py

Lines changed: 26 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,6 @@ class SDLegacyType(Enum):
4545
UNKNOWN = 99
4646

4747
DEFAULT_MAX_MODELS = 2
48-
VAE_TO_REPO_ID = { # hack, see note in convert_and_import()
49-
"vae-ft-mse-840000-ema-pruned": "stabilityai/sd-vae-ft-mse",
50-
}
5148

5249
class ModelManager(object):
5350
'''
@@ -457,15 +454,21 @@ def _load_ckpt_model(self, model_name, mconfig):
457454

458455
from . import load_pipeline_from_original_stable_diffusion_ckpt
459456

460-
self.offload_model(self.current_model)
461-
if vae_config := self._choose_diffusers_vae(model_name):
462-
vae = self._load_vae(vae_config)
457+
try:
458+
if self.list_models()[self.current_model]['status'] == 'active':
459+
self.offload_model(self.current_model)
460+
except Exception as e:
461+
pass
462+
463+
vae_path = None
464+
if vae:
465+
vae_path = vae if os.path.isabs(vae) else os.path.normpath(os.path.join(Globals.root, vae))
463466
if self._has_cuda():
464467
torch.cuda.empty_cache()
465468
pipeline = load_pipeline_from_original_stable_diffusion_ckpt(
466469
checkpoint_path=weights,
467470
original_config_file=config,
468-
vae=vae,
471+
vae_path=vae_path,
469472
return_generator_pipeline=True,
470473
precision=torch.float16 if self.precision == "float16" else torch.float32,
471474
)
@@ -512,6 +515,7 @@ def offload_model(self, model_name: str) -> None:
512515
print(f">> Offloading {model_name} to CPU")
513516
model = self.models[model_name]["model"]
514517
model.offload_all()
518+
self.current_model = None
515519

516520
gc.collect()
517521
if self._has_cuda():
@@ -795,15 +799,16 @@ def heuristic_import(
795799
return model_name
796800

797801
def convert_and_import(
798-
self,
799-
ckpt_path: Path,
800-
diffusers_path: Path,
801-
model_name=None,
802-
model_description=None,
803-
vae=None,
804-
original_config_file: Path = None,
805-
commit_to_conf: Path = None,
806-
scan_needed: bool=True,
802+
self,
803+
ckpt_path: Path,
804+
diffusers_path: Path,
805+
model_name=None,
806+
model_description=None,
807+
vae:dict=None,
808+
vae_path:Path=None,
809+
original_config_file: Path = None,
810+
commit_to_conf: Path = None,
811+
scan_needed: bool=True,
807812
) -> str:
808813
"""
809814
Convert a legacy ckpt weights file to diffuser model and import
@@ -831,13 +836,17 @@ def convert_and_import(
831836
try:
832837
# By passing the specified VAE to the conversion function, the autoencoder
833838
# will be built into the model rather than tacked on afterward via the config file
834-
vae_model = self._load_vae(vae) if vae else None
839+
vae_model=None
840+
if vae:
841+
vae_model=self._load_vae(vae)
842+
vae_path=None
835843
convert_ckpt_to_diffusers(
836844
ckpt_path,
837845
diffusers_path,
838846
extract_ema=True,
839847
original_config_file=original_config_file,
840848
vae=vae_model,
849+
vae_path=vae_path,
841850
scan_needed=scan_needed,
842851
)
843852
print(
@@ -884,36 +893,6 @@ def search_models(self, search_folder):
884893

885894
return search_folder, found_models
886895

887-
def _choose_diffusers_vae(
888-
self, model_name: str, vae: str = None
889-
) -> Union[dict, str]:
890-
# In the event that the original entry is using a custom ckpt VAE, we try to
891-
# map that VAE onto a diffuser VAE using a hard-coded dictionary.
892-
# I would prefer to do this differently: We load the ckpt model into memory, swap the
893-
# VAE in memory, and then pass that to convert_ckpt_to_diffuser() so that the swapped
894-
# VAE is built into the model. However, when I tried this I got obscure key errors.
895-
if vae:
896-
return vae
897-
if model_name in self.config and (
898-
vae_ckpt_path := self.model_info(model_name).get("vae", None)
899-
):
900-
vae_basename = Path(vae_ckpt_path).stem
901-
diffusers_vae = None
902-
if diffusers_vae := VAE_TO_REPO_ID.get(vae_basename, None):
903-
print(
904-
f">> {vae_basename} VAE corresponds to known {diffusers_vae} diffusers version"
905-
)
906-
vae = {"repo_id": diffusers_vae}
907-
else:
908-
print(
909-
f'** Custom VAE "{vae_basename}" found, but corresponding diffusers model unknown'
910-
)
911-
print(
912-
'** Using "stabilityai/sd-vae-ft-mse"; If this isn\'t right, please edit the model config'
913-
)
914-
vae = {"repo_id": "stabilityai/sd-vae-ft-mse"}
915-
return vae
916-
917896
def _make_cache_room(self) -> None:
918897
num_loaded_models = len(self.models)
919898
if num_loaded_models >= self.max_loaded_models:

invokeai/frontend/CLI/CLI.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -772,16 +772,10 @@ def convert_model(model_name_or_path: Union[Path, str], gen, opt, completer):
772772
original_config_file = Path(model_info["config"])
773773
model_name = model_name_or_path
774774
model_description = model_info["description"]
775-
vae = model_info["vae"]
775+
vae_path = model_info.get("vae")
776776
else:
777777
print(f"** {model_name_or_path} is not a legacy .ckpt weights file")
778778
return
779-
if vae_repo := invokeai.backend.model_management.model_manager.VAE_TO_REPO_ID.get(
780-
Path(vae).stem
781-
):
782-
vae_repo = dict(repo_id=vae_repo)
783-
else:
784-
vae_repo = None
785779
model_name = manager.convert_and_import(
786780
ckpt_path,
787781
diffusers_path=Path(
@@ -790,7 +784,7 @@ def convert_model(model_name_or_path: Union[Path, str], gen, opt, completer):
790784
model_name=model_name,
791785
model_description=model_description,
792786
original_config_file=original_config_file,
793-
vae=vae_repo,
787+
vae_path=vae_path,
794788
)
795789
else:
796790
try:

0 commit comments

Comments
 (0)