19
19
import inspect
20
20
import os
21
21
import re
22
+ import sys
22
23
import warnings
23
24
from dataclasses import dataclass
24
25
from pathlib import Path
@@ -540,11 +541,9 @@ def save_pretrained(
540
541
variant (`str`, *optional*):
541
542
If specified, weights are saved in the format pytorch_model.<variant>.bin.
542
543
"""
543
- self .save_config (save_directory )
544
-
545
544
model_index_dict = dict (self .config )
546
- model_index_dict .pop ("_class_name" )
547
- model_index_dict .pop ("_diffusers_version" )
545
+ model_index_dict .pop ("_class_name" , None )
546
+ model_index_dict .pop ("_diffusers_version" , None )
548
547
model_index_dict .pop ("_module" , None )
549
548
550
549
expected_modules , optional_kwargs = self ._get_signature_keys (self )
@@ -557,7 +556,6 @@ def is_saveable_module(name, value):
557
556
return True
558
557
559
558
model_index_dict = {k : v for k , v in model_index_dict .items () if is_saveable_module (k , v )}
560
-
561
559
for pipeline_component_name in model_index_dict .keys ():
562
560
sub_model = getattr (self , pipeline_component_name )
563
561
model_cls = sub_model .__class__
@@ -571,7 +569,13 @@ def is_saveable_module(name, value):
571
569
save_method_name = None
572
570
# search for the model's base class in LOADABLE_CLASSES
573
571
for library_name , library_classes in LOADABLE_CLASSES .items ():
574
- library = importlib .import_module (library_name )
572
+ if library_name in sys .modules :
573
+ library = importlib .import_module (library_name )
574
+ else :
575
+ logger .info (
576
+ f"{ library_name } is not installed. Cannot save { pipeline_component_name } as { library_classes } from { library_name } "
577
+ )
578
+
575
579
for base_class , save_load_methods in library_classes .items ():
576
580
class_candidate = getattr (library , base_class , None )
577
581
if class_candidate is not None and issubclass (model_cls , class_candidate ):
@@ -581,6 +585,12 @@ def is_saveable_module(name, value):
581
585
if save_method_name is not None :
582
586
break
583
587
588
+ if save_method_name is None :
589
+ logger .warn (f"self.{ pipeline_component_name } ={ sub_model } of type { type (sub_model )} cannot be saved." )
590
+ # make sure that unsaveable components are not tried to be loaded afterward
591
+ self .register_to_config (** {pipeline_component_name : (None , None )})
592
+ continue
593
+
584
594
save_method = getattr (sub_model , save_method_name )
585
595
586
596
# Call the save method with the argument safe_serialization only if it's supported
@@ -596,6 +606,9 @@ def is_saveable_module(name, value):
596
606
597
607
save_method (os .path .join (save_directory , pipeline_component_name ), ** save_kwargs )
598
608
609
+ # finally save the config
610
+ self .save_config (save_directory )
611
+
599
612
def to (
600
613
self ,
601
614
torch_device : Optional [Union [str , torch .device ]] = None ,
0 commit comments