Skip to content

[Pipelines] Make sure that None functions are correctly not saved #3080

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 12, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 19 additions & 6 deletions src/diffusers/pipelines/pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import inspect
import os
import re
import sys
import warnings
from dataclasses import dataclass
from pathlib import Path
Expand Down Expand Up @@ -540,11 +541,9 @@ def save_pretrained(
variant (`str`, *optional*):
If specified, weights are saved in the format pytorch_model.<variant>.bin.
"""
self.save_config(save_directory)

model_index_dict = dict(self.config)
model_index_dict.pop("_class_name")
model_index_dict.pop("_diffusers_version")
model_index_dict.pop("_class_name", None)
model_index_dict.pop("_diffusers_version", None)
model_index_dict.pop("_module", None)

expected_modules, optional_kwargs = self._get_signature_keys(self)
Expand All @@ -557,7 +556,6 @@ def is_saveable_module(name, value):
return True

model_index_dict = {k: v for k, v in model_index_dict.items() if is_saveable_module(k, v)}

for pipeline_component_name in model_index_dict.keys():
sub_model = getattr(self, pipeline_component_name)
model_cls = sub_model.__class__
Expand All @@ -571,7 +569,13 @@ def is_saveable_module(name, value):
save_method_name = None
# search for the model's base class in LOADABLE_CLASSES
for library_name, library_classes in LOADABLE_CLASSES.items():
library = importlib.import_module(library_name)
if library_name in sys.modules:
library = importlib.import_module(library_name)
else:
logger.info(
f"{library_name} is not installed. Cannot save {pipeline_component_name} as {library_classes} from {library_name}"
)

for base_class, save_load_methods in library_classes.items():
class_candidate = getattr(library, base_class, None)
if class_candidate is not None and issubclass(model_cls, class_candidate):
Expand All @@ -581,6 +585,12 @@ def is_saveable_module(name, value):
if save_method_name is not None:
break

if save_method_name is None:
logger.warn(f"self.{pipeline_component_name}={sub_model} of type {type(sub_model)} cannot be saved.")
# make sure that unsaveable components are not tried to be loaded afterward
self.register_to_config(**{pipeline_component_name: (None, None)})
continue

save_method = getattr(sub_model, save_method_name)

# Call the save method with the argument safe_serialization only if it's supported
Expand All @@ -596,6 +606,9 @@ def is_saveable_module(name, value):

save_method(os.path.join(save_directory, pipeline_component_name), **save_kwargs)

# finally save the config
self.save_config(save_directory)

def to(
self,
torch_device: Optional[Union[str, torch.device]] = None,
Expand Down