From 3c84784ac7ee7d7e83051b1e1775206d4a7c2f9e Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 19 Oct 2022 12:23:10 +0200 Subject: [PATCH 1/5] finish --- src/diffusers/_ | 572 ++++++++++++++++++ src/diffusers/__init__.py | 1 + src/diffusers/pipeline_utils.py | 20 + src/diffusers/pipelines/__init__.py | 1 + .../pipelines/stable_diffusion/__init__.py | 1 + .../pipeline_stable_diffusion_inpaint.py | 4 +- ...ipeline_stable_diffusion_inpaint_legacy.py | 407 +++++++++++++ 7 files changed, 1004 insertions(+), 2 deletions(-) create mode 100644 src/diffusers/_ create mode 100644 src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py diff --git a/src/diffusers/_ b/src/diffusers/_ new file mode 100644 index 000000000000..a37b2084929b --- /dev/null +++ b/src/diffusers/_ @@ -0,0 +1,572 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib +import inspect +import os +from dataclasses import dataclass +from typing import List, Optional, Union + +import numpy as np +import torch + +import diffusers +import PIL +from huggingface_hub import snapshot_download +from packaging import version +from PIL import Image +from tqdm.auto import tqdm + +from . import __version__ +from .configuration_utils import ConfigMixin +from .dynamic_modules_utils import get_class_from_dynamic_module +from .schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME +from .utils import ( + CONFIG_NAME, + DIFFUSERS_CACHE, + ONNX_WEIGHTS_NAME, + WEIGHTS_NAME, + BaseOutput, + is_transformers_available, + logging, +) + + +if is_transformers_available(): + import transformers + from transformers import PreTrainedModel + + +INDEX_FILE = "diffusion_pytorch_model.bin" +CUSTOM_PIPELINE_FILE_NAME = "pipeline.py" +DUMMY_MODULES_FOLDER = "diffusers.utils" + + +logger = logging.get_logger(__name__) + + +LOADABLE_CLASSES = { + "diffusers": { + "ModelMixin": ["save_pretrained", "from_pretrained"], + "SchedulerMixin": ["save_config", "from_config"], + "DiffusionPipeline": ["save_pretrained", "from_pretrained"], + "OnnxRuntimeModel": ["save_pretrained", "from_pretrained"], + }, + "transformers": { + "PreTrainedTokenizer": ["save_pretrained", "from_pretrained"], + "PreTrainedTokenizerFast": ["save_pretrained", "from_pretrained"], + "PreTrainedModel": ["save_pretrained", "from_pretrained"], + "FeatureExtractionMixin": ["save_pretrained", "from_pretrained"], + }, +} + +ALL_IMPORTABLE_CLASSES = {} +for library in LOADABLE_CLASSES: + ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES[library]) + + +@dataclass +class ImagePipelineOutput(BaseOutput): + """ + Output class for image pipelines. + + Args: + images (`List[PIL.Image.Image]` or `np.ndarray`) + List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, + num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. + """ + + images: Union[List[PIL.Image.Image], np.ndarray] + + +class DiffusionPipeline(ConfigMixin): + r""" + Base class for all models. + + [`DiffusionPipeline`] takes care of storing all components (models, schedulers, processors) for diffusion pipelines + and handles methods for loading, downloading and saving models as well as a few methods common to all pipelines to: + + - move all PyTorch modules to the device of your choice + - enabling/disabling the progress bar for the denoising iteration + + Class attributes: + + - **config_name** ([`str`]) -- name of the config file that will store the class and module names of all + components of the diffusion pipeline. + """ + config_name = "model_index.json" + + def register_modules(self, **kwargs): + # import it here to avoid circular import + from diffusers import pipelines + + for name, module in kwargs.items(): + # retrieve library + if module is None: + register_dict = {name: (None, None)} + else: + library = module.__module__.split(".")[0] + + # check if the module is a pipeline module + pipeline_dir = module.__module__.split(".")[-2] + path = module.__module__.split(".") + is_pipeline_module = pipeline_dir in path and hasattr(pipelines, pipeline_dir) + + # if library is not in LOADABLE_CLASSES, then it is a custom module. + # Or if it's a pipeline module, then the module is inside the pipeline + # folder so we set the library to module name. + if library not in LOADABLE_CLASSES or is_pipeline_module: + library = pipeline_dir + + # retrieve class_name + class_name = module.__class__.__name__ + + register_dict = {name: (library, class_name)} + + # save model index config + self.register_to_config(**register_dict) + + # set models + setattr(self, name, module) + + def save_pretrained(self, save_directory: Union[str, os.PathLike]): + """ + Save all variables of the pipeline that can be saved and loaded as well as the pipelines configuration file to + a directory. A pipeline variable can be saved and loaded if its class implements both a save and loading + method. The pipeline can easily be re-loaded using the `[`~DiffusionPipeline.from_pretrained`]` class method. + + Arguments: + save_directory (`str` or `os.PathLike`): + Directory to which to save. Will be created if it doesn't exist. + """ + self.save_config(save_directory) + + model_index_dict = dict(self.config) + model_index_dict.pop("_class_name") + model_index_dict.pop("_diffusers_version") + model_index_dict.pop("_module", None) + + for pipeline_component_name in model_index_dict.keys(): + sub_model = getattr(self, pipeline_component_name) + model_cls = sub_model.__class__ + + save_method_name = None + # search for the model's base class in LOADABLE_CLASSES + for library_name, library_classes in LOADABLE_CLASSES.items(): + library = importlib.import_module(library_name) + for base_class, save_load_methods in library_classes.items(): + class_candidate = getattr(library, base_class) + if issubclass(model_cls, class_candidate): + # if we found a suitable base class in LOADABLE_CLASSES then grab its save method + save_method_name = save_load_methods[0] + break + if save_method_name is not None: + break + + save_method = getattr(sub_model, save_method_name) + save_method(os.path.join(save_directory, pipeline_component_name)) + + def to(self, torch_device: Optional[Union[str, torch.device]] = None): + if torch_device is None: + return self + + module_names, _ = self.extract_init_dict(dict(self.config)) + for name in module_names.keys(): + module = getattr(self, name) + if isinstance(module, torch.nn.Module): + if module.dtype == torch.float16 and str(torch_device) in ["cpu", "mps"]: + logger.warning( + "Pipelines loaded with `torch_dtype=torch.float16` cannot run with `cpu` or `mps` device. It" + " is not recommended to move them to `cpu` or `mps` as running them will fail. Please make" + " sure to use a `cuda` device to run the pipeline in inference. due to the lack of support for" + " `float16` operations on those devices in PyTorch. Please remove the" + " `torch_dtype=torch.float16` argument, or use a `cuda` device to run inference." + ) + module.to(torch_device) + return self + + @property + def device(self) -> torch.device: + r""" + Returns: + `torch.device`: The torch device on which the pipeline is located. + """ + module_names, _ = self.extract_init_dict(dict(self.config)) + for name in module_names.keys(): + module = getattr(self, name) + if isinstance(module, torch.nn.Module): + return module.device + return torch.device("cpu") + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs): + r""" + Instantiate a PyTorch diffusion pipeline from pre-trained pipeline weights. + + The pipeline is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). + + The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come + pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning + task. + + The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those + weights are discarded. + + Parameters: + pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*): + Can be either: + + - A string, the *repo id* of a pretrained pipeline hosted inside a model repo on + https://huggingface.co/ Valid repo ids have to be located under a user or organization name, like + `CompVis/ldm-text2im-large-256`. + - A path to a *directory* containing pipeline weights saved using + [`~DiffusionPipeline.save_pretrained`], e.g., `./my_pipeline_directory/`. + torch_dtype (`str` or `torch.dtype`, *optional*): + Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype + will be automatically derived from the model's weights. + custom_pipeline (`str`, *optional*): + + + + This is an experimental feature and is likely to change in the future. + + + + Can be either: + + - A string, the *repo id* of a custom pipeline hosted inside a model repo on + https://huggingface.co/. Valid repo ids have to be located under a user or organization name, + like `hf-internal-testing/diffusers-dummy-pipeline`. + + + + It is required that the model repo has a file, called `pipeline.py` that defines the custom + pipeline. + + + + - A string, the *file name* of a community pipeline hosted on GitHub under + https://github.com/huggingface/diffusers/tree/main/examples/community. Valid file names have to + match exactly the file name without `.py` located under the above link, *e.g.* + `clip_guided_stable_diffusion`. + + + + Community pipelines are always loaded from the current `main` branch of GitHub. + + + + - A path to a *directory* containing a custom pipeline, e.g., `./my_pipeline_directory/`. + + + + It is required that the directory has a file, called `pipeline.py` that defines the custom + pipeline. + + + + For more information on how to load and create custom pipelines, please have a look at [Loading and + Creating Custom + Pipelines](https://huggingface.co/docs/diffusers/main/en/using-diffusers/custom_pipelines) + + torch_dtype (`str` or `torch.dtype`, *optional*): + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + resume_download (`bool`, *optional*, defaults to `False`): + Whether or not to delete incompletely received files. Will attempt to resume the download if such a + file exists. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + output_loading_info(`bool`, *optional*, defaults to `False`): + Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. + local_files_only(`bool`, *optional*, defaults to `False`): + Whether or not to only look at local files (i.e., do not try to download the model). + use_auth_token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated + when running `huggingface-cli login` (stored in `~/.huggingface`). + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a + git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any + identifier allowed by git. + mirror (`str`, *optional*): + Mirror source to accelerate downloads in China. If you are from China and have an accessibility + problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety. + Please refer to the mirror site for more information. specify the folder name here. + + kwargs (remaining dictionary of keyword arguments, *optional*): + Can be used to overwrite load - and saveable variables - *i.e.* the pipeline components - of the + specific pipeline class. The overwritten components are then directly passed to the pipelines + `__init__` method. See example below for more information. + + + + It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated + models](https://huggingface.co/docs/hub/models-gated#gated-models), *e.g.* `"CompVis/stable-diffusion-v1-4"` + + + + + + Activate the special ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use + this method in a firewalled environment. + + + + Examples: + + ```py + >>> from diffusers import DiffusionPipeline + + >>> # Download pipeline from huggingface.co and cache. + >>> pipeline = DiffusionPipeline.from_pretrained("CompVis/ldm-text2im-large-256") + + >>> # Download pipeline that requires an authorization token + >>> # For more information on access tokens, please refer to this section + >>> # of the documentation](https://huggingface.co/docs/hub/security-tokens) + >>> pipeline = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4") + + >>> # Download pipeline, but overwrite scheduler + >>> from diffusers import LMSDiscreteScheduler + + >>> scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear") + >>> pipeline = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", scheduler=scheduler) + ``` + """ + cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) + resume_download = kwargs.pop("resume_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", False) + use_auth_token = kwargs.pop("use_auth_token", None) + revision = kwargs.pop("revision", None) + torch_dtype = kwargs.pop("torch_dtype", None) + custom_pipeline = kwargs.pop("custom_pipeline", None) + provider = kwargs.pop("provider", None) + sess_options = kwargs.pop("sess_options", None) + device_map = kwargs.pop("device_map", None) + + # 1. Download the checkpoints and configs + # use snapshot download here to get it working from from_pretrained + if not os.path.isdir(pretrained_model_name_or_path): + config_dict = cls.get_config_dict( + pretrained_model_name_or_path, + cache_dir=cache_dir, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + ) + # make sure we only download sub-folders and `diffusers` filenames + folder_names = [k for k in config_dict.keys() if not k.startswith("_")] + allow_patterns = [os.path.join(k, "*") for k in folder_names] + allow_patterns += [WEIGHTS_NAME, SCHEDULER_CONFIG_NAME, CONFIG_NAME, ONNX_WEIGHTS_NAME, cls.config_name] + + if custom_pipeline is not None: + allow_patterns += [CUSTOM_PIPELINE_FILE_NAME] + + requested_pipeline_class = config_dict.get("_class_name", cls.__name__) + user_agent = {"diffusers": __version__, "pipeline_class": requested_pipeline_class} + if custom_pipeline is not None: + user_agent["custom_pipeline"] = custom_pipeline + + # download all allow_patterns + cached_folder = snapshot_download( + pretrained_model_name_or_path, + cache_dir=cache_dir, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + allow_patterns=allow_patterns, + user_agent=user_agent, + ) + else: + cached_folder = pretrained_model_name_or_path + + config_dict = cls.get_config_dict(cached_folder) + + # 2. Load the pipeline class, if using custom module then load it from the hub + # if we load from explicit class, let's use it + if custom_pipeline is not None: + pipeline_class = get_class_from_dynamic_module( + custom_pipeline, module_file=CUSTOM_PIPELINE_FILE_NAME, cache_dir=custom_pipeline + ) + elif cls != DiffusionPipeline: + pipeline_class = cls + else: + diffusers_module = importlib.import_module(cls.__module__.split(".")[0]) + pipeline_class = getattr(diffusers_module, config_dict["_class_name"]) + + # To be removed in 1.0.0 + if pipeline_class.__name__ == "StableDiffusionInpaintPipeline" and config_dict["_diffusers_version"] + + # some modules can be passed directly to the init + # in this case they are already instantiated in `kwargs` + # extract them here + expected_modules = set(inspect.signature(pipeline_class.__init__).parameters.keys()) - set(["self"]) + passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs} + + init_dict, _ = pipeline_class.extract_init_dict(config_dict, **kwargs) + + import ipdb; ipdb.set_trace() + + init_kwargs = {} + + # import it here to avoid circular import + from diffusers import pipelines + + # 3. Load each module in the pipeline + for name, (library_name, class_name) in init_dict.items(): + # 3.1 - now that JAX/Flax is an official framework of the library, we might load from Flax names + if class_name.startswith("Flax"): + class_name = class_name[4:] + + is_pipeline_module = hasattr(pipelines, library_name) + loaded_sub_model = None + sub_model_should_be_defined = True + + # if the model is in a pipeline module, then we load it from the pipeline + if name in passed_class_obj: + # 1. check that passed_class_obj has correct parent class + if not is_pipeline_module: + library = importlib.import_module(library_name) + class_obj = getattr(library, class_name) + importable_classes = LOADABLE_CLASSES[library_name] + class_candidates = {c: getattr(library, c) for c in importable_classes.keys()} + + expected_class_obj = None + for class_name, class_candidate in class_candidates.items(): + if issubclass(class_obj, class_candidate): + expected_class_obj = class_candidate + + if not issubclass(passed_class_obj[name].__class__, expected_class_obj): + raise ValueError( + f"{passed_class_obj[name]} is of type: {type(passed_class_obj[name])}, but should be" + f" {expected_class_obj}" + ) + elif passed_class_obj[name] is None: + logger.warn( + f"You have passed `None` for {name} to disable its functionality in {pipeline_class}. Note" + f" that this might lead to problems when using {pipeline_class} and is not recommended." + ) + sub_model_should_be_defined = False + else: + logger.warn( + f"You have passed a non-standard module {passed_class_obj[name]}. We cannot verify whether it" + " has the correct type" + ) + + # set passed class object + loaded_sub_model = passed_class_obj[name] + elif is_pipeline_module: + pipeline_module = getattr(pipelines, library_name) + class_obj = getattr(pipeline_module, class_name) + importable_classes = ALL_IMPORTABLE_CLASSES + class_candidates = {c: class_obj for c in importable_classes.keys()} + else: + # else we just import it from the library. + library = importlib.import_module(library_name) + class_obj = getattr(library, class_name) + importable_classes = LOADABLE_CLASSES[library_name] + class_candidates = {c: getattr(library, c) for c in importable_classes.keys()} + + if loaded_sub_model is None and sub_model_should_be_defined: + load_method_name = None + for class_name, class_candidate in class_candidates.items(): + if issubclass(class_obj, class_candidate): + load_method_name = importable_classes[class_name][1] + + if load_method_name is None: + none_module = class_obj.__module__ + if none_module.startswith(DUMMY_MODULES_FOLDER) and "dummy" in none_module: + # call class_obj for nice error message of missing requirements + class_obj() + + raise ValueError( + f"The component {class_obj} of {pipeline_class} cannot be loaded as it does not seem to have" + f" any of the loading methods defined in {ALL_IMPORTABLE_CLASSES}." + ) + + load_method = getattr(class_obj, load_method_name) + loading_kwargs = {} + + if issubclass(class_obj, torch.nn.Module): + loading_kwargs["torch_dtype"] = torch_dtype + if issubclass(class_obj, diffusers.OnnxRuntimeModel): + loading_kwargs["provider"] = provider + loading_kwargs["sess_options"] = sess_options + + is_diffusers_model = issubclass(class_obj, diffusers.ModelMixin) + is_transformers_model = ( + is_transformers_available() + and issubclass(class_obj, PreTrainedModel) + and version.parse(version.parse(transformers.__version__).base_version) >= version.parse("4.20.0") + ) + + if is_diffusers_model or is_transformers_model: + loading_kwargs["device_map"] = device_map + + # check if the module is in a subdirectory + if os.path.isdir(os.path.join(cached_folder, name)): + loaded_sub_model = load_method(os.path.join(cached_folder, name), **loading_kwargs) + else: + # else load from the root directory + loaded_sub_model = load_method(cached_folder, **loading_kwargs) + + init_kwargs[name] = loaded_sub_model # UNet(...), # DiffusionSchedule(...) + + # 4. Potentially add passed objects if expected + missing_modules = set(expected_modules) - set(init_kwargs.keys()) + if len(missing_modules) > 0 and missing_modules <= set(passed_class_obj.keys()): + for module in missing_modules: + init_kwargs[module] = passed_class_obj[module] + elif len(missing_modules) > 0: + passed_modules = set(list(init_kwargs.keys()) + list(passed_class_obj.keys())) + raise ValueError( + f"Pipeline {pipeline_class} expected {expected_modules}, but only {passed_modules} were passed." + ) + + # 5. Instantiate the pipeline + model = pipeline_class(**init_kwargs) + return model + + @staticmethod + def numpy_to_pil(images): + """ + Convert a numpy image or a batch of images to a PIL image. + """ + if images.ndim == 3: + images = images[None, ...] + images = (images * 255).round().astype("uint8") + pil_images = [Image.fromarray(image) for image in images] + + return pil_images + + def progress_bar(self, iterable): + if not hasattr(self, "_progress_bar_config"): + self._progress_bar_config = {} + elif not isinstance(self._progress_bar_config, dict): + raise ValueError( + f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}." + ) + + return tqdm(iterable, **self._progress_bar_config) + + def set_progress_bar_config(self, **kwargs): + self._progress_bar_config = kwargs diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index b12993225759..f36d3284777e 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -52,6 +52,7 @@ LDMTextToImagePipeline, StableDiffusionImg2ImgPipeline, StableDiffusionInpaintPipeline, + StableDiffusionInpaintPipelineLegacy, StableDiffusionPipeline, ) else: diff --git a/src/diffusers/pipeline_utils.py b/src/diffusers/pipeline_utils.py index d3e6113dac8e..c579e252926b 100644 --- a/src/diffusers/pipeline_utils.py +++ b/src/diffusers/pipeline_utils.py @@ -40,6 +40,7 @@ ONNX_WEIGHTS_NAME, WEIGHTS_NAME, BaseOutput, + deprecate, is_transformers_available, logging, ) @@ -413,6 +414,25 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P diffusers_module = importlib.import_module(cls.__module__.split(".")[0]) pipeline_class = getattr(diffusers_module, config_dict["_class_name"]) + # To be removed in 1.0.0 + if pipeline_class.__name__ == "StableDiffusionInpaintPipeline" and version.parse( + version.parse(config_dict["_diffusers_version"]).base_version + ) <= version.parse("0.5.1"): + from diffusers import StableDiffusionInpaintPipeline, StableDiffusionInpaintPipelineLegacy + + pipeline_class = StableDiffusionInpaintPipelineLegacy + + deprecation_message = ( + "You are using a legacy checkpoint for inpainting with Stable Diffusion, therefore we are loading the" + f" {StableDiffusionInpaintPipelineLegacy} class instead of {StableDiffusionInpaintPipeline}.For better" + " inpainting results, we strongly suggest using Stable Diffusion's official inpainting checkpoint:" + " https://huggingface.co/runwayml/stable-diffusion-inpainting instead or adapting your checkpoint" + f" {pretrained_model_name_or_path} to the format of" + " https://huggingface.co/runwayml/stable-diffusion-inpainting.Note that we do not actively maintain" + " the {StableDiffusionInpaintPipelineLegacy} class and will likely remove it in version 1.0.0." + ) + deprecate("StableDiffusionInpaintPipelineLegacy", "1.0.0", deprecation_message, standard_warn=False) + # some modules can be passed directly to the init # in this case they are already instantiated in `kwargs` # extract them here diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 93eaf2dac678..3c4f877fa830 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -16,6 +16,7 @@ from .stable_diffusion import ( StableDiffusionImg2ImgPipeline, StableDiffusionInpaintPipeline, + StableDiffusionInpaintPipelineLegacy, StableDiffusionPipeline, ) diff --git a/src/diffusers/pipelines/stable_diffusion/__init__.py b/src/diffusers/pipelines/stable_diffusion/__init__.py index cc1ff35dbad8..5a452138b7a8 100644 --- a/src/diffusers/pipelines/stable_diffusion/__init__.py +++ b/src/diffusers/pipelines/stable_diffusion/__init__.py @@ -31,6 +31,7 @@ class StableDiffusionPipelineOutput(BaseOutput): from .pipeline_stable_diffusion import StableDiffusionPipeline from .pipeline_stable_diffusion_img2img import StableDiffusionImg2ImgPipeline from .pipeline_stable_diffusion_inpaint import StableDiffusionInpaintPipeline + from .pipeline_stable_diffusion_inpaint_legacy import StableDiffusionInpaintPipelineLegacy from .safety_checker import StableDiffusionSafetyChecker if is_transformers_available() and is_onnx_available(): diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index cb4d552a6f50..077a7b0748d9 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -82,8 +82,6 @@ def __init__( feature_extractor: CLIPFeatureExtractor, ): super().__init__() - logger.info("`StableDiffusionInpaintPipeline` is experimental and will very likely change in the future.") - if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: deprecation_message = ( f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" @@ -223,6 +221,8 @@ def __call__( list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ + # TODO(Suraj) - adapt to your use case + if isinstance(prompt, str): batch_size = 1 elif isinstance(prompt, list): diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py new file mode 100644 index 000000000000..038178dd9a0a --- /dev/null +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py @@ -0,0 +1,407 @@ +import inspect +from typing import Callable, List, Optional, Union + +import numpy as np +import torch + +import PIL +from tqdm.auto import tqdm +from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer + +from ...configuration_utils import FrozenDict +from ...models import AutoencoderKL, UNet2DConditionModel +from ...pipeline_utils import DiffusionPipeline +from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler +from ...utils import deprecate, logging +from . import StableDiffusionPipelineOutput +from .safety_checker import StableDiffusionSafetyChecker + + +logger = logging.get_logger(__name__) + + +def preprocess_image(image): + w, h = image.size + w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 + image = image.resize((w, h), resample=PIL.Image.LANCZOS) + image = np.array(image).astype(np.float32) / 255.0 + image = image[None].transpose(0, 3, 1, 2) + image = torch.from_numpy(image) + return 2.0 * image - 1.0 + + +def preprocess_mask(mask): + mask = mask.convert("L") + w, h = mask.size + w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 + mask = mask.resize((w // 8, h // 8), resample=PIL.Image.NEAREST) + mask = np.array(mask).astype(np.float32) / 255.0 + mask = np.tile(mask, (4, 1, 1)) + mask = mask[None].transpose(0, 1, 2, 3) # what does this step do? + mask = 1 - mask # repaint white, keep black + mask = torch.from_numpy(mask) + return mask + + +class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline): + r""" + Pipeline for text-guided image inpainting using Stable Diffusion. *This is an experimental feature*. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details. + feature_extractor ([`CLIPFeatureExtractor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPFeatureExtractor, + ): + super().__init__() + if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if safety_checker is None: + logger.warn( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + + def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module will split the input tensor in slices, to compute attention + in several steps. This is useful to save some memory in exchange for a small speed decrease. + + Args: + slice_size (`str` or `int`, *optional*, defaults to `"auto"`): + When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If + a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, + `attention_head_dim` must be a multiple of `slice_size`. + """ + if slice_size == "auto": + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = self.unet.config.attention_head_dim // 2 + self.unet.set_attention_slice(slice_size) + + def disable_attention_slicing(self): + r""" + Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go + back to computing attention in one step. + """ + # set slice_size = `None` to disable `set_attention_slice` + self.enable_attention_slicing(None) + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]], + init_image: Union[torch.FloatTensor, PIL.Image.Image], + mask_image: Union[torch.FloatTensor, PIL.Image.Image], + strength: float = 0.8, + num_inference_steps: Optional[int] = 50, + guidance_scale: Optional[float] = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: Optional[float] = 0.0, + generator: Optional[torch.Generator] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: Optional[int] = 1, + **kwargs, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + init_image (`torch.FloatTensor` or `PIL.Image.Image`): + `Image`, or tensor representing an image batch, that will be used as the starting point for the + process. This is the image whose masked region will be inpainted. + mask_image (`torch.FloatTensor` or `PIL.Image.Image`): + `Image`, or tensor representing an image batch, to mask `init_image`. White pixels in the mask will be + replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a + PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should + contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`. + strength (`float`, *optional*, defaults to 0.8): + Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1. When `strength` + is 1, the denoising process will be run on the masked area for the full number of iterations specified + in `num_inference_steps`. `init_image` will be used as a reference for the masked area, adding more + noise to that region the larger the `strength`. If `strength` is 0, no inpainting will occur. + num_inference_steps (`int`, *optional*, defaults to 50): + The reference number of denoising steps. More denoising steps usually lead to a higher quality image at + the expense of slower inference. This parameter will be modulated by `strength`, as explained above. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + if isinstance(prompt, str): + batch_size = 1 + elif isinstance(prompt, list): + batch_size = len(prompt) + else: + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + # set timesteps + self.scheduler.set_timesteps(num_inference_steps) + + # get prompt text embeddings + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + + if text_input_ids.shape[-1] > self.tokenizer.model_max_length: + removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] + text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0] + + # duplicate text embeddings for each generation per prompt + text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0) + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + max_length = text_input_ids.shape[-1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0] + + # duplicate unconditional embeddings for each generation per prompt + uncond_embeddings = uncond_embeddings.repeat_interleave(batch_size * num_images_per_prompt, dim=0) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + + # preprocess image + if not isinstance(init_image, torch.FloatTensor): + init_image = preprocess_image(init_image) + + # encode the init image into latents and scale the latents + latents_dtype = text_embeddings.dtype + init_image = init_image.to(device=self.device, dtype=latents_dtype) + init_latent_dist = self.vae.encode(init_image).latent_dist + init_latents = init_latent_dist.sample(generator=generator) + init_latents = 0.18215 * init_latents + + # Expand init_latents for batch_size and num_images_per_prompt + init_latents = torch.cat([init_latents] * batch_size * num_images_per_prompt, dim=0) + init_latents_orig = init_latents + + # preprocess mask + if not isinstance(mask_image, torch.FloatTensor): + mask_image = preprocess_mask(mask_image) + mask_image = mask_image.to(device=self.device, dtype=latents_dtype) + mask = torch.cat([mask_image] * batch_size * num_images_per_prompt) + + # check sizes + if not mask.shape == init_latents.shape: + raise ValueError("The mask and init_image should be the same size!") + + # get the original timestep using init_timestep + offset = self.scheduler.config.get("steps_offset", 0) + init_timestep = int(num_inference_steps * strength) + offset + init_timestep = min(init_timestep, num_inference_steps) + + timesteps = self.scheduler.timesteps[-init_timestep] + timesteps = torch.tensor([timesteps] * batch_size * num_images_per_prompt, device=self.device) + + # add noise to latents using the timesteps + noise = torch.randn(init_latents.shape, generator=generator, device=self.device, dtype=latents_dtype) + init_latents = self.scheduler.add_noise(init_latents, noise, timesteps) + + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + latents = init_latents + + t_start = max(num_inference_steps - init_timestep + offset, 0) + + # Some schedulers like PNDM have timesteps as arrays + # It's more optimized to move all timesteps to correct device beforehand + timesteps = self.scheduler.timesteps[t_start:].to(self.device) + + for i, t in tqdm(enumerate(timesteps)): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + # masking + init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t])) + + latents = (init_latents_proper * mask) + (latents * (1 - mask)) + + # call the callback, if provided + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + latents = 1 / 0.18215 * latents + image = self.vae.decode(latents).sample + + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).numpy() + + if self.safety_checker is not None: + safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to( + self.device + ) + image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_checker_input.pixel_values) + else: + has_nsfw_concept = None + + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) From ec18a126b68973588d6b6215f5da1e9d606b3208 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 19 Oct 2022 12:23:27 +0200 Subject: [PATCH 2/5] up --- src/diffusers/_ | 572 ------------------------------------------------ 1 file changed, 572 deletions(-) delete mode 100644 src/diffusers/_ diff --git a/src/diffusers/_ b/src/diffusers/_ deleted file mode 100644 index a37b2084929b..000000000000 --- a/src/diffusers/_ +++ /dev/null @@ -1,572 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The HuggingFace Inc. team. -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import importlib -import inspect -import os -from dataclasses import dataclass -from typing import List, Optional, Union - -import numpy as np -import torch - -import diffusers -import PIL -from huggingface_hub import snapshot_download -from packaging import version -from PIL import Image -from tqdm.auto import tqdm - -from . import __version__ -from .configuration_utils import ConfigMixin -from .dynamic_modules_utils import get_class_from_dynamic_module -from .schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME -from .utils import ( - CONFIG_NAME, - DIFFUSERS_CACHE, - ONNX_WEIGHTS_NAME, - WEIGHTS_NAME, - BaseOutput, - is_transformers_available, - logging, -) - - -if is_transformers_available(): - import transformers - from transformers import PreTrainedModel - - -INDEX_FILE = "diffusion_pytorch_model.bin" -CUSTOM_PIPELINE_FILE_NAME = "pipeline.py" -DUMMY_MODULES_FOLDER = "diffusers.utils" - - -logger = logging.get_logger(__name__) - - -LOADABLE_CLASSES = { - "diffusers": { - "ModelMixin": ["save_pretrained", "from_pretrained"], - "SchedulerMixin": ["save_config", "from_config"], - "DiffusionPipeline": ["save_pretrained", "from_pretrained"], - "OnnxRuntimeModel": ["save_pretrained", "from_pretrained"], - }, - "transformers": { - "PreTrainedTokenizer": ["save_pretrained", "from_pretrained"], - "PreTrainedTokenizerFast": ["save_pretrained", "from_pretrained"], - "PreTrainedModel": ["save_pretrained", "from_pretrained"], - "FeatureExtractionMixin": ["save_pretrained", "from_pretrained"], - }, -} - -ALL_IMPORTABLE_CLASSES = {} -for library in LOADABLE_CLASSES: - ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES[library]) - - -@dataclass -class ImagePipelineOutput(BaseOutput): - """ - Output class for image pipelines. - - Args: - images (`List[PIL.Image.Image]` or `np.ndarray`) - List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, - num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. - """ - - images: Union[List[PIL.Image.Image], np.ndarray] - - -class DiffusionPipeline(ConfigMixin): - r""" - Base class for all models. - - [`DiffusionPipeline`] takes care of storing all components (models, schedulers, processors) for diffusion pipelines - and handles methods for loading, downloading and saving models as well as a few methods common to all pipelines to: - - - move all PyTorch modules to the device of your choice - - enabling/disabling the progress bar for the denoising iteration - - Class attributes: - - - **config_name** ([`str`]) -- name of the config file that will store the class and module names of all - components of the diffusion pipeline. - """ - config_name = "model_index.json" - - def register_modules(self, **kwargs): - # import it here to avoid circular import - from diffusers import pipelines - - for name, module in kwargs.items(): - # retrieve library - if module is None: - register_dict = {name: (None, None)} - else: - library = module.__module__.split(".")[0] - - # check if the module is a pipeline module - pipeline_dir = module.__module__.split(".")[-2] - path = module.__module__.split(".") - is_pipeline_module = pipeline_dir in path and hasattr(pipelines, pipeline_dir) - - # if library is not in LOADABLE_CLASSES, then it is a custom module. - # Or if it's a pipeline module, then the module is inside the pipeline - # folder so we set the library to module name. - if library not in LOADABLE_CLASSES or is_pipeline_module: - library = pipeline_dir - - # retrieve class_name - class_name = module.__class__.__name__ - - register_dict = {name: (library, class_name)} - - # save model index config - self.register_to_config(**register_dict) - - # set models - setattr(self, name, module) - - def save_pretrained(self, save_directory: Union[str, os.PathLike]): - """ - Save all variables of the pipeline that can be saved and loaded as well as the pipelines configuration file to - a directory. A pipeline variable can be saved and loaded if its class implements both a save and loading - method. The pipeline can easily be re-loaded using the `[`~DiffusionPipeline.from_pretrained`]` class method. - - Arguments: - save_directory (`str` or `os.PathLike`): - Directory to which to save. Will be created if it doesn't exist. - """ - self.save_config(save_directory) - - model_index_dict = dict(self.config) - model_index_dict.pop("_class_name") - model_index_dict.pop("_diffusers_version") - model_index_dict.pop("_module", None) - - for pipeline_component_name in model_index_dict.keys(): - sub_model = getattr(self, pipeline_component_name) - model_cls = sub_model.__class__ - - save_method_name = None - # search for the model's base class in LOADABLE_CLASSES - for library_name, library_classes in LOADABLE_CLASSES.items(): - library = importlib.import_module(library_name) - for base_class, save_load_methods in library_classes.items(): - class_candidate = getattr(library, base_class) - if issubclass(model_cls, class_candidate): - # if we found a suitable base class in LOADABLE_CLASSES then grab its save method - save_method_name = save_load_methods[0] - break - if save_method_name is not None: - break - - save_method = getattr(sub_model, save_method_name) - save_method(os.path.join(save_directory, pipeline_component_name)) - - def to(self, torch_device: Optional[Union[str, torch.device]] = None): - if torch_device is None: - return self - - module_names, _ = self.extract_init_dict(dict(self.config)) - for name in module_names.keys(): - module = getattr(self, name) - if isinstance(module, torch.nn.Module): - if module.dtype == torch.float16 and str(torch_device) in ["cpu", "mps"]: - logger.warning( - "Pipelines loaded with `torch_dtype=torch.float16` cannot run with `cpu` or `mps` device. It" - " is not recommended to move them to `cpu` or `mps` as running them will fail. Please make" - " sure to use a `cuda` device to run the pipeline in inference. due to the lack of support for" - " `float16` operations on those devices in PyTorch. Please remove the" - " `torch_dtype=torch.float16` argument, or use a `cuda` device to run inference." - ) - module.to(torch_device) - return self - - @property - def device(self) -> torch.device: - r""" - Returns: - `torch.device`: The torch device on which the pipeline is located. - """ - module_names, _ = self.extract_init_dict(dict(self.config)) - for name in module_names.keys(): - module = getattr(self, name) - if isinstance(module, torch.nn.Module): - return module.device - return torch.device("cpu") - - @classmethod - def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs): - r""" - Instantiate a PyTorch diffusion pipeline from pre-trained pipeline weights. - - The pipeline is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). - - The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come - pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning - task. - - The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those - weights are discarded. - - Parameters: - pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*): - Can be either: - - - A string, the *repo id* of a pretrained pipeline hosted inside a model repo on - https://huggingface.co/ Valid repo ids have to be located under a user or organization name, like - `CompVis/ldm-text2im-large-256`. - - A path to a *directory* containing pipeline weights saved using - [`~DiffusionPipeline.save_pretrained`], e.g., `./my_pipeline_directory/`. - torch_dtype (`str` or `torch.dtype`, *optional*): - Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype - will be automatically derived from the model's weights. - custom_pipeline (`str`, *optional*): - - - - This is an experimental feature and is likely to change in the future. - - - - Can be either: - - - A string, the *repo id* of a custom pipeline hosted inside a model repo on - https://huggingface.co/. Valid repo ids have to be located under a user or organization name, - like `hf-internal-testing/diffusers-dummy-pipeline`. - - - - It is required that the model repo has a file, called `pipeline.py` that defines the custom - pipeline. - - - - - A string, the *file name* of a community pipeline hosted on GitHub under - https://github.com/huggingface/diffusers/tree/main/examples/community. Valid file names have to - match exactly the file name without `.py` located under the above link, *e.g.* - `clip_guided_stable_diffusion`. - - - - Community pipelines are always loaded from the current `main` branch of GitHub. - - - - - A path to a *directory* containing a custom pipeline, e.g., `./my_pipeline_directory/`. - - - - It is required that the directory has a file, called `pipeline.py` that defines the custom - pipeline. - - - - For more information on how to load and create custom pipelines, please have a look at [Loading and - Creating Custom - Pipelines](https://huggingface.co/docs/diffusers/main/en/using-diffusers/custom_pipelines) - - torch_dtype (`str` or `torch.dtype`, *optional*): - force_download (`bool`, *optional*, defaults to `False`): - Whether or not to force the (re-)download of the model weights and configuration files, overriding the - cached versions if they exist. - resume_download (`bool`, *optional*, defaults to `False`): - Whether or not to delete incompletely received files. Will attempt to resume the download if such a - file exists. - proxies (`Dict[str, str]`, *optional*): - A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', - 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. - output_loading_info(`bool`, *optional*, defaults to `False`): - Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. - local_files_only(`bool`, *optional*, defaults to `False`): - Whether or not to only look at local files (i.e., do not try to download the model). - use_auth_token (`str` or *bool*, *optional*): - The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated - when running `huggingface-cli login` (stored in `~/.huggingface`). - revision (`str`, *optional*, defaults to `"main"`): - The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a - git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any - identifier allowed by git. - mirror (`str`, *optional*): - Mirror source to accelerate downloads in China. If you are from China and have an accessibility - problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety. - Please refer to the mirror site for more information. specify the folder name here. - - kwargs (remaining dictionary of keyword arguments, *optional*): - Can be used to overwrite load - and saveable variables - *i.e.* the pipeline components - of the - specific pipeline class. The overwritten components are then directly passed to the pipelines - `__init__` method. See example below for more information. - - - - It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated - models](https://huggingface.co/docs/hub/models-gated#gated-models), *e.g.* `"CompVis/stable-diffusion-v1-4"` - - - - - - Activate the special ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use - this method in a firewalled environment. - - - - Examples: - - ```py - >>> from diffusers import DiffusionPipeline - - >>> # Download pipeline from huggingface.co and cache. - >>> pipeline = DiffusionPipeline.from_pretrained("CompVis/ldm-text2im-large-256") - - >>> # Download pipeline that requires an authorization token - >>> # For more information on access tokens, please refer to this section - >>> # of the documentation](https://huggingface.co/docs/hub/security-tokens) - >>> pipeline = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4") - - >>> # Download pipeline, but overwrite scheduler - >>> from diffusers import LMSDiscreteScheduler - - >>> scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear") - >>> pipeline = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", scheduler=scheduler) - ``` - """ - cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) - resume_download = kwargs.pop("resume_download", False) - proxies = kwargs.pop("proxies", None) - local_files_only = kwargs.pop("local_files_only", False) - use_auth_token = kwargs.pop("use_auth_token", None) - revision = kwargs.pop("revision", None) - torch_dtype = kwargs.pop("torch_dtype", None) - custom_pipeline = kwargs.pop("custom_pipeline", None) - provider = kwargs.pop("provider", None) - sess_options = kwargs.pop("sess_options", None) - device_map = kwargs.pop("device_map", None) - - # 1. Download the checkpoints and configs - # use snapshot download here to get it working from from_pretrained - if not os.path.isdir(pretrained_model_name_or_path): - config_dict = cls.get_config_dict( - pretrained_model_name_or_path, - cache_dir=cache_dir, - resume_download=resume_download, - proxies=proxies, - local_files_only=local_files_only, - use_auth_token=use_auth_token, - revision=revision, - ) - # make sure we only download sub-folders and `diffusers` filenames - folder_names = [k for k in config_dict.keys() if not k.startswith("_")] - allow_patterns = [os.path.join(k, "*") for k in folder_names] - allow_patterns += [WEIGHTS_NAME, SCHEDULER_CONFIG_NAME, CONFIG_NAME, ONNX_WEIGHTS_NAME, cls.config_name] - - if custom_pipeline is not None: - allow_patterns += [CUSTOM_PIPELINE_FILE_NAME] - - requested_pipeline_class = config_dict.get("_class_name", cls.__name__) - user_agent = {"diffusers": __version__, "pipeline_class": requested_pipeline_class} - if custom_pipeline is not None: - user_agent["custom_pipeline"] = custom_pipeline - - # download all allow_patterns - cached_folder = snapshot_download( - pretrained_model_name_or_path, - cache_dir=cache_dir, - resume_download=resume_download, - proxies=proxies, - local_files_only=local_files_only, - use_auth_token=use_auth_token, - revision=revision, - allow_patterns=allow_patterns, - user_agent=user_agent, - ) - else: - cached_folder = pretrained_model_name_or_path - - config_dict = cls.get_config_dict(cached_folder) - - # 2. Load the pipeline class, if using custom module then load it from the hub - # if we load from explicit class, let's use it - if custom_pipeline is not None: - pipeline_class = get_class_from_dynamic_module( - custom_pipeline, module_file=CUSTOM_PIPELINE_FILE_NAME, cache_dir=custom_pipeline - ) - elif cls != DiffusionPipeline: - pipeline_class = cls - else: - diffusers_module = importlib.import_module(cls.__module__.split(".")[0]) - pipeline_class = getattr(diffusers_module, config_dict["_class_name"]) - - # To be removed in 1.0.0 - if pipeline_class.__name__ == "StableDiffusionInpaintPipeline" and config_dict["_diffusers_version"] - - # some modules can be passed directly to the init - # in this case they are already instantiated in `kwargs` - # extract them here - expected_modules = set(inspect.signature(pipeline_class.__init__).parameters.keys()) - set(["self"]) - passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs} - - init_dict, _ = pipeline_class.extract_init_dict(config_dict, **kwargs) - - import ipdb; ipdb.set_trace() - - init_kwargs = {} - - # import it here to avoid circular import - from diffusers import pipelines - - # 3. Load each module in the pipeline - for name, (library_name, class_name) in init_dict.items(): - # 3.1 - now that JAX/Flax is an official framework of the library, we might load from Flax names - if class_name.startswith("Flax"): - class_name = class_name[4:] - - is_pipeline_module = hasattr(pipelines, library_name) - loaded_sub_model = None - sub_model_should_be_defined = True - - # if the model is in a pipeline module, then we load it from the pipeline - if name in passed_class_obj: - # 1. check that passed_class_obj has correct parent class - if not is_pipeline_module: - library = importlib.import_module(library_name) - class_obj = getattr(library, class_name) - importable_classes = LOADABLE_CLASSES[library_name] - class_candidates = {c: getattr(library, c) for c in importable_classes.keys()} - - expected_class_obj = None - for class_name, class_candidate in class_candidates.items(): - if issubclass(class_obj, class_candidate): - expected_class_obj = class_candidate - - if not issubclass(passed_class_obj[name].__class__, expected_class_obj): - raise ValueError( - f"{passed_class_obj[name]} is of type: {type(passed_class_obj[name])}, but should be" - f" {expected_class_obj}" - ) - elif passed_class_obj[name] is None: - logger.warn( - f"You have passed `None` for {name} to disable its functionality in {pipeline_class}. Note" - f" that this might lead to problems when using {pipeline_class} and is not recommended." - ) - sub_model_should_be_defined = False - else: - logger.warn( - f"You have passed a non-standard module {passed_class_obj[name]}. We cannot verify whether it" - " has the correct type" - ) - - # set passed class object - loaded_sub_model = passed_class_obj[name] - elif is_pipeline_module: - pipeline_module = getattr(pipelines, library_name) - class_obj = getattr(pipeline_module, class_name) - importable_classes = ALL_IMPORTABLE_CLASSES - class_candidates = {c: class_obj for c in importable_classes.keys()} - else: - # else we just import it from the library. - library = importlib.import_module(library_name) - class_obj = getattr(library, class_name) - importable_classes = LOADABLE_CLASSES[library_name] - class_candidates = {c: getattr(library, c) for c in importable_classes.keys()} - - if loaded_sub_model is None and sub_model_should_be_defined: - load_method_name = None - for class_name, class_candidate in class_candidates.items(): - if issubclass(class_obj, class_candidate): - load_method_name = importable_classes[class_name][1] - - if load_method_name is None: - none_module = class_obj.__module__ - if none_module.startswith(DUMMY_MODULES_FOLDER) and "dummy" in none_module: - # call class_obj for nice error message of missing requirements - class_obj() - - raise ValueError( - f"The component {class_obj} of {pipeline_class} cannot be loaded as it does not seem to have" - f" any of the loading methods defined in {ALL_IMPORTABLE_CLASSES}." - ) - - load_method = getattr(class_obj, load_method_name) - loading_kwargs = {} - - if issubclass(class_obj, torch.nn.Module): - loading_kwargs["torch_dtype"] = torch_dtype - if issubclass(class_obj, diffusers.OnnxRuntimeModel): - loading_kwargs["provider"] = provider - loading_kwargs["sess_options"] = sess_options - - is_diffusers_model = issubclass(class_obj, diffusers.ModelMixin) - is_transformers_model = ( - is_transformers_available() - and issubclass(class_obj, PreTrainedModel) - and version.parse(version.parse(transformers.__version__).base_version) >= version.parse("4.20.0") - ) - - if is_diffusers_model or is_transformers_model: - loading_kwargs["device_map"] = device_map - - # check if the module is in a subdirectory - if os.path.isdir(os.path.join(cached_folder, name)): - loaded_sub_model = load_method(os.path.join(cached_folder, name), **loading_kwargs) - else: - # else load from the root directory - loaded_sub_model = load_method(cached_folder, **loading_kwargs) - - init_kwargs[name] = loaded_sub_model # UNet(...), # DiffusionSchedule(...) - - # 4. Potentially add passed objects if expected - missing_modules = set(expected_modules) - set(init_kwargs.keys()) - if len(missing_modules) > 0 and missing_modules <= set(passed_class_obj.keys()): - for module in missing_modules: - init_kwargs[module] = passed_class_obj[module] - elif len(missing_modules) > 0: - passed_modules = set(list(init_kwargs.keys()) + list(passed_class_obj.keys())) - raise ValueError( - f"Pipeline {pipeline_class} expected {expected_modules}, but only {passed_modules} were passed." - ) - - # 5. Instantiate the pipeline - model = pipeline_class(**init_kwargs) - return model - - @staticmethod - def numpy_to_pil(images): - """ - Convert a numpy image or a batch of images to a PIL image. - """ - if images.ndim == 3: - images = images[None, ...] - images = (images * 255).round().astype("uint8") - pil_images = [Image.fromarray(image) for image in images] - - return pil_images - - def progress_bar(self, iterable): - if not hasattr(self, "_progress_bar_config"): - self._progress_bar_config = {} - elif not isinstance(self._progress_bar_config, dict): - raise ValueError( - f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}." - ) - - return tqdm(iterable, **self._progress_bar_config) - - def set_progress_bar_config(self, **kwargs): - self._progress_bar_config = kwargs From 953e4e1af5d8910e5f02fc139b05c7fc03be8c34 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 19 Oct 2022 12:52:44 +0200 Subject: [PATCH 3/5] Apply suggestions from code review Co-authored-by: Anton Lozhkov --- src/diffusers/pipeline_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipeline_utils.py b/src/diffusers/pipeline_utils.py index c579e252926b..10d31e38310d 100644 --- a/src/diffusers/pipeline_utils.py +++ b/src/diffusers/pipeline_utils.py @@ -428,7 +428,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P " inpainting results, we strongly suggest using Stable Diffusion's official inpainting checkpoint:" " https://huggingface.co/runwayml/stable-diffusion-inpainting instead or adapting your checkpoint" f" {pretrained_model_name_or_path} to the format of" - " https://huggingface.co/runwayml/stable-diffusion-inpainting.Note that we do not actively maintain" + " https://huggingface.co/runwayml/stable-diffusion-inpainting. Note that we do not actively maintain" " the {StableDiffusionInpaintPipelineLegacy} class and will likely remove it in version 1.0.0." ) deprecate("StableDiffusionInpaintPipelineLegacy", "1.0.0", deprecation_message, standard_warn=False) From affb64d5b36ad4836bdf270c7c6936abc8ffaa35 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 19 Oct 2022 12:53:20 +0200 Subject: [PATCH 4/5] Update src/diffusers/pipeline_utils.py --- src/diffusers/pipeline_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipeline_utils.py b/src/diffusers/pipeline_utils.py index 10d31e38310d..bb07d4eabdac 100644 --- a/src/diffusers/pipeline_utils.py +++ b/src/diffusers/pipeline_utils.py @@ -424,7 +424,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P deprecation_message = ( "You are using a legacy checkpoint for inpainting with Stable Diffusion, therefore we are loading the" - f" {StableDiffusionInpaintPipelineLegacy} class instead of {StableDiffusionInpaintPipeline}.For better" + f" {StableDiffusionInpaintPipelineLegacy} class instead of {StableDiffusionInpaintPipeline}. For better" " inpainting results, we strongly suggest using Stable Diffusion's official inpainting checkpoint:" " https://huggingface.co/runwayml/stable-diffusion-inpainting instead or adapting your checkpoint" f" {pretrained_model_name_or_path} to the format of" From 4b82b0907755f569b537e73e20d079718b809951 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 19 Oct 2022 12:55:14 +0200 Subject: [PATCH 5/5] Finish --- src/diffusers/pipeline_utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/diffusers/pipeline_utils.py b/src/diffusers/pipeline_utils.py index bb07d4eabdac..37bc228b10e6 100644 --- a/src/diffusers/pipeline_utils.py +++ b/src/diffusers/pipeline_utils.py @@ -424,10 +424,10 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P deprecation_message = ( "You are using a legacy checkpoint for inpainting with Stable Diffusion, therefore we are loading the" - f" {StableDiffusionInpaintPipelineLegacy} class instead of {StableDiffusionInpaintPipeline}. For better" - " inpainting results, we strongly suggest using Stable Diffusion's official inpainting checkpoint:" - " https://huggingface.co/runwayml/stable-diffusion-inpainting instead or adapting your checkpoint" - f" {pretrained_model_name_or_path} to the format of" + f" {StableDiffusionInpaintPipelineLegacy} class instead of {StableDiffusionInpaintPipeline}. For" + " better inpainting results, we strongly suggest using Stable Diffusion's official inpainting" + " checkpoint: https://huggingface.co/runwayml/stable-diffusion-inpainting instead or adapting your" + f" checkpoint {pretrained_model_name_or_path} to the format of" " https://huggingface.co/runwayml/stable-diffusion-inpainting. Note that we do not actively maintain" " the {StableDiffusionInpaintPipelineLegacy} class and will likely remove it in version 1.0.0." )