diff --git a/docs/source/en/api/loaders.mdx b/docs/source/en/api/loaders.mdx index 8cbf21b8e0cf..20134a0afe66 100644 --- a/docs/source/en/api/loaders.mdx +++ b/docs/source/en/api/loaders.mdx @@ -36,3 +36,7 @@ API to load such adapter neural networks via the [`loaders.py` module](https://g ### LoraLoaderMixin [[autodoc]] loaders.LoraLoaderMixin + +### FromCkptMixin + +[[autodoc]] loaders.FromCkptMixin diff --git a/docs/source/en/api/pipelines/stable_diffusion/controlnet.mdx b/docs/source/en/api/pipelines/stable_diffusion/controlnet.mdx index af859177c002..dabd3ded31ce 100644 --- a/docs/source/en/api/pipelines/stable_diffusion/controlnet.mdx +++ b/docs/source/en/api/pipelines/stable_diffusion/controlnet.mdx @@ -308,6 +308,7 @@ All checkpoints can be found under the authors' namespace [lllyasviel](https://h - disable_vae_slicing - enable_xformers_memory_efficient_attention - disable_xformers_memory_efficient_attention + - load_textual_inversion ## FlaxStableDiffusionControlNetPipeline [[autodoc]] FlaxStableDiffusionControlNetPipeline diff --git a/docs/source/en/api/pipelines/stable_diffusion/depth2img.mdx b/docs/source/en/api/pipelines/stable_diffusion/depth2img.mdx index c46576ff2887..a91167bac58c 100644 --- a/docs/source/en/api/pipelines/stable_diffusion/depth2img.mdx +++ b/docs/source/en/api/pipelines/stable_diffusion/depth2img.mdx @@ -30,4 +30,7 @@ Available Checkpoints are: - enable_attention_slicing - disable_attention_slicing - enable_xformers_memory_efficient_attention - - disable_xformers_memory_efficient_attention \ No newline at end of file + - disable_xformers_memory_efficient_attention + - load_textual_inversion + - load_lora_weights + - save_lora_weights diff --git a/docs/source/en/api/pipelines/stable_diffusion/img2img.mdx b/docs/source/en/api/pipelines/stable_diffusion/img2img.mdx index 09bfb853f9c9..7959c588608b 100644 --- a/docs/source/en/api/pipelines/stable_diffusion/img2img.mdx +++ b/docs/source/en/api/pipelines/stable_diffusion/img2img.mdx @@ -30,7 +30,11 @@ proposed by Chenlin Meng, Yutong He, Yang Song, Jiaming Song, Jiajun Wu, Jun-Yan - disable_attention_slicing - enable_xformers_memory_efficient_attention - disable_xformers_memory_efficient_attention + - load_textual_inversion + - from_ckpt + - load_lora_weights + - save_lora_weights [[autodoc]] FlaxStableDiffusionImg2ImgPipeline - all - - __call__ \ No newline at end of file + - __call__ diff --git a/docs/source/en/api/pipelines/stable_diffusion/inpaint.mdx b/docs/source/en/api/pipelines/stable_diffusion/inpaint.mdx index 33e84a63261f..39e5ae0fd37d 100644 --- a/docs/source/en/api/pipelines/stable_diffusion/inpaint.mdx +++ b/docs/source/en/api/pipelines/stable_diffusion/inpaint.mdx @@ -31,7 +31,10 @@ Available checkpoints are: - disable_attention_slicing - enable_xformers_memory_efficient_attention - disable_xformers_memory_efficient_attention + - load_textual_inversion + - load_lora_weights + - save_lora_weights [[autodoc]] FlaxStableDiffusionInpaintPipeline - all - - __call__ \ No newline at end of file + - __call__ diff --git a/docs/source/en/api/pipelines/stable_diffusion/pix2pix.mdx b/docs/source/en/api/pipelines/stable_diffusion/pix2pix.mdx index 42cd4b896b2e..d01f1df23385 100644 --- a/docs/source/en/api/pipelines/stable_diffusion/pix2pix.mdx +++ b/docs/source/en/api/pipelines/stable_diffusion/pix2pix.mdx @@ -68,3 +68,6 @@ images[0].save("snowy_mountains.png") [[autodoc]] StableDiffusionInstructPix2PixPipeline - __call__ - all + - load_textual_inversion + - load_lora_weights + - save_lora_weights diff --git a/docs/source/en/api/pipelines/stable_diffusion/text2img.mdx b/docs/source/en/api/pipelines/stable_diffusion/text2img.mdx index 6b8d53bf6510..ce78434fdbaa 100644 --- a/docs/source/en/api/pipelines/stable_diffusion/text2img.mdx +++ b/docs/source/en/api/pipelines/stable_diffusion/text2img.mdx @@ -39,6 +39,10 @@ Available Checkpoints are: - disable_xformers_memory_efficient_attention - enable_vae_tiling - disable_vae_tiling + - load_textual_inversion + - from_ckpt + - load_lora_weights + - save_lora_weights [[autodoc]] FlaxStableDiffusionPipeline - all diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 07c17100e0e0..40029fcecfd1 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -109,7 +109,6 @@ except OptionalDependencyNotAvailable: from .utils.dummy_torch_and_transformers_objects import * # noqa F403 else: - from .loaders import TextualInversionLoaderMixin from .pipelines import ( AltDiffusionImg2ImgPipeline, AltDiffusionPipeline, diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index e814981a85c9..3133da117390 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -13,9 +13,11 @@ # limitations under the License. import os from collections import defaultdict +from pathlib import Path from typing import Callable, Dict, List, Optional, Union import torch +from huggingface_hub import hf_hub_download from .models.attention_processor import LoRAAttnProcessor from .utils import ( @@ -431,6 +433,7 @@ def load_textual_inversion( Example: To load a textual inversion embedding vector in `diffusers` format: + ```py from diffusers import StableDiffusionPipeline import torch @@ -463,6 +466,7 @@ def load_textual_inversion( image = pipe(prompt, num_inference_steps=50).images[0] image.save("character.png") ``` + """ if not hasattr(self, "tokenizer") or not isinstance(self.tokenizer, PreTrainedTokenizer): raise ValueError( @@ -1051,3 +1055,197 @@ def save_function(weights, filename): save_function(state_dict, os.path.join(save_directory, weight_name)) logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}") + + +class FromCkptMixin: + """This helper class allows to directly load .ckpt stable diffusion file_extension + into the respective classes.""" + + @classmethod + def from_ckpt(cls, pretrained_model_link_or_path, **kwargs): + r""" + Instantiate a PyTorch diffusion pipeline from pre-trained pipeline weights saved in the original .ckpt format. + + The pipeline is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). + + Parameters: + pretrained_model_link_or_path (`str` or `os.PathLike`, *optional*): + Can be either: + - A link to the .ckpt file on the Hub. Should be in the format + `"https://huggingface.co//blob/main/"` + - A path to a *file* containing all pipeline weights. + torch_dtype (`str` or `torch.dtype`, *optional*): + Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype + will be automatically derived from the model's weights. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory in which a downloaded pretrained model configuration should be cached if the + standard cache should not be used. + resume_download (`bool`, *optional*, defaults to `False`): + Whether or not to delete incompletely received files. Will attempt to resume the download if such a + file exists. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + local_files_only (`bool`, *optional*, defaults to `False`): + Whether or not to only look at local files (i.e., do not try to download the model). + use_auth_token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated + when running `huggingface-cli login` (stored in `~/.huggingface`). + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a + git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any + identifier allowed by git. + use_safetensors (`bool`, *optional* ): + If set to `True`, the pipeline will be loaded from `safetensors` weights. If set to `None` (the + default). The pipeline will load using `safetensors` if the safetensors weights are available *and* if + `safetensors` is installed. If the to `False` the pipeline will *not* use `safetensors`. + extract_ema (`bool`, *optional*, defaults to `False`): Only relevant for + checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights or not. Defaults + to `False`. Pass `True` to extract the EMA weights. EMA weights usually yield higher quality images for + inference. Non-EMA weights are usually better to continue fine-tuning. + upcast_attention (`bool`, *optional*, defaults to `None`): + Whether the attention computation should always be upcasted. This is necessary when running stable + image_size (`int`, *optional*, defaults to 512): + The image size that the model was trained on. Use 512 for Stable Diffusion v1.X and Stable Diffusion v2 + Base. Use 768 for Stable Diffusion v2. + prediction_type (`str`, *optional*): + The prediction type that the model was trained on. Use `'epsilon'` for Stable Diffusion v1.X and Stable + Diffusion v2 Base. Use `'v_prediction'` for Stable Diffusion v2. + num_in_channels (`int`, *optional*, defaults to None): + The number of input channels. If `None`, it will be automatically inferred. + scheduler_type (`str`, *optional*, defaults to 'pndm'): + Type of scheduler to use. Should be one of `["pndm", "lms", "heun", "euler", "euler-ancestral", "dpm", + "ddim"]`. + load_safety_checker (`bool`, *optional*, defaults to `True`): + Whether to load the safety checker or not. Defaults to `True`. + kwargs (remaining dictionary of keyword arguments, *optional*): + Can be used to overwrite load - and saveable variables - *i.e.* the pipeline components - of the + specific pipeline class. The overwritten components are then directly passed to the pipelines + `__init__` method. See example below for more information. + + Examples: + + ```py + >>> from diffusers import StableDiffusionPipeline + + >>> # Download pipeline from huggingface.co and cache. + >>> pipeline = StableDiffusionPipeline.from_ckpt( + ... "https://huggingface.co/WarriorMama777/OrangeMixs/blob/main/Models/AbyssOrangeMix/AbyssOrangeMix.safetensors" + ... ) + + >>> # Download pipeline from local file + >>> # file is downloaded under ./v1-5-pruned-emaonly.ckpt + >>> pipeline = StableDiffusionPipeline.from_ckpt("./v1-5-pruned-emaonly") + + >>> # Enable float16 and move to GPU + >>> pipeline = StableDiffusionPipeline.from_ckpt( + ... "https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.ckpt", + ... torch_dtype=torch.float16, + ... ) + >>> pipeline.to("cuda") + ``` + """ + # import here to avoid circular dependency + from .pipelines.stable_diffusion.convert_from_ckpt import download_from_original_stable_diffusion_ckpt + + cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) + resume_download = kwargs.pop("resume_download", False) + force_download = kwargs.pop("force_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE) + use_auth_token = kwargs.pop("use_auth_token", None) + revision = kwargs.pop("revision", None) + extract_ema = kwargs.pop("extract_ema", False) + image_size = kwargs.pop("image_size", 512) + scheduler_type = kwargs.pop("scheduler_type", "pndm") + num_in_channels = kwargs.pop("num_in_channels", None) + upcast_attention = kwargs.pop("upcast_attention", None) + load_safety_checker = kwargs.pop("load_safety_checker", True) + prediction_type = kwargs.pop("prediction_type", None) + + torch_dtype = kwargs.pop("torch_dtype", None) + + use_safetensors = kwargs.pop("use_safetensors", None if is_safetensors_available() else False) + + pipeline_name = cls.__name__ + file_extension = pretrained_model_link_or_path.rsplit(".", 1)[-1] + from_safetensors = file_extension == "safetensors" + + if from_safetensors and use_safetensors is True: + raise ValueError("Make sure to install `safetensors` with `pip install safetensors`.") + + # TODO: For now we only support stable diffusion + stable_unclip = None + controlnet = False + + if pipeline_name == "StableDiffusionControlNetPipeline": + model_type = "FrozenCLIPEmbedder" + controlnet = True + elif "StableDiffusion" in pipeline_name: + model_type = "FrozenCLIPEmbedder" + elif pipeline_name == "StableUnCLIPPipeline": + model_type == "FrozenOpenCLIPEmbedder" + stable_unclip = "txt2img" + elif pipeline_name == "StableUnCLIPImg2ImgPipeline": + model_type == "FrozenOpenCLIPEmbedder" + stable_unclip = "img2img" + elif pipeline_name == "PaintByExamplePipeline": + model_type == "PaintByExample" + elif pipeline_name == "LDMTextToImagePipeline": + model_type == "LDMTextToImage" + else: + raise ValueError(f"Unhandled pipeline class: {pipeline_name}") + + # remove huggingface url + for prefix in ["https://huggingface.co/", "huggingface.co/", "hf.co/", "https://hf.co/"]: + if pretrained_model_link_or_path.startswith(prefix): + pretrained_model_link_or_path = pretrained_model_link_or_path[len(prefix) :] + + # Code based on diffusers.pipelines.pipeline_utils.DiffusionPipeline.from_pretrained + ckpt_path = Path(pretrained_model_link_or_path) + if not ckpt_path.is_file(): + # get repo_id and (potentially nested) file path of ckpt in repo + repo_id = str(Path().joinpath(*ckpt_path.parts[:2])) + file_path = str(Path().joinpath(*ckpt_path.parts[2:])) + + if file_path.startswith("blob/"): + file_path = file_path[len("blob/") :] + + if file_path.startswith("main/"): + file_path = file_path[len("main/") :] + + pretrained_model_link_or_path = hf_hub_download( + repo_id, + filename=file_path, + cache_dir=cache_dir, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + force_download=force_download, + ) + + pipe = download_from_original_stable_diffusion_ckpt( + pretrained_model_link_or_path, + pipeline_class=cls, + model_type=model_type, + stable_unclip=stable_unclip, + controlnet=controlnet, + from_safetensors=from_safetensors, + extract_ema=extract_ema, + image_size=image_size, + scheduler_type=scheduler_type, + num_in_channels=num_in_channels, + upcast_attention=upcast_attention, + load_safety_checker=load_safety_checker, + prediction_type=prediction_type, + ) + + if torch_dtype is not None: + pipe.to(torch_dtype=torch_dtype) + + return pipe diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py index bf314b91116e..ff9474ffd43a 100644 --- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py @@ -57,6 +57,14 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin): This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + In addition the pipeline inherits the following loading methods: + - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`] + - *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`] + - *Ckpt*: [`loaders.FromCkptMixin.from_ckpt`] + + as well as the following saving methods: + - *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`] + Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py index 86fc47f424e9..dee4a91924f7 100644 --- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py +++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py @@ -96,6 +96,14 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + In addition the pipeline inherits the following loading methods: + - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`] + - *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`] + - *Ckpt*: [`loaders.FromCkptMixin.from_ckpt`] + + as well as the following saving methods: + - *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`] + Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. diff --git a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py index dbc1b27e88be..5961636dd197 100644 --- a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +++ b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py @@ -31,35 +31,30 @@ CLIPVisionModelWithProjection, ) -from diffusers import ( +from ...models import ( AutoencoderKL, ControlNetModel, + PriorTransformer, + UNet2DConditionModel, +) +from ...schedulers import ( DDIMScheduler, DDPMScheduler, DPMSolverMultistepScheduler, EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, HeunDiscreteScheduler, - LDMTextToImagePipeline, LMSDiscreteScheduler, PNDMScheduler, - PriorTransformer, - StableDiffusionControlNetPipeline, - StableDiffusionImg2ImgPipeline, - StableDiffusionInpaintPipeline, - StableDiffusionPipeline, - StableUnCLIPImg2ImgPipeline, - StableUnCLIPPipeline, UnCLIPScheduler, - UNet2DConditionModel, ) -from diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel -from diffusers.pipelines.paint_by_example import PaintByExampleImageEncoder, PaintByExamplePipeline -from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker -from diffusers.pipelines.stable_diffusion.stable_unclip_image_normalizer import StableUnCLIPImageNormalizer - from ...utils import is_omegaconf_available, is_safetensors_available, logging from ...utils.import_utils import BACKENDS_MAPPING +from ..latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel +from ..paint_by_example import PaintByExampleImageEncoder +from ..pipeline_utils import DiffusionPipeline +from .safety_checker import StableDiffusionSafetyChecker +from .stable_unclip_image_normalizer import StableUnCLIPImageNormalizer logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -981,7 +976,6 @@ def download_from_original_stable_diffusion_ckpt( image_size: int = 512, prediction_type: str = None, model_type: str = None, - is_img2img: bool = False, extract_ema: bool = False, scheduler_type: str = "pndm", num_in_channels: Optional[int] = None, @@ -993,7 +987,8 @@ def download_from_original_stable_diffusion_ckpt( clip_stats_path: Optional[str] = None, controlnet: Optional[bool] = None, load_safety_checker: bool = True, -) -> StableDiffusionPipeline: + pipeline_class: DiffusionPipeline = None, +) -> DiffusionPipeline: """ Load a Stable Diffusion pipeline object from a CompVis-style `.ckpt`/`.safetensors` file and (ideally) a `.yaml` config file. @@ -1031,12 +1026,29 @@ def download_from_original_stable_diffusion_ckpt( Whether the attention computation should always be upcasted. This is necessary when running stable diffusion 2.1. device (`str`, *optional*, defaults to `None`): - The device to use. Pass `None` to determine automatically. :param from_safetensors: If `checkpoint_path` is - in `safetensors` format, load checkpoint with safetensors instead of PyTorch. :return: A - StableDiffusionPipeline object representing the passed-in `.ckpt`/`.safetensors` file. + The device to use. Pass `None` to determine automatically. + from_safetensors (`str`, *optional*, defaults to `False`): + If `checkpoint_path` is in `safetensors` format, load checkpoint with safetensors instead of PyTorch. load_safety_checker (`bool`, *optional*, defaults to `True`): Whether to load the safety checker or not. Defaults to `True`. + pipeline_class (`str`, *optional*, defaults to `None`): + The pipeline class to use. Pass `None` to determine automatically. + return: A StableDiffusionPipeline object representing the passed-in `.ckpt`/`.safetensors` file. """ + + # import pipelines here to avoid circular import error when using from_ckpt method + from diffusers import ( + LDMTextToImagePipeline, + PaintByExamplePipeline, + StableDiffusionControlNetPipeline, + StableDiffusionPipeline, + StableUnCLIPImg2ImgPipeline, + StableUnCLIPPipeline, + ) + + if pipeline_class is None: + pipeline_class = StableDiffusionPipeline + if prediction_type == "v-prediction": prediction_type = "v_prediction" @@ -1198,44 +1210,16 @@ def download_from_original_stable_diffusion_ckpt( requires_safety_checker=False, ) else: - if ( - hasattr(original_config, "model") - and hasattr(original_config.model, "target") - and "LatentInpaintDiffusion" in original_config.model.target - ): - pipe = StableDiffusionInpaintPipeline( - vae=vae, - text_encoder=text_model, - tokenizer=tokenizer, - unet=unet, - scheduler=scheduler, - safety_checker=None, - feature_extractor=None, - requires_safety_checker=False, - ) - else: - if is_img2img: - pipe = StableDiffusionImg2ImgPipeline( - vae=vae, - text_encoder=text_model, - tokenizer=tokenizer, - unet=unet, - scheduler=scheduler, - safety_checker=None, - feature_extractor=None, - requires_safety_checker=False, - ) - else: - pipe = StableDiffusionPipeline( - vae=vae, - text_encoder=text_model, - tokenizer=tokenizer, - unet=unet, - scheduler=scheduler, - safety_checker=None, - feature_extractor=None, - requires_safety_checker=False, - ) + pipe = pipeline_class( + vae=vae, + text_encoder=text_model, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=None, + feature_extractor=None, + requires_safety_checker=False, + ) else: image_normalizer, image_noising_scheduler = stable_unclip_image_noising_components( original_config, clip_stats_path=clip_stats_path, device=device @@ -1326,41 +1310,15 @@ def download_from_original_stable_diffusion_ckpt( feature_extractor=feature_extractor, ) else: - if ( - hasattr(original_config, "model") - and hasattr(original_config.model, "target") - and "LatentInpaintDiffusion" in original_config.model.target - ): - pipe = StableDiffusionInpaintPipeline( - vae=vae, - text_encoder=text_model, - tokenizer=tokenizer, - unet=unet, - scheduler=scheduler, - safety_checker=safety_checker, - feature_extractor=feature_extractor, - ) - else: - if is_img2img: - pipe = StableDiffusionImg2ImgPipeline( - vae=vae, - text_encoder=text_model, - tokenizer=tokenizer, - unet=unet, - scheduler=scheduler, - safety_checker=safety_checker, - feature_extractor=feature_extractor, - ) - else: - pipe = StableDiffusionPipeline( - vae=vae, - text_encoder=text_model, - tokenizer=tokenizer, - unet=unet, - scheduler=scheduler, - safety_checker=safety_checker, - feature_extractor=feature_extractor, - ) + pipe = pipeline_class( + vae=vae, + text_encoder=text_model, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) else: text_config = create_ldm_bert_config(original_config) text_model = convert_ldm_bert_checkpoint(checkpoint, text_config) @@ -1379,7 +1337,7 @@ def download_controlnet_from_original_ckpt( upcast_attention: Optional[bool] = None, device: str = None, from_safetensors: bool = False, -) -> StableDiffusionPipeline: +) -> DiffusionPipeline: if not is_omegaconf_available(): raise ValueError(BACKENDS_MAPPING["omegaconf"][1]) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 689febe3e891..7347d70c4023 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -20,7 +20,7 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from ...configuration_utils import FrozenDict -from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin +from ...loaders import FromCkptMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( @@ -53,13 +53,21 @@ """ -class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin): +class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromCkptMixin): r""" Pipeline for text-to-image generation using Stable Diffusion. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + In addition the pipeline inherits the following loading methods: + - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`] + - *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`] + - *Ckpt*: [`loaders.FromCkptMixin.from_ckpt`] + + as well as the following saving methods: + - *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`] + Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py index 3b8889d92b55..322f2232fc8a 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py @@ -156,6 +156,9 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + In addition the pipeline inherits the following loading methods: + - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`] + Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py index 4fe117ba120b..c4f9ae59a4e9 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py @@ -23,7 +23,7 @@ from transformers import CLIPTextModel, CLIPTokenizer, DPTFeatureExtractor, DPTForDepthEstimation from ...configuration_utils import FrozenDict -from ...loaders import TextualInversionLoaderMixin +from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers from ...utils import PIL_INTERPOLATION, deprecate, is_accelerate_available, logging, randn_tensor @@ -55,13 +55,20 @@ def preprocess(image): return image -class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin): +class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin): r""" Pipeline for text-guided image to image generation using Stable Diffusion. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + In addition the pipeline inherits the following loading methods: + - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`] + - *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`] + + as well as the following saving methods: + - *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`] + Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index 5860a53ad528..c26ddf06cadc 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -23,7 +23,7 @@ from ...configuration_utils import FrozenDict from ...image_processor import VaeImageProcessor -from ...loaders import TextualInversionLoaderMixin +from ...loaders import FromCkptMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( @@ -92,13 +92,21 @@ def preprocess(image): return image -class StableDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin): +class StableDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromCkptMixin): r""" Pipeline for text-guided image to image generation using Stable Diffusion. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + In addition the pipeline inherits the following loading methods: + - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`] + - *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`] + - *Ckpt*: [`loaders.FromCkptMixin.from_ckpt`] + + as well as the following saving methods: + - *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`] + Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index 8e0ea5a8d079..fb2e5dc424e3 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -22,7 +22,7 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from ...configuration_utils import FrozenDict -from ...loaders import TextualInversionLoaderMixin +from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers from ...utils import deprecate, is_accelerate_available, is_accelerate_version, logging, randn_tensor @@ -138,13 +138,20 @@ def prepare_mask_and_masked_image(image, mask): return mask, masked_image -class StableDiffusionInpaintPipeline(DiffusionPipeline, TextualInversionLoaderMixin): +class StableDiffusionInpaintPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin): r""" Pipeline for text-guided image inpainting using Stable Diffusion. *This is an experimental feature*. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + In addition the pipeline inherits the following loading methods: + - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`] + - *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`] + + as well as the following saving methods: + - *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`] + Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py index 6d9cbaf67a07..1c8377c7e54e 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py @@ -22,7 +22,7 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from ...configuration_utils import FrozenDict -from ...loaders import TextualInversionLoaderMixin +from ...loaders import FromCkptMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( @@ -82,13 +82,23 @@ def preprocess_mask(mask, scale_factor=8): return mask -class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline, TextualInversionLoaderMixin): +class StableDiffusionInpaintPipelineLegacy( + DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromCkptMixin +): r""" Pipeline for text-guided image inpainting using Stable Diffusion. *This is an experimental feature*. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + In addition the pipeline inherits the following loading methods: + - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`] + - *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`] + - *Ckpt*: [`loaders.FromCkptMixin.from_ckpt`] + + as well as the following saving methods: + - *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`] + Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py index f7999a08dc9b..49944cdcd636 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py @@ -20,7 +20,7 @@ import torch from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer -from ...loaders import TextualInversionLoaderMixin +from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( @@ -61,13 +61,20 @@ def preprocess(image): return image -class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversionLoaderMixin): +class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin): r""" Pipeline for pixel-level image editing by following text instructions. Based on Stable Diffusion. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + In addition the pipeline inherits the following loading methods: + - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`] + - *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`] + + as well as the following saving methods: + - *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`] + Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 8a521457f2e3..bda56d2ae8ae 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -2,21 +2,6 @@ from ..utils import DummyObject, requires_backends -class TextualInversionLoaderMixin(metaclass=DummyObject): - _backends = ["torch", "transformers"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "transformers"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers"]) - - class AltDiffusionImg2ImgPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion.py b/tests/pipelines/stable_diffusion/test_stable_diffusion.py index 14421a64b9e8..fcfcd84c5d48 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion.py @@ -36,6 +36,7 @@ UNet2DConditionModel, logging, ) +from diffusers.models.attention_processor import AttnProcessor from diffusers.utils import load_numpy, nightly, slow, torch_device from diffusers.utils.testing_utils import CaptureLogger, require_torch_gpu @@ -865,6 +866,62 @@ def test_stable_diffusion_textual_inversion(self): assert max_diff < 5e-2 +@slow +@require_torch_gpu +class StableDiffusionPipelineCkptTests(unittest.TestCase): + def tearDown(self): + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + def test_download_from_hub(self): + ckpt_paths = [ + "https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.ckpt", + "https://huggingface.co/WarriorMama777/OrangeMixs/blob/main/Models/AbyssOrangeMix/AbyssOrangeMix_base.ckpt", + ] + + for ckpt_path in ckpt_paths: + pipe = StableDiffusionPipeline.from_ckpt(ckpt_path, torch_dtype=torch.float16) + pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) + pipe.to("cuda") + + image_out = pipe("test", num_inference_steps=1, output_type="np").images[0] + + assert image_out.shape == (512, 512, 3) + + def test_download_local(self): + filename = hf_hub_download("runwayml/stable-diffusion-v1-5", filename="v1-5-pruned-emaonly.ckpt") + + pipe = StableDiffusionPipeline.from_ckpt(filename, torch_dtype=torch.float16) + pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) + pipe.to("cuda") + + image_out = pipe("test", num_inference_steps=1, output_type="np").images[0] + + assert image_out.shape == (512, 512, 3) + + def test_download_ckpt_diff_format_is_same(self): + ckpt_path = "https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.ckpt" + + pipe = StableDiffusionPipeline.from_ckpt(ckpt_path) + pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) + pipe.unet.set_attn_processor(AttnProcessor()) + pipe.to("cuda") + + generator = torch.Generator(device="cpu").manual_seed(0) + image_ckpt = pipe("a turtle", num_inference_steps=5, generator=generator, output_type="np").images[0] + + pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5") + pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) + pipe.unet.set_attn_processor(AttnProcessor()) + pipe.to("cuda") + + generator = torch.Generator(device="cpu").manual_seed(0) + image = pipe("a turtle", num_inference_steps=5, generator=generator, output_type="np").images[0] + + assert np.max(np.abs(image - image_ckpt)) < 1e-4 + + @nightly @require_torch_gpu class StableDiffusionPipelineNightlyTests(unittest.TestCase):