diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 2e20c21aaf38..72c4363da3c6 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -19,6 +19,7 @@ import inspect import os import re +import sys import warnings from dataclasses import dataclass from pathlib import Path @@ -540,11 +541,9 @@ def save_pretrained( variant (`str`, *optional*): If specified, weights are saved in the format pytorch_model..bin. """ - self.save_config(save_directory) - model_index_dict = dict(self.config) - model_index_dict.pop("_class_name") - model_index_dict.pop("_diffusers_version") + model_index_dict.pop("_class_name", None) + model_index_dict.pop("_diffusers_version", None) model_index_dict.pop("_module", None) expected_modules, optional_kwargs = self._get_signature_keys(self) @@ -557,7 +556,6 @@ def is_saveable_module(name, value): return True model_index_dict = {k: v for k, v in model_index_dict.items() if is_saveable_module(k, v)} - for pipeline_component_name in model_index_dict.keys(): sub_model = getattr(self, pipeline_component_name) model_cls = sub_model.__class__ @@ -571,7 +569,13 @@ def is_saveable_module(name, value): save_method_name = None # search for the model's base class in LOADABLE_CLASSES for library_name, library_classes in LOADABLE_CLASSES.items(): - library = importlib.import_module(library_name) + if library_name in sys.modules: + library = importlib.import_module(library_name) + else: + logger.info( + f"{library_name} is not installed. Cannot save {pipeline_component_name} as {library_classes} from {library_name}" + ) + for base_class, save_load_methods in library_classes.items(): class_candidate = getattr(library, base_class, None) if class_candidate is not None and issubclass(model_cls, class_candidate): @@ -581,6 +585,12 @@ def is_saveable_module(name, value): if save_method_name is not None: break + if save_method_name is None: + logger.warn(f"self.{pipeline_component_name}={sub_model} of type {type(sub_model)} cannot be saved.") + # make sure that unsaveable components are not tried to be loaded afterward + self.register_to_config(**{pipeline_component_name: (None, None)}) + continue + save_method = getattr(sub_model, save_method_name) # Call the save method with the argument safe_serialization only if it's supported @@ -596,6 +606,9 @@ def is_saveable_module(name, value): save_method(os.path.join(save_directory, pipeline_component_name), **save_kwargs) + # finally save the config + self.save_config(save_directory) + def to( self, torch_device: Optional[Union[str, torch.device]] = None,