Skip to content

Commit 46c52f9

Browse files
[Pipelines] Make sure that None functions are correctly not saved (#3080)
1 parent d06e069 commit 46c52f9

File tree

1 file changed

+19
-6
lines changed

1 file changed

+19
-6
lines changed

src/diffusers/pipelines/pipeline_utils.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import inspect
2020
import os
2121
import re
22+
import sys
2223
import warnings
2324
from dataclasses import dataclass
2425
from pathlib import Path
@@ -540,11 +541,9 @@ def save_pretrained(
540541
variant (`str`, *optional*):
541542
If specified, weights are saved in the format pytorch_model.<variant>.bin.
542543
"""
543-
self.save_config(save_directory)
544-
545544
model_index_dict = dict(self.config)
546-
model_index_dict.pop("_class_name")
547-
model_index_dict.pop("_diffusers_version")
545+
model_index_dict.pop("_class_name", None)
546+
model_index_dict.pop("_diffusers_version", None)
548547
model_index_dict.pop("_module", None)
549548

550549
expected_modules, optional_kwargs = self._get_signature_keys(self)
@@ -557,7 +556,6 @@ def is_saveable_module(name, value):
557556
return True
558557

559558
model_index_dict = {k: v for k, v in model_index_dict.items() if is_saveable_module(k, v)}
560-
561559
for pipeline_component_name in model_index_dict.keys():
562560
sub_model = getattr(self, pipeline_component_name)
563561
model_cls = sub_model.__class__
@@ -571,7 +569,13 @@ def is_saveable_module(name, value):
571569
save_method_name = None
572570
# search for the model's base class in LOADABLE_CLASSES
573571
for library_name, library_classes in LOADABLE_CLASSES.items():
574-
library = importlib.import_module(library_name)
572+
if library_name in sys.modules:
573+
library = importlib.import_module(library_name)
574+
else:
575+
logger.info(
576+
f"{library_name} is not installed. Cannot save {pipeline_component_name} as {library_classes} from {library_name}"
577+
)
578+
575579
for base_class, save_load_methods in library_classes.items():
576580
class_candidate = getattr(library, base_class, None)
577581
if class_candidate is not None and issubclass(model_cls, class_candidate):
@@ -581,6 +585,12 @@ def is_saveable_module(name, value):
581585
if save_method_name is not None:
582586
break
583587

588+
if save_method_name is None:
589+
logger.warn(f"self.{pipeline_component_name}={sub_model} of type {type(sub_model)} cannot be saved.")
590+
# make sure that unsaveable components are not tried to be loaded afterward
591+
self.register_to_config(**{pipeline_component_name: (None, None)})
592+
continue
593+
584594
save_method = getattr(sub_model, save_method_name)
585595

586596
# Call the save method with the argument safe_serialization only if it's supported
@@ -596,6 +606,9 @@ def is_saveable_module(name, value):
596606

597607
save_method(os.path.join(save_directory, pipeline_component_name), **save_kwargs)
598608

609+
# finally save the config
610+
self.save_config(save_directory)
611+
599612
def to(
600613
self,
601614
torch_device: Optional[Union[str, torch.device]] = None,

0 commit comments

Comments
 (0)