@@ -68,7 +68,6 @@ class SDModelComponent(Enum):
6868 feature_extractor = "feature_extractor"
6969
7070DEFAULT_MAX_MODELS = 2
71- config = get_invokeai_config ()
7271
7372class ModelManager (object ):
7473 """
@@ -99,6 +98,7 @@ def __init__(
9998 if not isinstance (config , DictConfig ):
10099 config = OmegaConf .load (config )
101100 self .config = config
101+ self .globals = get_invokeai_config ()
102102 self .precision = precision
103103 self .device = torch .device (device_type )
104104 self .max_loaded_models = max_loaded_models
@@ -291,7 +291,7 @@ def is_legacy(self, model_name: str) -> bool:
291291 """
292292 # if we are converting legacy files automatically, then
293293 # there are no legacy ckpts!
294- if config .ckpt_convert :
294+ if self . globals .ckpt_convert :
295295 return False
296296 info = self .model_info (model_name )
297297 if "weights" in info and info ["weights" ].endswith ((".ckpt" , ".safetensors" )):
@@ -501,13 +501,13 @@ def _load_diffusers_model(self, mconfig):
501501
502502 # TODO: scan weights maybe?
503503 pipeline_args : dict [str , Any ] = dict (
504- safety_checker = None , local_files_only = not config .internet_available
504+ safety_checker = None , local_files_only = not self . globals .internet_available
505505 )
506506 if "vae" in mconfig and mconfig ["vae" ] is not None :
507507 if vae := self ._load_vae (mconfig ["vae" ]):
508508 pipeline_args .update (vae = vae )
509509 if not isinstance (name_or_path , Path ):
510- pipeline_args .update (cache_dir = config .cache_dir )
510+ pipeline_args .update (cache_dir = self . globals .cache_dir )
511511 if using_fp16 :
512512 pipeline_args .update (torch_dtype = torch .float16 )
513513 fp_args_list = [{"revision" : "fp16" }, {}]
@@ -559,10 +559,9 @@ def _load_ckpt_model(self, model_name, mconfig):
559559 width = mconfig .width
560560 height = mconfig .height
561561
562- if not os .path .isabs (config ):
563- config = os .path .join (config .root , config )
564- if not os .path .isabs (weights ):
565- weights = os .path .normpath (os .path .join (config .root , weights ))
562+ root_dir = self .globals .root_dir
563+ config = str (root_dir / config )
564+ weights = str (root_dir / weights )
566565
567566 # Convert to diffusers and return a diffusers pipeline
568567 self .logger .info (f"Converting legacy checkpoint { model_name } into a diffusers model..." )
@@ -577,11 +576,7 @@ def _load_ckpt_model(self, model_name, mconfig):
577576
578577 vae_path = None
579578 if vae :
580- vae_path = (
581- vae
582- if os .path .isabs (vae )
583- else os .path .normpath (os .path .join (config .root , vae ))
584- )
579+ vae_path = str (root_dir / vae )
585580 if self ._has_cuda ():
586581 torch .cuda .empty_cache ()
587582 pipeline = load_pipeline_from_original_stable_diffusion_ckpt (
@@ -613,9 +608,7 @@ def model_name_or_path(self, model_name: Union[str, DictConfig]) -> str | Path:
613608 )
614609
615610 if "path" in mconfig and mconfig ["path" ] is not None :
616- path = Path (mconfig ["path" ])
617- if not path .is_absolute ():
618- path = Path (config .root , path ).resolve ()
611+ path = self .globals .root_dir / Path (mconfig ["path" ])
619612 return path
620613 elif "repo_id" in mconfig :
621614 return mconfig ["repo_id" ]
@@ -863,16 +856,16 @@ def heuristic_import(
863856 model_type = self .probe_model_type (checkpoint )
864857 if model_type == SDLegacyType .V1 :
865858 self .logger .debug ("SD-v1 model detected" )
866- model_config_file = config .legacy_conf_path / "v1-inference.yaml"
859+ model_config_file = self . globals .legacy_conf_path / "v1-inference.yaml"
867860 elif model_type == SDLegacyType .V1_INPAINT :
868861 self .logger .debug ("SD-v1 inpainting model detected" )
869- model_config_file = config .legacy_conf_path / "v1-inpainting-inference.yaml" ,
862+ model_config_file = self . globals .legacy_conf_path / "v1-inpainting-inference.yaml" ,
870863 elif model_type == SDLegacyType .V2_v :
871864 self .logger .debug ("SD-v2-v model detected" )
872- model_config_file = config .legacy_conf_path / "v2-inference-v.yaml"
865+ model_config_file = self . globals .legacy_conf_path / "v2-inference-v.yaml"
873866 elif model_type == SDLegacyType .V2_e :
874867 self .logger .debug ("SD-v2-e model detected" )
875- model_config_file = config .legacy_conf_path / "v2-inference.yaml"
868+ model_config_file = self . globals .legacy_conf_path / "v2-inference.yaml"
876869 elif model_type == SDLegacyType .V2 :
877870 self .logger .warning (
878871 f"{ thing } is a V2 checkpoint file, but its parameterization cannot be determined. Please provide configuration file path."
@@ -899,7 +892,7 @@ def heuristic_import(
899892 self .logger .debug (f"Using VAE file { vae_path .name } " )
900893 vae = None if vae_path else dict (repo_id = "stabilityai/sd-vae-ft-mse" )
901894
902- diffuser_path = config . root / "models/converted_ckpts" / model_path .stem
895+ diffuser_path = self . globals . root_dir / "models/converted_ckpts" / model_path .stem
903896 model_name = self .convert_and_import (
904897 model_path ,
905898 diffusers_path = diffuser_path ,
@@ -1032,7 +1025,7 @@ def commit(self, config_file_path: str) -> None:
10321025 """
10331026 yaml_str = OmegaConf .to_yaml (self .config )
10341027 if not os .path .isabs (config_file_path ):
1035- config_file_path = config .model_conf_path
1028+ config_file_path = self . globals .model_conf_path
10361029 tmpfile = os .path .join (os .path .dirname (config_file_path ), "new_config.tmp" )
10371030 with open (tmpfile , "w" , encoding = "utf-8" ) as outfile :
10381031 outfile .write (self .preamble ())
@@ -1064,7 +1057,8 @@ def migrate_models(cls):
10641057 """
10651058 # Three transformer models to check: bert, clip and safety checker, and
10661059 # the diffusers as well
1067- models_dir = config .root / "models"
1060+ config = get_invokeai_config ()
1061+ models_dir = config .root_dir / "models"
10681062 legacy_locations = [
10691063 Path (
10701064 models_dir ,
@@ -1138,13 +1132,12 @@ def _resolve_path(
11381132 if str (source ).startswith (("http:" , "https:" , "ftp:" )):
11391133 dest_directory = Path (dest_directory )
11401134 if not dest_directory .is_absolute ():
1141- dest_directory = config . root / dest_directory
1135+ dest_directory = self . globals . root_dir / dest_directory
11421136 dest_directory .mkdir (parents = True , exist_ok = True )
11431137 resolved_path = download_with_resume (str (source ), dest_directory )
11441138 else :
1145- if not os .path .isabs (source ):
1146- source = config .root / source
1147- resolved_path = Path (source )
1139+ source = self .globals .root_dir / source
1140+ resolved_path = source
11481141 return resolved_path
11491142
11501143 def _invalidate_cached_model (self , model_name : str ) -> None :
@@ -1194,7 +1187,7 @@ def _diffuser_sha256(
11941187 path = name_or_path
11951188 else :
11961189 owner , repo = name_or_path .split ("/" )
1197- path = Path ( config . cache_dir / f"models--{ owner } --{ repo } " )
1190+ path = self . globals . cache_dir / f"models--{ owner } --{ repo } "
11981191 if not path .exists ():
11991192 return None
12001193 hashpath = path / "checksum.sha256"
@@ -1255,8 +1248,8 @@ def _load_vae(self, vae_config) -> AutoencoderKL:
12551248 using_fp16 = self .precision == "float16"
12561249
12571250 vae_args .update (
1258- cache_dir = config .cache_dir ,
1259- local_files_only = not config .internet_available ,
1251+ cache_dir = self . globals .cache_dir ,
1252+ local_files_only = not self . globals .internet_available ,
12601253 )
12611254
12621255 self .logger .debug (f"Loading diffusers VAE from { name_or_path } " )
@@ -1294,7 +1287,7 @@ def _load_vae(self, vae_config) -> AutoencoderKL:
12941287
12951288 @classmethod
12961289 def _delete_model_from_cache (cls ,repo_id ):
1297- cache_info = scan_cache_dir (config .cache_dir )
1290+ cache_info = scan_cache_dir (get_invokeai_config () .cache_dir )
12981291
12991292 # I'm sure there is a way to do this with comprehensions
13001293 # but the code quickly became incomprehensible!
@@ -1311,9 +1304,10 @@ def _delete_model_from_cache(cls,repo_id):
13111304
13121305 @staticmethod
13131306 def _abs_path (path : str | Path ) -> Path :
1307+ globals = get_invokeai_config ()
13141308 if path is None or Path (path ).is_absolute ():
13151309 return path
1316- return Path (config . root , path ).resolve ()
1310+ return Path (globals . root_dir , path ).resolve ()
13171311
13181312 @staticmethod
13191313 def _is_huggingface_hub_directory_present () -> bool :
0 commit comments