From 136c9f3d0cfd6f7782dff95b1914dd0809b5eb3f Mon Sep 17 00:00:00 2001 From: William Berman Date: Tue, 7 Feb 2023 22:47:42 -0800 Subject: [PATCH 01/14] pipeline_variant --- ..._original_stable_diffusion_to_diffusers.py | 23 + src/diffusers/__init__.py | 3 + src/diffusers/models/__init__.py | 1 + src/diffusers/models/noise_augmentor.py | 167 ++++ src/diffusers/models/unet_2d_condition.py | 20 +- src/diffusers/pipelines/__init__.py | 2 + .../pipelines/stable_diffusion/__init__.py | 2 + .../stable_diffusion/convert_from_ckpt.py | 187 +++- .../pipeline_stable_unclip.py | 814 ++++++++++++++++++ .../pipeline_stable_unclip_img2img.py | 727 ++++++++++++++++ .../versatile_diffusion/modeling_text_unet.py | 20 +- src/diffusers/schedulers/scheduling_ddpm.py | 5 +- src/diffusers/utils/dummy_pt_objects.py | 15 + .../dummy_torch_and_transformers_objects.py | 30 + tests/pipelines/stable_unclip/__init__.py | 0 .../stable_unclip/test_stable_unclip.py | 219 +++++ .../test_stable_unclip_img2img.py | 244 ++++++ tests/test_pipelines_common.py | 10 +- 18 files changed, 2471 insertions(+), 18 deletions(-) create mode 100644 src/diffusers/models/noise_augmentor.py create mode 100644 src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py create mode 100644 src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py create mode 100644 tests/pipelines/stable_unclip/__init__.py create mode 100644 tests/pipelines/stable_unclip/test_stable_unclip.py create mode 100644 tests/pipelines/stable_unclip/test_stable_unclip_img2img.py diff --git a/scripts/convert_original_stable_diffusion_to_diffusers.py b/scripts/convert_original_stable_diffusion_to_diffusers.py index d449f283d95e..255ce9411394 100644 --- a/scripts/convert_original_stable_diffusion_to_diffusers.py +++ b/scripts/convert_original_stable_diffusion_to_diffusers.py @@ -100,6 +100,26 @@ ) parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.") parser.add_argument("--device", type=str, help="Device to use (e.g. cpu, cuda:0, cuda:1, etc.)") + parser.add_argument( + "--stable_unclip", + type=str, + default=None, + required=False, + help="Set if this is a stable unCLIP model. One of 'txt2img' or 'img2img'.", + ) + parser.add_argument( + "--stable_unclip_prior", + type=str, + default="karlo", + required=False, + help="Set if this is a stable unCLIP txt2img model. Selects which prior to use. Defaults to 'karlo'.", + ) + parser.add_argument( + "--clip_stats_path", + type=str, + help="Path to the clip stats file. Only required with certain stable unclip models.", + required=False, + ) args = parser.parse_args() pipe = load_pipeline_from_original_stable_diffusion_ckpt( @@ -114,5 +134,8 @@ upcast_attention=args.upcast_attention, from_safetensors=args.from_safetensors, device=args.device, + stable_unclip=args.stable_unclip, + stable_unclip_prior=args.stabe_unclip_prior, + clip_stats_path=args.clip_stats_path, ) pipe.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index bc6057eaf2da..80c0caef97af 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -35,6 +35,7 @@ from .models import ( AutoencoderKL, ModelMixin, + NoiseAugmentor, PriorTransformer, Transformer2DModel, UNet1DModel, @@ -119,6 +120,8 @@ StableDiffusionPipeline, StableDiffusionPipelineSafe, StableDiffusionUpscalePipeline, + StableUnCLIPImg2ImgPipeline, + StableUnCLIPPipeline, UnCLIPImageVariationPipeline, UnCLIPPipeline, VersatileDiffusionDualGuidedPipeline, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 474b8412560e..29c64652c00d 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -19,6 +19,7 @@ from .autoencoder_kl import AutoencoderKL from .dual_transformer_2d import DualTransformer2DModel from .modeling_utils import ModelMixin + from .noise_augmentor import NoiseAugmentor from .prior_transformer import PriorTransformer from .transformer_2d import Transformer2DModel from .unet_1d import UNet1DModel diff --git a/src/diffusers/models/noise_augmentor.py b/src/diffusers/models/noise_augmentor.py new file mode 100644 index 000000000000..0db7217070e3 --- /dev/null +++ b/src/diffusers/models/noise_augmentor.py @@ -0,0 +1,167 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import List, Optional, Union + +import numpy as np +import torch +from torch import nn + +from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import randn_tensor +from .embeddings import Timesteps +from .modeling_utils import ModelMixin + + +# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar +def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> torch.Tensor: + """ + Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of + (1-beta) over time from t = [0,1]. + + Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up + to that part of the diffusion process. + + + Args: + num_diffusion_timesteps (`int`): the number of betas to produce. + max_beta (`float`): the maximum beta to use; use values lower than 1 to + prevent singularities. + + Returns: + betas (`np.ndarray`): the betas used by the scheduler to step the model outputs + """ + + def alpha_bar(time_step): + return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2 + + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return torch.tensor(betas, dtype=torch.float32) + + +class NoiseAugmentor(ModelMixin, ConfigMixin): + """ + NoiseAugmentor is used in the context of stable unclip to add noise to the image embeddings. The amount of noise is + controlled by a `noise_level` input. A higher `noise_level` increases the variance in the final un-noised images. + + The noise is applied in two ways + 1. A noise schedule is applied directly to the embeddings + 2. A vector of sinusoidal time embeddings are appended to the output. + + In both cases, the amount of noise is controlled by the same `noise_level`. + + The embeddings are normalized before the noise is applied and un-normalized after the noise is applied, so the + layer requires parameters for the mean and standard deviation of the embeddings. + """ + + @register_to_config + def __init__( + self, + embedding_dim: int = 768, + max_noise_level: int = 1000, + beta_start: float = 0.0001, + beta_end: float = 0.02, + beta_schedule: str = "linear", + trained_betas: Optional[Union[np.ndarray, List[float]]] = None, + set_alpha_to_one: bool = True, + ): + super().__init__() + self.max_noise_level = max_noise_level + + self.mean = nn.Parameter(torch.zeros(1, embedding_dim)) + self.std = nn.Parameter(torch.ones(1, embedding_dim)) + + self.time_embed = Timesteps(embedding_dim, True, 0) + + if trained_betas is not None: + self.betas = torch.tensor(trained_betas, dtype=torch.float32) + elif beta_schedule == "linear": + self.betas = torch.linspace(beta_start, beta_end, max_noise_level, dtype=torch.float32) + elif beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, max_noise_level, dtype=torch.float32) ** 2 + elif beta_schedule == "squaredcos_cap_v2": + # Glide cosine schedule + self.betas = betas_for_alpha_bar(max_noise_level) + else: + raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") + + self.alphas = 1.0 - self.betas + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + + # At every step in ddim, we are looking into the previous alphas_cumprod + # For the final step, there is no previous alphas_cumprod because we are already at 0 + # `set_alpha_to_one` decides whether we set this parameter simply to one or + # whether we use the final alpha of the "non-previous" one. + self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0] + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise + def add_noise( + self, + original_samples: torch.FloatTensor, + noise: torch.FloatTensor, + timesteps: torch.IntTensor, + ) -> torch.FloatTensor: + # Make sure alphas_cumprod and timestep have same device and dtype as original_samples + self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) + timesteps = timesteps.to(original_samples.device) + + sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = sqrt_alpha_prod.flatten() + while len(sqrt_alpha_prod.shape) < len(original_samples.shape): + sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) + + sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() + while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) + + noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise + return noisy_samples + + def forward( + self, + embeds: torch.Tensor, + noise_level: int, + noise: Optional[torch.FloatTensor] = None, + generator: Optional[torch.Generator] = None, + ): + if noise is None: + noise = randn_tensor(embeds.shape, generator=generator, device=self.device, dtype=embeds.dtype) + + noise_level = torch.tensor([noise_level] * embeds.shape[0], device=embeds.device) + + # scale + embeds = (embeds - self.mean) * 1.0 / self.std + + embeds = self.add_noise(original_samples=embeds, timesteps=noise_level, noise=noise) + + # unscale + embeds = (embeds * self.std) + self.mean + + noise_level = self.time_embed(noise_level) + + # `self.time_embed` does not contain any weights and will always return f32 tensors, + # but we might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + noise_level = noise_level.to(self.dtype) + + combined = torch.cat((embeds, noise_level), 1) + + return combined diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index ba2c09b297b9..ddda5f9d64f6 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -91,7 +91,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config for resnet blocks, see [`~models.resnet.ResnetBlock2D`]. Choose from `default` or `scale_shift`. class_embed_type (`str`, *optional*, defaults to None): The type of class embedding to use which is ultimately - summed with the time embeddings. Choose from `None`, `"timestep"`, or `"identity"`. + summed with the time embeddings. Choose from `None`, `"timestep"`, `"identity"`, or `"projection"`. num_class_embeds (`int`, *optional*, defaults to None): Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing class conditioning with `class_embed_type` equal to `None`. @@ -102,7 +102,9 @@ class conditioning with `class_embed_type` equal to `None`. time_cond_proj_dim (`int`, *optional*, default to `None`): The dimension of `cond_proj` layer in timestep embedding. conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. - conv_out_kernel (`int`, *optional*, default to `3`): the Kernel size of `conv_out` layer. + conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer. + projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when + using the "projection" `class_embed_type`. Required when using the "projection" `class_embed_type`. """ _supports_gradient_checkpointing = True @@ -145,6 +147,7 @@ def __init__( time_cond_proj_dim: Optional[int] = None, conv_in_kernel: int = 3, conv_out_kernel: int = 3, + projection_class_embeddings_input_dim: Optional[int] = None, ): super().__init__() @@ -190,6 +193,19 @@ def __init__( self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) elif class_embed_type == "identity": self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) + elif class_embed_type == "projection": + if projection_class_embeddings_input_dim is None: + raise ValueError( + "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set" + ) + # The projection `class_embed_type` is the same as the timestep `class_embed_type` except + # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings + # 2. it projects from an arbitrary input dimension. + # + # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations. + # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings. + # As a result, `TimestepEmbedding` can be passed arbitrary vectors. + self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) else: self.class_embedding = None diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index dfb2fd83cb71..0f3114f3c066 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -55,6 +55,8 @@ StableDiffusionLatentUpscalePipeline, StableDiffusionPipeline, StableDiffusionUpscalePipeline, + StableUnCLIPImg2ImgPipeline, + StableUnCLIPPipeline, ) from .stable_diffusion_safe import StableDiffusionPipelineSafe from .unclip import UnCLIPImageVariationPipeline, UnCLIPPipeline diff --git a/src/diffusers/pipelines/stable_diffusion/__init__.py b/src/diffusers/pipelines/stable_diffusion/__init__.py index bf07127cde5b..7cbf646939db 100644 --- a/src/diffusers/pipelines/stable_diffusion/__init__.py +++ b/src/diffusers/pipelines/stable_diffusion/__init__.py @@ -45,6 +45,8 @@ class StableDiffusionPipelineOutput(BaseOutput): from .pipeline_stable_diffusion_instruct_pix2pix import StableDiffusionInstructPix2PixPipeline from .pipeline_stable_diffusion_latent_upscale import StableDiffusionLatentUpscalePipeline from .pipeline_stable_diffusion_upscale import StableDiffusionUpscalePipeline + from .pipeline_stable_unclip import StableUnCLIPPipeline + from .pipeline_stable_unclip_img2img import StableUnCLIPImg2ImgPipeline from .safety_checker import StableDiffusionSafetyChecker try: diff --git a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py index 033c0a23a98e..9c620bcbf959 100644 --- a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +++ b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py @@ -17,22 +17,38 @@ import os import re import tempfile +from typing import Optional import requests import torch -from transformers import AutoFeatureExtractor, BertTokenizerFast, CLIPTextModel, CLIPTokenizer, CLIPVisionConfig +from transformers import ( + AutoFeatureExtractor, + BertTokenizerFast, + CLIPImageProcessor, + CLIPTextModel, + CLIPTextModelWithProjection, + CLIPTokenizer, + CLIPVisionConfig, + CLIPVisionModelWithProjection, +) from diffusers import ( AutoencoderKL, DDIMScheduler, + DDPMScheduler, DPMSolverMultistepScheduler, EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, HeunDiscreteScheduler, LDMTextToImagePipeline, LMSDiscreteScheduler, + NoiseAugmentor, PNDMScheduler, + PriorTransformer, StableDiffusionPipeline, + StableUnCLIPImg2ImgPipeline, + StableUnCLIPPipeline, + UnCLIPScheduler, UNet2DConditionModel, ) from diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel @@ -242,6 +258,17 @@ def create_unet_diffusers_config(original_config, image_size: int): if head_dim is None: head_dim = [5, 10, 20, 20] + class_embed_type = None + projection_class_embeddings_input_dim = None + + if "num_classes" in unet_params: + if unet_params.num_classes == "sequential": + class_embed_type = "projection" + assert "adm_in_channels" in unet_params + projection_class_embeddings_input_dim = unet_params.adm_in_channels + else: + raise NotImplementedError(f"Unknown conditional unet num_classes config: {unet_params.num_classes}") + config = dict( sample_size=image_size // vae_scale_factor, in_channels=unet_params.in_channels, @@ -253,6 +280,8 @@ def create_unet_diffusers_config(original_config, image_size: int): cross_attention_dim=unet_params.context_dim, attention_head_dim=head_dim, use_linear_projection=use_linear_projection, + class_embed_type=class_embed_type, + projection_class_embeddings_input_dim=projection_class_embeddings_input_dim, ) return config @@ -341,6 +370,17 @@ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"] new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"] + if config["class_embed_type"] is None: + # No parameters to port + ... + elif config["class_embed_type"] == "timestep" or config["class_embed_type"] == "projection": + new_checkpoint["class_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"] + new_checkpoint["class_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"] + new_checkpoint["class_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"] + new_checkpoint["class_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"] + else: + raise NotImplementedError(f"Not implemented `class_embed_type`: {config['class_embed_type']}") + new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"] new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"] @@ -779,6 +819,81 @@ def convert_open_clip_checkpoint(checkpoint): return text_model +def stable_unclip_image_encoder(original_config): + """ + Returns the image processor and clip image encoder for the img2img unclip pipeline. + + We currently know of two types of stable unclip models which separately use the clip and the openclip image + encoders. + """ + + image_embedder_config = original_config.model.params.embedder_config + + sd_clip_image_embedder_class = image_embedder_config.target + sd_clip_image_embedder_class = sd_clip_image_embedder_class.split(".")[-1] + + if sd_clip_image_embedder_class == "ClipImageEmbedder": + clip_model_name = image_embedder_config.params.model + + if clip_model_name == "ViT-L/14": + feature_extractor = CLIPImageProcessor() + image_encoder = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-large-patch14") + else: + raise NotImplementedError(f"Unknown CLIP checkpoint name in stable diffusion checkpoint {clip_model_name}") + + elif sd_clip_image_embedder_class == "FrozenOpenCLIPImageEmbedder": + feature_extractor = CLIPImageProcessor() + image_encoder = CLIPVisionModelWithProjection.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K") + else: + raise NotImplementedError( + f"Unknown CLIP image embedder class in stable diffusion checkpoint {sd_clip_image_embedder_class}" + ) + + return feature_extractor, image_encoder + + +def stable_unclip_noise_augmentor( + original_config, clip_stats_path: Optional[str] = None, device: Optional[str] = None +): + """ + Returns the noise augmentor for the img2img and txt2img unclip pipelines. + + If the noise augmentor config specifies a clip stats path, the `clip_stats_path` must be provided. + """ + noise_aug_config = original_config.model.params.noise_aug_config + noise_aug_class = noise_aug_config.target + noise_aug_class = noise_aug_class.split(".")[-1] + + if noise_aug_class == "CLIPEmbeddingNoiseAugmentation": + noise_aug_config = noise_aug_config.params + embedding_dim = noise_aug_config.timestep_dim + max_noise_level = noise_aug_config.noise_schedule_config.timesteps + beta_schedule = noise_aug_config.noise_schedule_config.beta_schedule + + noise_augmentor = NoiseAugmentor( + embedding_dim=embedding_dim, max_noise_level=max_noise_level, beta_schedule=beta_schedule + ) + + if "clip_stats_path" in noise_aug_config: + if clip_stats_path is None: + raise ValueError("This stable unclip config requires a `clip_stats_path`") + + clip_mean, clip_std = torch.load(clip_stats_path, map_location=device) + clip_mean = clip_mean[None, :] + clip_std = clip_std[None, :] + + clip_stats_state_dict = { + "mean": clip_mean, + "std": clip_std, + } + + noise_augmentor.load_state_dict(clip_stats_state_dict) + else: + raise NotImplementedError(f"Unknown noise augmentor class: {noise_aug_class}") + + return noise_augmentor + + def load_pipeline_from_original_stable_diffusion_ckpt( checkpoint_path: str, original_config_file: str = None, @@ -791,6 +906,9 @@ def load_pipeline_from_original_stable_diffusion_ckpt( upcast_attention: bool = None, device: str = None, from_safetensors: bool = False, + stable_unclip: Optional[str] = None, + stable_unclip_prior: str = "karlo", + clip_stats_path: Optional[str] = None, ) -> StableDiffusionPipeline: """ Load a Stable Diffusion pipeline object from a CompVis-style `.ckpt`/`.safetensors` file and (ideally) a `.yaml` @@ -967,16 +1085,63 @@ def load_pipeline_from_original_stable_diffusion_ckpt( if model_type == "FrozenOpenCLIPEmbedder": text_model = convert_open_clip_checkpoint(checkpoint) tokenizer = CLIPTokenizer.from_pretrained("stabilityai/stable-diffusion-2", subfolder="tokenizer") - pipe = StableDiffusionPipeline( - vae=vae, - text_encoder=text_model, - tokenizer=tokenizer, - unet=unet, - scheduler=scheduler, - safety_checker=None, - feature_extractor=None, - requires_safety_checker=False, - ) + + if stable_unclip is None: + pipe = StableDiffusionPipeline( + vae=vae, + text_encoder=text_model, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=None, + feature_extractor=None, + requires_safety_checker=False, + ) + else: + noise_augmentor = stable_unclip_noise_augmentor( + original_config, clip_stats_path=clip_stats_path, device=device + ) + + if stable_unclip == "img2img": + feature_extractor, image_encoder = stable_unclip_image_encoder(original_config) + + pipe = StableUnCLIPImg2ImgPipeline( + vae=vae, + unet=unet, + text_encoder=text_model, + tokenizer=tokenizer, + scheduler=scheduler, + image_encoder=image_encoder, + feature_extractor=feature_extractor, + noise_augmentor=noise_augmentor, + ) + elif stable_unclip == "txt2img": + if stable_unclip_prior == "karlo": + karlo_model = "kakaobrain/karlo-v1-alpha" + prior = PriorTransformer.from_pretrained(karlo_model, subfolder="prior") + + prior_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") + prior_text_model = CLIPTextModelWithProjection.from_pretrained("openai/clip-vit-large-patch14") + + prior_scheduler = UnCLIPScheduler.from_pretrained(karlo_model, subfolder="prior_scheduler") + prior_scheduler = DDPMScheduler.from_config(prior_scheduler.config) + else: + raise NotImplementedError(f"unknown prior for stable unclip model: {stable_unclip_prior}") + + pipe = StableUnCLIPPipeline( + vae=vae, + unet=unet, + tokenizer=tokenizer, + text_encoder=text_model, + prior_tokenizer=prior_tokenizer, + prior_text_encoder=prior_text_model, + scheduler=scheduler, + noise_augmentor=noise_augmentor, + prior=prior, + prior_scheduler=prior_scheduler, + ) + else: + raise NotImplementedError(f"unknown `stable_unclip` type: {stable_unclip}") elif model_type == "PaintByExample": vision_model = convert_paint_by_example_checkpoint(checkpoint) tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py new file mode 100644 index 000000000000..9365938b439b --- /dev/null +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py @@ -0,0 +1,814 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import torch +from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer +from transformers.models.clip.modeling_clip import CLIPTextModelOutput + +from ...models import AutoencoderKL, NoiseAugmentor, PriorTransformer, UNet2DConditionModel +from ...schedulers import DDIMScheduler, DDPMScheduler +from ...utils import is_accelerate_available, logging, randn_tensor, replace_example_docstring +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import StableUnCLIPPipeline + + >>> pipe = StableUnCLIPPipeline.from_pretrained( + ... "fusing/stable-unclip-2-1-l", torch_dtype=torch.float16 + ... ) # TODO update model path + >>> pipe = pipe.to("cuda") + + >>> prompt = "a photo of an astronaut riding a horse on mars" + >>> images = pipe(prompt).images + >>> images[0].save("astronaut_horse.png") + ``` +""" + + +class StableUnCLIPPipeline(DiffusionPipeline): + """ + Pipeline for text-to-image generation using stable unCLIP. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + prior_tokenizer ([`CLIPTokenizer`]): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + prior_text_encoder ([`CLIPTextModelWithProjection`]): + Frozen text-encoder. + prior ([`PriorTransformer`]): + The canonincal unCLIP prior to approximate the image embedding from the text embedding. + prior_scheduler ([`DDPMScheduler`]): + Scheduler used in the prior denoising process. + noise_augmentor ([`NoiseAugmentor`]): + Layer for adding noise to the predicted image embeddings. The amount of noise to add is determined by + `noise_level` in `StableUnCLIPPipeline.__call__`. See `NoiseAugmentor` for more details on how noise is + added. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`DDIMScheduler`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + """ + + # prior components + prior_tokenizer: CLIPTokenizer + prior_text_encoder: CLIPTextModelWithProjection + prior: PriorTransformer + prior_scheduler: DDPMScheduler + + # regular denoising components + noise_augmentor: NoiseAugmentor + tokenizer: CLIPTokenizer + text_encoder: CLIPTextModel + unet: UNet2DConditionModel + scheduler: DDIMScheduler + + vae: AutoencoderKL + + def __init__( + self, + # prior components + prior_tokenizer: CLIPTokenizer, + prior_text_encoder: CLIPTextModelWithProjection, + prior: PriorTransformer, + prior_scheduler: DDPMScheduler, + # regular denoising components + noise_augmentor: NoiseAugmentor, + tokenizer: CLIPTokenizer, + text_encoder: CLIPTextModelWithProjection, + unet: UNet2DConditionModel, + scheduler: DDIMScheduler, + # vae + vae: AutoencoderKL, + ): + super().__init__() + + self.register_modules( + prior_tokenizer=prior_tokenizer, + prior_text_encoder=prior_text_encoder, + prior=prior, + prior_scheduler=prior_scheduler, + noise_augmentor=noise_augmentor, + tokenizer=tokenizer, + text_encoder=text_encoder, + unet=unet, + scheduler=scheduler, + vae=vae, + ) + + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. + + When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several + steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + def enable_sequential_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's + models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only + when their specific submodule has its `forward` method called. + """ + if is_accelerate_available(): + from accelerate import cpu_offload + else: + raise ImportError("Please install accelerate via `pip install accelerate`") + + device = torch.device(f"cuda:{gpu_id}") + + # TODO: self.prior.post_process_latents is not covered by the offload hooks, so it fails if added to the list + models = [ + self.prior_text_encoder, + self.noise_augmentor, + self.text_encoder, + self.unet, + self.vae, + ] + for cpu_offloaded_model in models: + if cpu_offloaded_model is not None: + cpu_offload(cpu_offloaded_model, device) + + @property + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device + def _execution_device(self): + r""" + Returns the device on which the pipeline's models will be executed. After calling + `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module + hooks. + """ + if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"): + return self.device + for module in self.unet.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + + # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline._encode_prompt with _encode_prompt->_encode_prior_prompt, tokenizer->prior_tokenizer, text_encoder->prior_text_encoder + def _encode_prior_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + text_model_output: Optional[Union[CLIPTextModelOutput, Tuple]] = None, + text_attention_mask: Optional[torch.Tensor] = None, + ): + if text_model_output is None: + batch_size = len(prompt) if isinstance(prompt, list) else 1 + # get prompt text embeddings + text_inputs = self.prior_tokenizer( + prompt, + padding="max_length", + max_length=self.prior_tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + text_mask = text_inputs.attention_mask.bool().to(device) + + untruncated_ids = self.prior_tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.prior_tokenizer.batch_decode( + untruncated_ids[:, self.prior_tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.prior_tokenizer.model_max_length} tokens: {removed_text}" + ) + text_input_ids = text_input_ids[:, : self.prior_tokenizer.model_max_length] + + prior_text_encoder_output = self.prior_text_encoder(text_input_ids.to(device)) + + prompt_embeds = prior_text_encoder_output.text_embeds + prior_text_encoder_hidden_states = prior_text_encoder_output.last_hidden_state + + else: + batch_size = text_model_output[0].shape[0] + prompt_embeds, prior_text_encoder_hidden_states = text_model_output[0], text_model_output[1] + text_mask = text_attention_mask + + prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0) + prior_text_encoder_hidden_states = prior_text_encoder_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0) + + if do_classifier_free_guidance: + uncond_tokens = [""] * batch_size + + uncond_input = self.prior_tokenizer( + uncond_tokens, + padding="max_length", + max_length=self.prior_tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + uncond_text_mask = uncond_input.attention_mask.bool().to(device) + negative_prompt_embeds_prior_text_encoder_output = self.prior_text_encoder( + uncond_input.input_ids.to(device) + ) + + negative_prompt_embeds = negative_prompt_embeds_prior_text_encoder_output.text_embeds + uncond_prior_text_encoder_hidden_states = ( + negative_prompt_embeds_prior_text_encoder_output.last_hidden_state + ) + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + + seq_len = negative_prompt_embeds.shape[1] + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len) + + seq_len = uncond_prior_text_encoder_hidden_states.shape[1] + uncond_prior_text_encoder_hidden_states = uncond_prior_text_encoder_hidden_states.repeat( + 1, num_images_per_prompt, 1 + ) + uncond_prior_text_encoder_hidden_states = uncond_prior_text_encoder_hidden_states.view( + batch_size * num_images_per_prompt, seq_len, -1 + ) + uncond_text_mask = uncond_text_mask.repeat_interleave(num_images_per_prompt, dim=0) + + # done duplicates + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + prior_text_encoder_hidden_states = torch.cat( + [uncond_prior_text_encoder_hidden_states, prior_text_encoder_hidden_states] + ) + + text_mask = torch.cat([uncond_text_mask, text_mask]) + + return prompt_embeds, prior_text_encoder_hidden_states, text_mask + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. + Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + """ + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + prompt_embeds = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents).sample + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + height, + width, + callback_steps, + noise_level, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Please make sure to define only one of the two." + ) + + if prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + + if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + "Provide either `negative_prompt` or `negative_prompt_embeds`. Cannot leave both `negative_prompt` and `negative_prompt_embeds` undefined." + ) + + if prompt is not None and negative_prompt is not None: + if type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if noise_level < 0 or noise_level >= self.noise_augmentor.max_noise_level: + raise ValueError( + f"`noise_level` must be between 0 and {self.noise_augmentor.max_noise_level - 1}, inclusive." + ) + + def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * scheduler.init_noise_sigma + return latents + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + # regular denoising process args + prompt: Optional[Union[str, List[str]]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 20, + guidance_scale: float = 10.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: Optional[int] = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + noise_level: int = 0, + # prior args + prior_num_inference_steps: int = 25, + prior_guidance_scale: float = 4.0, + prior_latents: Optional[torch.FloatTensor] = None, + ): + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 20): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 10.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. + Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + noise_level (`int`, *optional*, defaults to `0`): + The amount of noise to add to the image embeddings. A higher `noise_level` increases the variance in + the final un-noised images. See `NoiseAugmentor` for how the noise is added. + prior_num_inference_steps (`int`, *optional*, defaults to 25): + The number of denoising steps in the prior denoising process. More denoising steps usually lead to a + higher quality image at the expense of slower inference. + prior_guidance_scale (`float`, *optional*, defaults to 4.0): + Guidance scale for the prior denoising process as defined in [Classifier-Free Diffusion + Guidance](https://arxiv.org/abs/2207.12598). `prior_guidance_scale` is defined as `w` of equation 2. of + [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + prior_latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + embedding generation in the prior denoising process. Can be used to tweak the same generation with + different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied + random `generator`. + + Examples: + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple`: [`~ pipeline_utils.ImagePipelineOutput `] if `return_dict` + is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images. + """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt=prompt, + height=height, + width=width, + callback_steps=callback_steps, + noise_level=noise_level, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + batch_size = batch_size * num_images_per_prompt + + device = self._execution_device + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + prior_do_classifier_free_guidance = prior_guidance_scale > 1.0 + + # 3. Encode input prompt + prior_prompt_embeds, prior_text_encoder_hidden_states, prior_text_mask = self._encode_prior_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=prior_do_classifier_free_guidance, + ) + + # 4. Prepare prior timesteps + self.prior_scheduler.set_timesteps(prior_num_inference_steps, device=device) + prior_timesteps_tensor = self.prior_scheduler.timesteps + + # 5. Prepare prior latent variables + embedding_dim = self.prior.config.embedding_dim + prior_latents = self.prepare_latents( + (batch_size, embedding_dim), + prior_prompt_embeds.dtype, + device, + generator, + prior_latents, + self.prior_scheduler, + ) + + # 6. Prior denoising loop + for i, t in enumerate(self.progress_bar(prior_timesteps_tensor)): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([prior_latents] * 2) if prior_do_classifier_free_guidance else prior_latents + + predicted_image_embedding = self.prior( + latent_model_input, + timestep=t, + proj_embedding=prior_prompt_embeds, + encoder_hidden_states=prior_text_encoder_hidden_states, + attention_mask=prior_text_mask, + ).predicted_image_embedding + + if prior_do_classifier_free_guidance: + predicted_image_embedding_uncond, predicted_image_embedding_text = predicted_image_embedding.chunk(2) + predicted_image_embedding = predicted_image_embedding_uncond + prior_guidance_scale * ( + predicted_image_embedding_text - predicted_image_embedding_uncond + ) + + prior_latents = self.prior_scheduler.step( + predicted_image_embedding, + timestep=t, + sample=prior_latents, + generator=generator, + ).prev_sample + + if callback is not None and i % callback_steps == 0: + callback(i, t, prior_latents) + + prior_latents = self.prior.post_process_latents(prior_latents) + + image_embeds = prior_latents + + # done prior + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 7. Encode input prompt + prompt_embeds = self._encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + ) + + # 8. Prepare image embeddings + image_embeds = self.noise_augmentor(image_embeds, noise_level=noise_level, generator=generator) + + if do_classifier_free_guidance: + negative_prompt_embeds = torch.zeros_like(image_embeds) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + image_embeds = torch.cat([negative_prompt_embeds, image_embeds]) + + # 9. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 10. Prepare latent variables + num_channels_latents = self.unet.in_channels + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + latents = self.prepare_latents( + shape=shape, + dtype=prompt_embeds.dtype, + device=device, + generator=generator, + latents=latents, + scheduler=self.scheduler, + ) + + # 11. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 12. Denoising loop + for i, t in enumerate(self.progress_bar(timesteps)): + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + class_labels=image_embeds, + cross_attention_kwargs=cross_attention_kwargs, + ).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + # 13. Post-processing + image = self.decode_latents(latents) + + # 14. Convert to PIL + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py new file mode 100644 index 000000000000..f2e368eb4e9d --- /dev/null +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py @@ -0,0 +1,727 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Union + +import PIL +import torch +from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection + +from diffusers.utils.import_utils import is_accelerate_available + +from ...models import AutoencoderKL, NoiseAugmentor, UNet2DConditionModel +from ...schedulers import DDIMScheduler +from ...utils import logging, randn_tensor, replace_example_docstring +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import requests + >>> import torch + >>> from PIL import Image + >>> from io import BytesIO + + >>> from diffusers import StableUnCLIPImg2ImgPipeline + + >>> pipe = StableUnCLIPImg2ImgPipeline.from_pretrained( + ... "fusing/stable-unclip-2-1-l-img2img", torch_dtype=torch.float16 + ... ) # TODO update model path + >>> pipe = pipe.to("cuda") + + >>> url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg" + + >>> response = requests.get(url) + >>> init_image = Image.open(BytesIO(response.content)).convert("RGB") + >>> init_image = init_image.resize((768, 512)) + + >>> prompt = "A fantasy landscape, trending on artstation" + + >>> images = pipe(prompt, init_image).images + >>> images[0].save("fantasy_landscape.png") + ``` +""" + + +class StableUnCLIPImg2ImgPipeline(DiffusionPipeline): + """ + Pipeline for text-guided image to image generation using stable unCLIP. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + feature_extractor ([`CLIPFeatureExtractor`]): + Feature extractor for image pre-processing before being encoded. + image_encoder ([`CLIPVisionModelWithProjection`]): + CLIP vision model for encoding images. + noise_augmentor ([`NoiseAugmentor`]): + Layer for adding noise to the predicted image embeddings. The amount of noise to add is determined by + `noise_level` in `StableUnCLIPPipeline.__call__`. See `NoiseAugmentor` for more details on how noise is + added. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`DDIMScheduler`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + """ + + # image encoding components + feature_extractor: CLIPFeatureExtractor + image_encoder: CLIPVisionModelWithProjection + + # regular denoising components + noise_augmentor: NoiseAugmentor + tokenizer: CLIPTokenizer + text_encoder: CLIPTextModel + unet: UNet2DConditionModel + scheduler: DDIMScheduler + + vae: AutoencoderKL + + def __init__( + self, + # image encoding components + feature_extractor: CLIPFeatureExtractor, + image_encoder: CLIPVisionModelWithProjection, + # regular denoising components + noise_augmentor: NoiseAugmentor, + tokenizer: CLIPTokenizer, + text_encoder: CLIPTextModel, + unet: UNet2DConditionModel, + scheduler: DDIMScheduler, + # vae + vae: AutoencoderKL, + ): + super().__init__() + + self.register_modules( + feature_extractor=feature_extractor, + image_encoder=image_encoder, + noise_augmentor=noise_augmentor, + tokenizer=tokenizer, + text_encoder=text_encoder, + unet=unet, + scheduler=scheduler, + vae=vae, + ) + + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. + + When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several + steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + def enable_sequential_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's + models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only + when their specific submodule has its `forward` method called. + """ + if is_accelerate_available(): + from accelerate import cpu_offload + else: + raise ImportError("Please install accelerate via `pip install accelerate`") + + device = torch.device(f"cuda:{gpu_id}") + + models = [ + self.image_encoder, + self.noise_augmentor, + self.text_encoder, + self.unet, + self.vae, + ] + for cpu_offloaded_model in models: + if cpu_offloaded_model is not None: + cpu_offload(cpu_offloaded_model, device) + + @property + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device + def _execution_device(self): + r""" + Returns the device on which the pipeline's models will be executed. After calling + `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module + hooks. + """ + if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"): + return self.device + for module in self.unet.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. + Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + """ + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + prompt_embeds = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + return prompt_embeds + + def _encode_image( + self, + image, + device, + batch_size, + num_images_per_prompt, + do_classifier_free_guidance, + noise_level, + generator, + image_embeds, + ): + dtype = next(self.image_encoder.parameters()).dtype + + if isinstance(image, PIL.Image.Image): + # the image embedding should repeated so it matches the total batch size of the prompt + repeat_by = batch_size + else: + # assume the image input is already properly batched and just needs to be repeated so + # it matches the num_images_per_prompt. + # + # NOTE(will) this is probably missing a few number of side cases. I.e. batched/non-batched + # `image_embeds`. If those happen to be common use cases, let's think harder about + # what the expected dimensions of inputs should be and how we handle the encoding. + repeat_by = num_images_per_prompt + + if not image_embeds: + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(images=image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + image_embeds = self.image_encoder(image).image_embeds + + image_embeds = self.noise_augmentor(image_embeds, noise_level=noise_level, generator=generator) + + # duplicate image embeddings for each generation per prompt, using mps friendly method + image_embeds = image_embeds.unsqueeze(1) + bs_embed, seq_len, _ = image_embeds.shape + image_embeds = image_embeds.repeat(1, repeat_by, 1) + image_embeds = image_embeds.view(bs_embed * repeat_by, seq_len, -1) + image_embeds = image_embeds.squeeze(1) + + if do_classifier_free_guidance: + negative_prompt_embeds = torch.zeros_like(image_embeds) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + image_embeds = torch.cat([negative_prompt_embeds, image_embeds]) + + return image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents).sample + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + image, + height, + width, + callback_steps, + noise_level, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + image_embeds=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Please make sure to define only one of the two." + ) + + if prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + + if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + "Provide either `negative_prompt` or `negative_prompt_embeds`. Cannot leave both `negative_prompt` and `negative_prompt_embeds` undefined." + ) + + if prompt is not None and negative_prompt is not None: + if type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if noise_level < 0 or noise_level >= self.noise_augmentor.max_noise_level: + raise ValueError( + f"`noise_level` must be between 0 and {self.noise_augmentor.max_noise_level - 1}, inclusive." + ) + + if image is not None and image_embeds is not None: + raise ValueError( + "Provide either `image` or `image_embeds`. Please make sure to define only one of the two." + ) + + if image is None and image_embeds is None: + raise ValueError( + "Provide either `image` or `image_embeds`. Cannot leave both `image` and `image_embeds` undefined." + ) + + if image is not None: + if ( + not isinstance(image, torch.Tensor) + and not isinstance(image, PIL.Image.Image) + and not isinstance(image, list) + ): + raise ValueError( + "`image` has to be of type `torch.FloatTensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is" + f" {type(image)}" + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + image: Union[torch.FloatTensor, PIL.Image.Image] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 20, + guidance_scale: float = 10, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: Optional[int] = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + noise_level: int = 0, + image_embeds: Optional[torch.FloatTensor] = None, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + image (`torch.FloatTensor` or `PIL.Image.Image`): + `Image`, or tensor representing an image batch. The image will be encoded to its CLIP embedding which + the unet will be conditioned on. Note that the image is _not_ encoded by the vae and then used as the + latents in the denoising process such as in the standard stable diffusion text guided image variation + process. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 20): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 10.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. + Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + noise_level (`int`, *optional*, defaults to `0`): + The amount of noise to add to the image embeddings. A higher `noise_level` increases the variance in + the final un-noised images. See `NoiseAugmentor` for how the noise is added. + image_embeds (`torch.FloatTensor`, *optional*): + Pre-generated CLIP embeddings to condition the unet on. Note that these are not latents to be used in + the denoising process. If you want to provide pre-generated latents, pass them to `__call__` as + `latents`. + + Examples: + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple`: [`~ pipeline_utils.ImagePipelineOutput `] if `return_dict` + is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images. + """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt=prompt, + image=image, + height=height, + width=width, + callback_steps=callback_steps, + noise_level=noise_level, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + image_embeds=image_embeds, + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + batch_size = batch_size * num_images_per_prompt + + device = self._execution_device + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds = self._encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + ) + + # 4. Encoder input image + noise_level = torch.tensor([noise_level], device=device) + image_embeds = self._encode_image( + image=image, + device=device, + batch_size=batch_size, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + noise_level=noise_level, + generator=generator, + image_embeds=image_embeds, + ) + + # 5. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 6. Prepare latent variables + num_channels_latents = self.unet.in_channels + latents = self.prepare_latents( + batch_size=batch_size, + num_channels_latents=num_channels_latents, + height=height, + width=width, + dtype=prompt_embeds.dtype, + device=device, + generator=generator, + latents=latents, + ) + + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 8. Denoising loop + for i, t in enumerate(self.progress_bar(timesteps)): + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + class_labels=image_embeds, + cross_attention_kwargs=cross_attention_kwargs, + ).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + # 9. Post-processing + image = self.decode_latents(latents) + + # 10. Convert to PIL + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py index 806875a9a507..83b1e3f20378 100644 --- a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py @@ -172,7 +172,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config for resnet blocks, see [`~models.resnet.ResnetBlockFlat`]. Choose from `default` or `scale_shift`. class_embed_type (`str`, *optional*, defaults to None): The type of class embedding to use which is ultimately - summed with the time embeddings. Choose from `None`, `"timestep"`, or `"identity"`. + summed with the time embeddings. Choose from `None`, `"timestep"`, `"identity"`, or `"projection"`. num_class_embeds (`int`, *optional*, defaults to None): Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing class conditioning with `class_embed_type` equal to `None`. @@ -183,7 +183,9 @@ class conditioning with `class_embed_type` equal to `None`. time_cond_proj_dim (`int`, *optional*, default to `None`): The dimension of `cond_proj` layer in timestep embedding. conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. - conv_out_kernel (`int`, *optional*, default to `3`): the Kernel size of `conv_out` layer. + conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer. + projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when + using the "projection" `class_embed_type`. Required when using the "projection" `class_embed_type`. """ _supports_gradient_checkpointing = True @@ -231,6 +233,7 @@ def __init__( time_cond_proj_dim: Optional[int] = None, conv_in_kernel: int = 3, conv_out_kernel: int = 3, + projection_class_embeddings_input_dim: Optional[int] = None, ): super().__init__() @@ -276,6 +279,19 @@ def __init__( self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) elif class_embed_type == "identity": self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) + elif class_embed_type == "projection": + if projection_class_embeddings_input_dim is None: + raise ValueError( + "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set" + ) + # The projection `class_embed_type` is the same as the timestep `class_embed_type` except + # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings + # 2. it projects from an arbitrary input dimension. + # + # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations. + # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings. + # As a result, `TimestepEmbedding` can be passed arbitrary vectors. + self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) else: self.class_embedding = None diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index 9d8aa6fa5b2f..cd3884844bd0 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -119,6 +119,7 @@ def __init__( variance_type: str = "fixed_small", clip_sample: bool = True, prediction_type: str = "epsilon", + clip_sample_range: Optional[float] = 1.0, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) @@ -284,7 +285,9 @@ def step( # 3. Clip "predicted x_0" if self.config.clip_sample: - pred_original_sample = torch.clamp(pred_original_sample, -1, 1) + pred_original_sample = torch.clamp( + pred_original_sample, -self.config.clip_sample_range, self.config.clip_sample_range + ) # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 546992bc436e..1f67ad597766 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -32,6 +32,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class NoiseAugmentor(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class PriorTransformer(metaclass=DummyObject): _backends = ["torch"] diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 79755c27e6fe..3aa791283d20 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -227,6 +227,36 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class StableUnCLIPImg2ImgPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class StableUnCLIPPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class UnCLIPImageVariationPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] diff --git a/tests/pipelines/stable_unclip/__init__.py b/tests/pipelines/stable_unclip/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/pipelines/stable_unclip/test_stable_unclip.py b/tests/pipelines/stable_unclip/test_stable_unclip.py new file mode 100644 index 000000000000..74099d646407 --- /dev/null +++ b/tests/pipelines/stable_unclip/test_stable_unclip.py @@ -0,0 +1,219 @@ +import gc +import unittest + +import torch +from transformers import CLIPTextConfig, CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer + +from diffusers import ( + AutoencoderKL, + DDIMScheduler, + DDPMScheduler, + NoiseAugmentor, + PriorTransformer, + StableUnCLIPPipeline, + UNet2DConditionModel, +) +from diffusers.utils.testing_utils import load_numpy, require_torch_gpu, slow, torch_device + +from ...test_pipelines_common import PipelineTesterMixin, assert_mean_pixel_difference + + +class StableUnCLIPPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = StableUnCLIPPipeline + + # TODO(will) Expected attn_bias.stride(1) == 0 to be true, but got false + test_xformers_attention = False + + def get_dummy_components(self): + embedder_hidden_size = 32 + embedder_projection_dim = embedder_hidden_size + + # prior components + + torch.manual_seed(0) + prior_tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + torch.manual_seed(0) + prior_text_encoder = CLIPTextModelWithProjection( + CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=embedder_hidden_size, + projection_dim=embedder_projection_dim, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=5, + pad_token_id=1, + vocab_size=1000, + ) + ) + + torch.manual_seed(0) + prior = PriorTransformer( + num_attention_heads=2, + attention_head_dim=12, + embedding_dim=embedder_projection_dim, + num_layers=1, + ) + + torch.manual_seed(0) + prior_scheduler = DDPMScheduler( + variance_type="fixed_small_log", + prediction_type="sample", + num_train_timesteps=1000, + clip_sample=True, + clip_sample_range=5.0, + beta_schedule="squaredcos_cap_v2", + ) + + # regular denoising components + + torch.manual_seed(0) + noise_augmentor = NoiseAugmentor(embedding_dim=embedder_hidden_size, beta_schedule="squaredcos_cap_v2") + + torch.manual_seed(0) + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + torch.manual_seed(0) + text_encoder = CLIPTextModel( + CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=embedder_hidden_size, + projection_dim=32, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=5, + pad_token_id=1, + vocab_size=1000, + ) + ) + + torch.manual_seed(0) + unet = UNet2DConditionModel( + sample_size=32, + in_channels=4, + out_channels=4, + down_block_types=("CrossAttnDownBlock2D", "DownBlock2D"), + up_block_types=("UpBlock2D", "CrossAttnUpBlock2D"), + block_out_channels=(32, 64), + attention_head_dim=(2, 4), + class_embed_type="projection", + # The class embeddings are the noise augmented image embeddings. + # I.e. the image embeddings concated with the noised embeddings of the same dimension + projection_class_embeddings_input_dim=embedder_projection_dim * 2, + cross_attention_dim=embedder_hidden_size, + layers_per_block=1, + upcast_attention=True, + use_linear_projection=True, + ) + + torch.manual_seed(0) + scheduler = DDIMScheduler( + beta_schedule="scaled_linear", + beta_start=0.00085, + beta_end=0.012, + prediction_type="v_prediction", + set_alpha_to_one=False, + steps_offset=1, + ) + + torch.manual_seed(0) + vae = AutoencoderKL() + + components = { + # prior components + "prior_tokenizer": prior_tokenizer, + "prior_text_encoder": prior_text_encoder, + "prior": prior, + "prior_scheduler": prior_scheduler, + # regular denoising components + "noise_augmentor": noise_augmentor, + "tokenizer": tokenizer, + "text_encoder": text_encoder, + "unet": unet, + "scheduler": scheduler, + "vae": vae, + } + + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + inputs = { + "prompt": "A painting of a squirrel eating a burger", + "generator": generator, + "num_inference_steps": 2, + "prior_num_inference_steps": 2, + "output_type": "numpy", + } + return inputs + + # Overriding PipelineTesterMixin::test_attention_slicing_forward_pass + # because UnCLIP GPU undeterminism requires a looser check. + def test_attention_slicing_forward_pass(self): + test_max_difference = torch_device == "cpu" + + self._test_attention_slicing_forward_pass(test_max_difference=test_max_difference) + + # Overriding PipelineTesterMixin::test_inference_batch_single_identical + # because UnCLIP undeterminism requires a looser check. + def test_inference_batch_single_identical(self): + test_max_difference = torch_device in ["cpu", "mps"] + + self._test_inference_batch_single_identical(test_max_difference=test_max_difference) + + +@slow +@require_torch_gpu +class StableUnCLIPPipelineIntegrationTests(unittest.TestCase): + def tearDown(self): + # clean up the VRAM after each test + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + def test_stable_unclip(self): + expected_image = load_numpy( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/stable_unclip/stable_unclip_2_1_l_anime_turtle_fp16.npy" + ) + + pipe = StableUnCLIPPipeline.from_pretrained("fusing/stable-unclip-2-1-l", torch_dtype=torch.float16) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + generator = torch.Generator(device="cpu").manual_seed(0) + output = pipe("anime turle", generator=generator, output_type="np") + + image = output.images[0] + + assert image.shape == (768, 768, 3) + + assert_mean_pixel_difference(image, expected_image) + + def test_stable_unclip_pipeline_with_sequential_cpu_offloading(self): + torch.cuda.empty_cache() + torch.cuda.reset_max_memory_allocated() + torch.cuda.reset_peak_memory_stats() + + pipe = StableUnCLIPPipeline.from_pretrained("fusing/stable-unclip-2-1-l", torch_dtype=torch.float16) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + pipe.enable_attention_slicing() + pipe.enable_sequential_cpu_offload() + + _ = pipe( + "anime turtle", + prior_num_inference_steps=2, + num_inference_steps=2, + output_type="np", + ) + + mem_bytes = torch.cuda.max_memory_allocated() + # make sure that less than 7 GB is allocated + assert mem_bytes < 7 * 10**9 diff --git a/tests/pipelines/stable_unclip/test_stable_unclip_img2img.py b/tests/pipelines/stable_unclip/test_stable_unclip_img2img.py new file mode 100644 index 000000000000..39c988cc371a --- /dev/null +++ b/tests/pipelines/stable_unclip/test_stable_unclip_img2img.py @@ -0,0 +1,244 @@ +import gc +import random +import unittest + +import torch +from transformers import ( + CLIPFeatureExtractor, + CLIPTextConfig, + CLIPTextModel, + CLIPTokenizer, + CLIPVisionConfig, + CLIPVisionModelWithProjection, +) + +from diffusers import AutoencoderKL, DDIMScheduler, NoiseAugmentor, StableUnCLIPImg2ImgPipeline, UNet2DConditionModel +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.utils.import_utils import is_xformers_available +from diffusers.utils.testing_utils import floats_tensor, load_image, load_numpy, require_torch_gpu, slow, torch_device + +from ...test_pipelines_common import PipelineTesterMixin, assert_mean_pixel_difference + + +class StableUnCLIPImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = StableUnCLIPImg2ImgPipeline + + def get_dummy_components(self): + embedder_hidden_size = 32 + embedder_projection_dim = embedder_hidden_size + + # image encoding components + + feature_extractor = CLIPFeatureExtractor(crop_size=32, size=32) + + image_encoder = CLIPVisionModelWithProjection( + CLIPVisionConfig( + hidden_size=embedder_hidden_size, + projection_dim=embedder_projection_dim, + num_hidden_layers=5, + num_attention_heads=4, + image_size=32, + intermediate_size=37, + patch_size=1, + ) + ) + + # regular denoising components + + torch.manual_seed(0) + noise_augmentor = NoiseAugmentor(embedding_dim=embedder_hidden_size, beta_schedule="squaredcos_cap_v2") + + torch.manual_seed(0) + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + torch.manual_seed(0) + text_encoder = CLIPTextModel( + CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=embedder_hidden_size, + projection_dim=32, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=5, + pad_token_id=1, + vocab_size=1000, + ) + ) + + torch.manual_seed(0) + unet = UNet2DConditionModel( + sample_size=32, + in_channels=4, + out_channels=4, + down_block_types=("CrossAttnDownBlock2D", "DownBlock2D"), + up_block_types=("UpBlock2D", "CrossAttnUpBlock2D"), + block_out_channels=(32, 64), + attention_head_dim=(2, 4), + class_embed_type="projection", + # The class embeddings are the noise augmented image embeddings. + # I.e. the image embeddings concated with the noised embeddings of the same dimension + projection_class_embeddings_input_dim=embedder_projection_dim * 2, + cross_attention_dim=embedder_hidden_size, + layers_per_block=1, + upcast_attention=True, + use_linear_projection=True, + ) + + torch.manual_seed(0) + scheduler = DDIMScheduler( + beta_schedule="scaled_linear", + beta_start=0.00085, + beta_end=0.012, + prediction_type="v_prediction", + set_alpha_to_one=False, + steps_offset=1, + ) + + torch.manual_seed(0) + vae = AutoencoderKL() + + components = { + # image encoding components + "feature_extractor": feature_extractor, + "image_encoder": image_encoder, + # regular denoising components + "noise_augmentor": noise_augmentor, + "tokenizer": tokenizer, + "text_encoder": text_encoder, + "unet": unet, + "scheduler": scheduler, + "vae": vae, + } + + return components + + def get_dummy_inputs(self, device, seed=0, pil_image=True): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + + input_image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device) + + if pil_image: + input_image = input_image * 0.5 + 0.5 + input_image = input_image.clamp(0, 1) + input_image = input_image.cpu().permute(0, 2, 3, 1).float().numpy() + input_image = DiffusionPipeline.numpy_to_pil(input_image)[0] + + return { + "prompt": "An anime racoon running a marathon", + "image": input_image, + "generator": generator, + "num_inference_steps": 2, + "output_type": "np", + } + + # Overriding PipelineTesterMixin::test_attention_slicing_forward_pass + # because GPU undeterminism requires a looser check. + def test_attention_slicing_forward_pass(self): + test_max_difference = torch_device in ["cpu", "mps"] + + self._test_attention_slicing_forward_pass(test_max_difference=test_max_difference) + + # Overriding PipelineTesterMixin::test_inference_batch_single_identical + # because undeterminism requires a looser check. + def test_inference_batch_single_identical(self): + test_max_difference = torch_device in ["cpu", "mps"] + + self._test_inference_batch_single_identical(test_max_difference=test_max_difference) + + @unittest.skipIf( + torch_device != "cuda" or not is_xformers_available(), + reason="XFormers attention is only available with CUDA and `xformers` installed", + ) + def test_xformers_attention_forwardGenerator_pass(self): + self._test_xformers_attention_forwardGenerator_pass(test_max_difference=False) + + +@slow +@require_torch_gpu +class StableUnCLIPImg2ImgPipelineIntegrationTests(unittest.TestCase): + def tearDown(self): + # clean up the VRAM after each test + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + def test_stable_unclip_l_img2img(self): + input_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/stable_unclip/turtle.png" + ) + + expected_image = load_numpy( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/stable_unclip/stable_unclip_2_1_l_img2img_anime_turtle_fp16.npy" + ) + + pipe = StableUnCLIPImg2ImgPipeline.from_pretrained( + "fusing/stable-unclip-2-1-l-img2img", torch_dtype=torch.float16 + ) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + generator = torch.Generator(device="cpu").manual_seed(0) + output = pipe("anime turle", image=input_image, generator=generator, output_type="np") + + image = output.images[0] + + assert image.shape == (768, 768, 3) + + assert_mean_pixel_difference(image, expected_image) + + def test_stable_unclip_h_img2img(self): + input_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/stable_unclip/turtle.png" + ) + + expected_image = load_numpy( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/stable_unclip/stable_unclip_2_1_h_img2img_anime_turtle_fp16.npy" + ) + + pipe = StableUnCLIPImg2ImgPipeline.from_pretrained( + "fusing/stable-unclip-2-1-h-img2img", torch_dtype=torch.float16 + ) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + generator = torch.Generator(device="cpu").manual_seed(0) + output = pipe("anime turle", image=input_image, generator=generator, output_type="np") + + image = output.images[0] + + assert image.shape == (768, 768, 3) + + assert_mean_pixel_difference(image, expected_image) + + def test_stable_unclip_img2img_pipeline_with_sequential_cpu_offloading(self): + input_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/stable_unclip/turtle.png" + ) + + torch.cuda.empty_cache() + torch.cuda.reset_max_memory_allocated() + torch.cuda.reset_peak_memory_stats() + + pipe = StableUnCLIPImg2ImgPipeline.from_pretrained( + "fusing/stable-unclip-2-1-h-img2img", torch_dtype=torch.float16 + ) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + pipe.enable_attention_slicing() + pipe.enable_sequential_cpu_offload() + + _ = pipe( + "anime turtle", + image=input_image, + num_inference_steps=2, + output_type="np", + ) + + mem_bytes = torch.cuda.max_memory_allocated() + # make sure that less than 7 GB is allocated + assert mem_bytes < 7 * 10**9 diff --git a/tests/test_pipelines_common.py b/tests/test_pipelines_common.py index a1d3122f875c..fef0fe6fa624 100644 --- a/tests/test_pipelines_common.py +++ b/tests/test_pipelines_common.py @@ -517,6 +517,9 @@ def test_cpu_offload_forward_pass(self): reason="XFormers attention is only available with CUDA and `xformers` installed", ) def test_xformers_attention_forwardGenerator_pass(self): + self._test_xformers_attention_forwardGenerator_pass() + + def _test_xformers_attention_forwardGenerator_pass(self, test_max_difference=True): if not self.test_xformers_attention: return @@ -532,8 +535,11 @@ def test_xformers_attention_forwardGenerator_pass(self): inputs = self.get_dummy_inputs(torch_device) output_with_offload = pipe(**inputs)[0] - max_diff = np.abs(output_with_offload - output_without_offload).max() - self.assertLess(max_diff, 1e-4, "XFormers attention should not affect the inference results") + if test_max_difference: + max_diff = np.abs(output_with_offload - output_without_offload).max() + self.assertLess(max_diff, 1e-4, "XFormers attention should not affect the inference results") + + assert_mean_pixel_difference(output_with_offload[0], output_without_offload[0]) def test_progress_bar(self): components = self.get_dummy_components() From f203e1abb6d719fb0fff9af00d326d0eebc271ec Mon Sep 17 00:00:00 2001 From: William Berman Date: Wed, 8 Feb 2023 19:31:24 -0800 Subject: [PATCH 02/14] Add docs for when clip_stats_path is specified --- scripts/convert_original_stable_diffusion_to_diffusers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/convert_original_stable_diffusion_to_diffusers.py b/scripts/convert_original_stable_diffusion_to_diffusers.py index 255ce9411394..890ad870812d 100644 --- a/scripts/convert_original_stable_diffusion_to_diffusers.py +++ b/scripts/convert_original_stable_diffusion_to_diffusers.py @@ -117,7 +117,7 @@ parser.add_argument( "--clip_stats_path", type=str, - help="Path to the clip stats file. Only required with certain stable unclip models.", + help="Path to the clip stats file. Only required if the stable unclip model's config specifies `model.params.noise_aug_config.params.clip_stats_path`.", required=False, ) args = parser.parse_args() From b25971d0c1523abcbb6aadc3bb7988c7dfa495fe Mon Sep 17 00:00:00 2001 From: Will Berman Date: Thu, 9 Feb 2023 12:32:04 -0800 Subject: [PATCH 03/14] Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py Co-authored-by: Patrick von Platen --- .../pipelines/stable_diffusion/pipeline_stable_unclip.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py index 9365938b439b..25c54c90f002 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py @@ -636,7 +636,7 @@ def __call__( Examples: Returns: - [`~pipelines.ImagePipelineOutput`] or `tuple`: [`~ pipeline_utils.ImagePipelineOutput `] if `return_dict` + [`~pipelines.ImagePipelineOutput`] or `tuple`: [`~ pipeline_utils.ImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images. """ # 0. Default height and width to unet From 665e8181a365dbba4dbfc271be2809a3ab718db2 Mon Sep 17 00:00:00 2001 From: Will Berman Date: Thu, 9 Feb 2023 12:32:13 -0800 Subject: [PATCH 04/14] Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py Co-authored-by: Patrick von Platen --- .../pipelines/stable_diffusion/pipeline_stable_unclip.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py index 25c54c90f002..24e6569b7f55 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py @@ -637,7 +637,7 @@ def __call__( Returns: [`~pipelines.ImagePipelineOutput`] or `tuple`: [`~ pipeline_utils.ImagePipelineOutput`] if `return_dict` - is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images. + is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated images. """ # 0. Default height and width to unet height = height or self.unet.config.sample_size * self.vae_scale_factor From ef80c29b25af81b61e70d91649ba6f36efc7521d Mon Sep 17 00:00:00 2001 From: Will Berman Date: Thu, 9 Feb 2023 12:33:57 -0800 Subject: [PATCH 05/14] Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py Co-authored-by: Patrick von Platen --- .../stable_diffusion/pipeline_stable_unclip_img2img.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py index f2e368eb4e9d..105a183626a1 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py @@ -607,7 +607,7 @@ def __call__( Examples: Returns: - [`~pipelines.ImagePipelineOutput`] or `tuple`: [`~ pipeline_utils.ImagePipelineOutput `] if `return_dict` + [`~pipelines.ImagePipelineOutput`] or `tuple`: [`~ pipeline_utils.ImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images. """ # 0. Default height and width to unet From 338f299fd53b2af0c5771ebd111fbc596ca2f56f Mon Sep 17 00:00:00 2001 From: Will Berman Date: Thu, 9 Feb 2023 12:34:12 -0800 Subject: [PATCH 06/14] Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py Co-authored-by: Patrick von Platen --- .../stable_diffusion/pipeline_stable_unclip_img2img.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py index 105a183626a1..ce5ed0381738 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py @@ -608,7 +608,7 @@ def __call__( Returns: [`~pipelines.ImagePipelineOutput`] or `tuple`: [`~ pipeline_utils.ImagePipelineOutput`] if `return_dict` - is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images. + is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated images. """ # 0. Default height and width to unet height = height or self.unet.config.sample_size * self.vae_scale_factor From a1df31ae25310f47e57bfb6b69d63df152af30bd Mon Sep 17 00:00:00 2001 From: William Berman Date: Thu, 9 Feb 2023 12:44:10 -0800 Subject: [PATCH 07/14] prepare_latents # Copied from re: @patrickvonplaten --- .../pipelines/stable_diffusion/pipeline_stable_unclip.py | 6 +++--- .../stable_diffusion/pipeline_stable_unclip_img2img.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py index 24e6569b7f55..38a80a6e923d 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py @@ -516,6 +516,7 @@ def check_inputs( f"`noise_level` must be between 0 and {self.noise_augmentor.max_noise_level - 1}, inclusive." ) + # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) @@ -524,7 +525,6 @@ def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") latents = latents.to(device) - # scale the initial noise by the standard deviation required by the scheduler latents = latents * scheduler.init_noise_sigma return latents @@ -636,8 +636,8 @@ def __call__( Examples: Returns: - [`~pipelines.ImagePipelineOutput`] or `tuple`: [`~ pipeline_utils.ImagePipelineOutput`] if `return_dict` - is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated images. + [`~pipelines.ImagePipelineOutput`] or `tuple`: [`~ pipeline_utils.ImagePipelineOutput`] if `return_dict` is + True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated images. """ # 0. Default height and width to unet height = height or self.unet.config.sample_size * self.vae_scale_factor diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py index ce5ed0381738..59d50d34b114 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py @@ -607,8 +607,8 @@ def __call__( Examples: Returns: - [`~pipelines.ImagePipelineOutput`] or `tuple`: [`~ pipeline_utils.ImagePipelineOutput`] if `return_dict` - is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated images. + [`~pipelines.ImagePipelineOutput`] or `tuple`: [`~ pipeline_utils.ImagePipelineOutput`] if `return_dict` is + True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated images. """ # 0. Default height and width to unet height = height or self.unet.config.sample_size * self.vae_scale_factor From ee5f35e33fbbc3aa4ecac20a9fc92f4a00d64418 Mon Sep 17 00:00:00 2001 From: William Berman Date: Thu, 9 Feb 2023 14:41:07 -0800 Subject: [PATCH 08/14] NoiseAugmentor->ImageNormalizer --- ..._original_stable_diffusion_to_diffusers.py | 2 +- src/diffusers/__init__.py | 1 - src/diffusers/models/__init__.py | 1 - src/diffusers/models/noise_augmentor.py | 167 ------------------ .../pipelines/stable_diffusion/__init__.py | 1 + .../stable_diffusion/convert_from_ckpt.py | 55 +++--- .../pipeline_stable_unclip.py | 86 +++++++-- .../pipeline_stable_unclip_img2img.py | 88 +++++++-- .../stable_unclip_image_normalizer.py | 46 +++++ src/diffusers/utils/dummy_pt_objects.py | 15 -- .../stable_unclip/test_stable_unclip.py | 9 +- .../test_stable_unclip_img2img.py | 10 +- 12 files changed, 241 insertions(+), 240 deletions(-) delete mode 100644 src/diffusers/models/noise_augmentor.py create mode 100644 src/diffusers/pipelines/stable_diffusion/stable_unclip_image_normalizer.py diff --git a/scripts/convert_original_stable_diffusion_to_diffusers.py b/scripts/convert_original_stable_diffusion_to_diffusers.py index 890ad870812d..ed82d08c971e 100644 --- a/scripts/convert_original_stable_diffusion_to_diffusers.py +++ b/scripts/convert_original_stable_diffusion_to_diffusers.py @@ -135,7 +135,7 @@ from_safetensors=args.from_safetensors, device=args.device, stable_unclip=args.stable_unclip, - stable_unclip_prior=args.stabe_unclip_prior, + stable_unclip_prior=args.stable_unclip_prior, clip_stats_path=args.clip_stats_path, ) pipe.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 80c0caef97af..76b03a6d019d 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -35,7 +35,6 @@ from .models import ( AutoencoderKL, ModelMixin, - NoiseAugmentor, PriorTransformer, Transformer2DModel, UNet1DModel, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 29c64652c00d..474b8412560e 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -19,7 +19,6 @@ from .autoencoder_kl import AutoencoderKL from .dual_transformer_2d import DualTransformer2DModel from .modeling_utils import ModelMixin - from .noise_augmentor import NoiseAugmentor from .prior_transformer import PriorTransformer from .transformer_2d import Transformer2DModel from .unet_1d import UNet1DModel diff --git a/src/diffusers/models/noise_augmentor.py b/src/diffusers/models/noise_augmentor.py deleted file mode 100644 index 0db7217070e3..000000000000 --- a/src/diffusers/models/noise_augmentor.py +++ /dev/null @@ -1,167 +0,0 @@ -# Copyright 2022 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import math -from typing import List, Optional, Union - -import numpy as np -import torch -from torch import nn - -from ..configuration_utils import ConfigMixin, register_to_config -from ..utils import randn_tensor -from .embeddings import Timesteps -from .modeling_utils import ModelMixin - - -# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar -def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> torch.Tensor: - """ - Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of - (1-beta) over time from t = [0,1]. - - Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up - to that part of the diffusion process. - - - Args: - num_diffusion_timesteps (`int`): the number of betas to produce. - max_beta (`float`): the maximum beta to use; use values lower than 1 to - prevent singularities. - - Returns: - betas (`np.ndarray`): the betas used by the scheduler to step the model outputs - """ - - def alpha_bar(time_step): - return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2 - - betas = [] - for i in range(num_diffusion_timesteps): - t1 = i / num_diffusion_timesteps - t2 = (i + 1) / num_diffusion_timesteps - betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) - return torch.tensor(betas, dtype=torch.float32) - - -class NoiseAugmentor(ModelMixin, ConfigMixin): - """ - NoiseAugmentor is used in the context of stable unclip to add noise to the image embeddings. The amount of noise is - controlled by a `noise_level` input. A higher `noise_level` increases the variance in the final un-noised images. - - The noise is applied in two ways - 1. A noise schedule is applied directly to the embeddings - 2. A vector of sinusoidal time embeddings are appended to the output. - - In both cases, the amount of noise is controlled by the same `noise_level`. - - The embeddings are normalized before the noise is applied and un-normalized after the noise is applied, so the - layer requires parameters for the mean and standard deviation of the embeddings. - """ - - @register_to_config - def __init__( - self, - embedding_dim: int = 768, - max_noise_level: int = 1000, - beta_start: float = 0.0001, - beta_end: float = 0.02, - beta_schedule: str = "linear", - trained_betas: Optional[Union[np.ndarray, List[float]]] = None, - set_alpha_to_one: bool = True, - ): - super().__init__() - self.max_noise_level = max_noise_level - - self.mean = nn.Parameter(torch.zeros(1, embedding_dim)) - self.std = nn.Parameter(torch.ones(1, embedding_dim)) - - self.time_embed = Timesteps(embedding_dim, True, 0) - - if trained_betas is not None: - self.betas = torch.tensor(trained_betas, dtype=torch.float32) - elif beta_schedule == "linear": - self.betas = torch.linspace(beta_start, beta_end, max_noise_level, dtype=torch.float32) - elif beta_schedule == "scaled_linear": - # this schedule is very specific to the latent diffusion model. - self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, max_noise_level, dtype=torch.float32) ** 2 - elif beta_schedule == "squaredcos_cap_v2": - # Glide cosine schedule - self.betas = betas_for_alpha_bar(max_noise_level) - else: - raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") - - self.alphas = 1.0 - self.betas - self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) - - # At every step in ddim, we are looking into the previous alphas_cumprod - # For the final step, there is no previous alphas_cumprod because we are already at 0 - # `set_alpha_to_one` decides whether we set this parameter simply to one or - # whether we use the final alpha of the "non-previous" one. - self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0] - - # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise - def add_noise( - self, - original_samples: torch.FloatTensor, - noise: torch.FloatTensor, - timesteps: torch.IntTensor, - ) -> torch.FloatTensor: - # Make sure alphas_cumprod and timestep have same device and dtype as original_samples - self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) - timesteps = timesteps.to(original_samples.device) - - sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 - sqrt_alpha_prod = sqrt_alpha_prod.flatten() - while len(sqrt_alpha_prod.shape) < len(original_samples.shape): - sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) - - sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5 - sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() - while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): - sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) - - noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise - return noisy_samples - - def forward( - self, - embeds: torch.Tensor, - noise_level: int, - noise: Optional[torch.FloatTensor] = None, - generator: Optional[torch.Generator] = None, - ): - if noise is None: - noise = randn_tensor(embeds.shape, generator=generator, device=self.device, dtype=embeds.dtype) - - noise_level = torch.tensor([noise_level] * embeds.shape[0], device=embeds.device) - - # scale - embeds = (embeds - self.mean) * 1.0 / self.std - - embeds = self.add_noise(original_samples=embeds, timesteps=noise_level, noise=noise) - - # unscale - embeds = (embeds * self.std) + self.mean - - noise_level = self.time_embed(noise_level) - - # `self.time_embed` does not contain any weights and will always return f32 tensors, - # but we might actually be running in fp16. so we need to cast here. - # there might be better ways to encapsulate this. - noise_level = noise_level.to(self.dtype) - - combined = torch.cat((embeds, noise_level), 1) - - return combined diff --git a/src/diffusers/pipelines/stable_diffusion/__init__.py b/src/diffusers/pipelines/stable_diffusion/__init__.py index 7cbf646939db..205b5386a7c8 100644 --- a/src/diffusers/pipelines/stable_diffusion/__init__.py +++ b/src/diffusers/pipelines/stable_diffusion/__init__.py @@ -48,6 +48,7 @@ class StableDiffusionPipelineOutput(BaseOutput): from .pipeline_stable_unclip import StableUnCLIPPipeline from .pipeline_stable_unclip_img2img import StableUnCLIPImg2ImgPipeline from .safety_checker import StableDiffusionSafetyChecker + from .stable_unclip_image_normalizer import StableUnCLIPImageNormalizer try: if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.25.0")): diff --git a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py index 9c620bcbf959..2640f61d42e3 100644 --- a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +++ b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py @@ -42,7 +42,6 @@ HeunDiscreteScheduler, LDMTextToImagePipeline, LMSDiscreteScheduler, - NoiseAugmentor, PNDMScheduler, PriorTransformer, StableDiffusionPipeline, @@ -54,6 +53,7 @@ from diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel from diffusers.pipelines.paint_by_example import PaintByExampleImageEncoder, PaintByExamplePipeline from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker +from diffusers.pipelines.stable_diffusion.stable_unclip_image_normalizer import StableUnCLIPImageNormalizer from ...utils import is_omegaconf_available, is_safetensors_available, logging from ...utils.import_utils import BACKENDS_MAPPING @@ -852,11 +852,15 @@ def stable_unclip_image_encoder(original_config): return feature_extractor, image_encoder -def stable_unclip_noise_augmentor( +def stable_unclip_image_noising_components( original_config, clip_stats_path: Optional[str] = None, device: Optional[str] = None ): """ - Returns the noise augmentor for the img2img and txt2img unclip pipelines. + Returns the noising components for the img2img and txt2img unclip pipelines. + + Converts the stability noise augmentor into + 1. a `StableUnCLIPImageNormalizer` for holding the CLIP stats + 2. a `DDPMScheduler` for holding the noise schedule If the noise augmentor config specifies a clip stats path, the `clip_stats_path` must be provided. """ @@ -870,9 +874,8 @@ def stable_unclip_noise_augmentor( max_noise_level = noise_aug_config.noise_schedule_config.timesteps beta_schedule = noise_aug_config.noise_schedule_config.beta_schedule - noise_augmentor = NoiseAugmentor( - embedding_dim=embedding_dim, max_noise_level=max_noise_level, beta_schedule=beta_schedule - ) + image_normalizer = StableUnCLIPImageNormalizer(embedding_dim=embedding_dim) + image_noising_scheduler = DDPMScheduler(num_train_timesteps=max_noise_level, beta_schedule=beta_schedule) if "clip_stats_path" in noise_aug_config: if clip_stats_path is None: @@ -887,11 +890,11 @@ def stable_unclip_noise_augmentor( "std": clip_std, } - noise_augmentor.load_state_dict(clip_stats_state_dict) + image_normalizer.load_state_dict(clip_stats_state_dict) else: raise NotImplementedError(f"Unknown noise augmentor class: {noise_aug_class}") - return noise_augmentor + return image_normalizer, image_noising_scheduler def load_pipeline_from_original_stable_diffusion_ckpt( @@ -1098,7 +1101,7 @@ def load_pipeline_from_original_stable_diffusion_ckpt( requires_safety_checker=False, ) else: - noise_augmentor = stable_unclip_noise_augmentor( + image_normalizer, image_noising_scheduler = stable_unclip_image_noising_components( original_config, clip_stats_path=clip_stats_path, device=device ) @@ -1106,14 +1109,19 @@ def load_pipeline_from_original_stable_diffusion_ckpt( feature_extractor, image_encoder = stable_unclip_image_encoder(original_config) pipe = StableUnCLIPImg2ImgPipeline( - vae=vae, - unet=unet, - text_encoder=text_model, + # image encoding components + feature_extractor=feature_extractor, + image_encoder=image_encoder, + # image noising components + image_normalizer=image_normalizer, + image_noising_scheduler=image_noising_scheduler, + # regular denoising components tokenizer=tokenizer, + text_encoder=text_model, + unet=unet, scheduler=scheduler, - image_encoder=image_encoder, - feature_extractor=feature_extractor, - noise_augmentor=noise_augmentor, + # vae + vae=vae, ) elif stable_unclip == "txt2img": if stable_unclip_prior == "karlo": @@ -1129,16 +1137,21 @@ def load_pipeline_from_original_stable_diffusion_ckpt( raise NotImplementedError(f"unknown prior for stable unclip model: {stable_unclip_prior}") pipe = StableUnCLIPPipeline( - vae=vae, - unet=unet, - tokenizer=tokenizer, - text_encoder=text_model, + # prior components prior_tokenizer=prior_tokenizer, prior_text_encoder=prior_text_model, - scheduler=scheduler, - noise_augmentor=noise_augmentor, prior=prior, prior_scheduler=prior_scheduler, + # image noising components + image_normalizer=image_normalizer, + image_noising_scheduler=image_noising_scheduler, + # regular denoising components + tokenizer=tokenizer, + text_encoder=text_model, + unet=unet, + scheduler=scheduler, + # vae + vae=vae, ) else: raise NotImplementedError(f"unknown `stable_unclip` type: {stable_unclip}") diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py index 38a80a6e923d..7236587de820 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py @@ -19,10 +19,12 @@ from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer from transformers.models.clip.modeling_clip import CLIPTextModelOutput -from ...models import AutoencoderKL, NoiseAugmentor, PriorTransformer, UNet2DConditionModel +from ...models import AutoencoderKL, PriorTransformer, UNet2DConditionModel +from ...models.embeddings import get_timestep_embedding from ...schedulers import DDIMScheduler, DDPMScheduler from ...utils import is_accelerate_available, logging, randn_tensor, replace_example_docstring from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput +from .stable_unclip_image_normalizer import StableUnCLIPImageNormalizer logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -62,10 +64,12 @@ class StableUnCLIPPipeline(DiffusionPipeline): The canonincal unCLIP prior to approximate the image embedding from the text embedding. prior_scheduler ([`DDPMScheduler`]): Scheduler used in the prior denoising process. - noise_augmentor ([`NoiseAugmentor`]): - Layer for adding noise to the predicted image embeddings. The amount of noise to add is determined by - `noise_level` in `StableUnCLIPPipeline.__call__`. See `NoiseAugmentor` for more details on how noise is - added. + image_normalizer ([`StableUnCLIPImageNormalizer`]): + Used to normalize the predicted image embeddings before the noise is applied and un-normalize the image + embeddings after the noise has been applied. + image_noising_scheduler ([`DDPMScheduler`]): + Noise schedule for adding noise to the predicted image embeddings. The amount of noise to add is determined + by `noise_level` in `StableUnCLIPPipeline.__call__`. tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). @@ -84,8 +88,11 @@ class StableUnCLIPPipeline(DiffusionPipeline): prior: PriorTransformer prior_scheduler: DDPMScheduler + # image noising components + image_normalizer: StableUnCLIPImageNormalizer + image_noising_scheduler: DDPMScheduler + # regular denoising components - noise_augmentor: NoiseAugmentor tokenizer: CLIPTokenizer text_encoder: CLIPTextModel unet: UNet2DConditionModel @@ -100,8 +107,10 @@ def __init__( prior_text_encoder: CLIPTextModelWithProjection, prior: PriorTransformer, prior_scheduler: DDPMScheduler, + # image noising components + image_normalizer: StableUnCLIPImageNormalizer, + image_noising_scheduler: DDPMScheduler, # regular denoising components - noise_augmentor: NoiseAugmentor, tokenizer: CLIPTokenizer, text_encoder: CLIPTextModelWithProjection, unet: UNet2DConditionModel, @@ -116,7 +125,8 @@ def __init__( prior_text_encoder=prior_text_encoder, prior=prior, prior_scheduler=prior_scheduler, - noise_augmentor=noise_augmentor, + image_normalizer=image_normalizer, + image_noising_scheduler=image_noising_scheduler, tokenizer=tokenizer, text_encoder=text_encoder, unet=unet, @@ -157,10 +167,9 @@ def enable_sequential_cpu_offload(self, gpu_id=0): device = torch.device(f"cuda:{gpu_id}") - # TODO: self.prior.post_process_latents is not covered by the offload hooks, so it fails if added to the list + # TODO: self.prior.post_process_latents and self.image_noiser.{scale,unscale} are not covered by the offload hooks, so they fails if added to the list models = [ self.prior_text_encoder, - self.noise_augmentor, self.text_encoder, self.unet, self.vae, @@ -511,9 +520,9 @@ def check_inputs( f" {negative_prompt_embeds.shape}." ) - if noise_level < 0 or noise_level >= self.noise_augmentor.max_noise_level: + if noise_level < 0 or noise_level >= self.image_noising_scheduler.config.num_train_timesteps: raise ValueError( - f"`noise_level` must be between 0 and {self.noise_augmentor.max_noise_level - 1}, inclusive." + f"`noise_level` must be between 0 and {self.image_noising_scheduler.config.num_train_timesteps - 1}, inclusive." ) # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents @@ -528,6 +537,51 @@ def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): latents = latents * scheduler.init_noise_sigma return latents + def noise_image_embeddings( + self, + image_embeds: torch.Tensor, + noise_level: int, + noise: Optional[torch.FloatTensor] = None, + generator: Optional[torch.Generator] = None, + ): + """ + Add noise to the image embeddings. The amount of noise is controlled by a `noise_level` input. A higher + `noise_level` increases the variance in the final un-noised images. + + The noise is applied in two ways + 1. A noise schedule is applied directly to the embeddings + 2. A vector of sinusoidal time embeddings are appended to the output. + + In both cases, the amount of noise is controlled by the same `noise_level`. + + The embeddings are normalized before the noise is applied and un-normalized after the noise is applied. + """ + if noise is None: + noise = randn_tensor( + image_embeds.shape, generator=generator, device=image_embeds.device, dtype=image_embeds.dtype + ) + + noise_level = torch.tensor([noise_level] * image_embeds.shape[0], device=image_embeds.device) + + image_embeds = self.image_normalizer.scale(image_embeds) + + image_embeds = self.image_noising_scheduler.add_noise(image_embeds, timesteps=noise_level, noise=noise) + + image_embeds = self.image_normalizer.unscale(image_embeds) + + noise_level = get_timestep_embedding( + timesteps=noise_level, embedding_dim=image_embeds.shape[-1], flip_sin_to_cos=True, downscale_freq_shift=0 + ) + + # `get_timestep_embeddings` does not contain any weights and will always return f32 tensors, + # but we might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + noise_level = noise_level.to(image_embeds.dtype) + + image_embeds = torch.cat((image_embeds, noise_level), 1) + + return image_embeds + @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( @@ -617,7 +671,7 @@ def __call__( [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). noise_level (`int`, *optional*, defaults to `0`): The amount of noise to add to the image embeddings. A higher `noise_level` increases the variance in - the final un-noised images. See `NoiseAugmentor` for how the noise is added. + the final un-noised images. See `StableUnCLIPPipeline.noise_image_embeddings` for details. prior_num_inference_steps (`int`, *optional*, defaults to 25): The number of denoising steps in the prior denoising process. More denoising steps usually lead to a higher quality image at the expense of slower inference. @@ -747,7 +801,11 @@ def __call__( ) # 8. Prepare image embeddings - image_embeds = self.noise_augmentor(image_embeds, noise_level=noise_level, generator=generator) + image_embeds = self.noise_image_embeddings( + image_embeds=image_embeds, + noise_level=noise_level, + generator=generator, + ) if do_classifier_free_guidance: negative_prompt_embeds = torch.zeros_like(image_embeds) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py index 59d50d34b114..1d7a10c914c7 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py @@ -21,10 +21,12 @@ from diffusers.utils.import_utils import is_accelerate_available -from ...models import AutoencoderKL, NoiseAugmentor, UNet2DConditionModel -from ...schedulers import DDIMScheduler +from ...models import AutoencoderKL, UNet2DConditionModel +from ...models.embeddings import get_timestep_embedding +from ...schedulers import DDIMScheduler, DDPMScheduler from ...utils import logging, randn_tensor, replace_example_docstring from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput +from .stable_unclip_image_normalizer import StableUnCLIPImageNormalizer logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -70,10 +72,12 @@ class StableUnCLIPImg2ImgPipeline(DiffusionPipeline): Feature extractor for image pre-processing before being encoded. image_encoder ([`CLIPVisionModelWithProjection`]): CLIP vision model for encoding images. - noise_augmentor ([`NoiseAugmentor`]): - Layer for adding noise to the predicted image embeddings. The amount of noise to add is determined by - `noise_level` in `StableUnCLIPPipeline.__call__`. See `NoiseAugmentor` for more details on how noise is - added. + image_normalizer ([`StableUnCLIPImageNormalizer`]): + Used to normalize the predicted image embeddings before the noise is applied and un-normalize the image + embeddings after the noise has been applied. + image_noising_scheduler ([`DDPMScheduler`]): + Noise schedule for adding noise to the predicted image embeddings. The amount of noise to add is determined + by `noise_level` in `StableUnCLIPPipeline.__call__`. tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). @@ -90,8 +94,11 @@ class StableUnCLIPImg2ImgPipeline(DiffusionPipeline): feature_extractor: CLIPFeatureExtractor image_encoder: CLIPVisionModelWithProjection + # image noising components + image_normalizer: StableUnCLIPImageNormalizer + image_noising_scheduler: DDPMScheduler + # regular denoising components - noise_augmentor: NoiseAugmentor tokenizer: CLIPTokenizer text_encoder: CLIPTextModel unet: UNet2DConditionModel @@ -104,8 +111,10 @@ def __init__( # image encoding components feature_extractor: CLIPFeatureExtractor, image_encoder: CLIPVisionModelWithProjection, + # image noising components + image_normalizer: StableUnCLIPImageNormalizer, + image_noising_scheduler: DDPMScheduler, # regular denoising components - noise_augmentor: NoiseAugmentor, tokenizer: CLIPTokenizer, text_encoder: CLIPTextModel, unet: UNet2DConditionModel, @@ -118,7 +127,8 @@ def __init__( self.register_modules( feature_extractor=feature_extractor, image_encoder=image_encoder, - noise_augmentor=noise_augmentor, + image_normalizer=image_normalizer, + image_noising_scheduler=image_noising_scheduler, tokenizer=tokenizer, text_encoder=text_encoder, unet=unet, @@ -159,9 +169,9 @@ def enable_sequential_cpu_offload(self, gpu_id=0): device = torch.device(f"cuda:{gpu_id}") + # TODO: self.image_noiser.{scale,unscale} are not covered by the offload hooks, so they fails if added to the list models = [ self.image_encoder, - self.noise_augmentor, self.text_encoder, self.unet, self.vae, @@ -360,7 +370,11 @@ def _encode_image( image = image.to(device=device, dtype=dtype) image_embeds = self.image_encoder(image).image_embeds - image_embeds = self.noise_augmentor(image_embeds, noise_level=noise_level, generator=generator) + image_embeds = self.noise_image_embeddings( + image_embeds=image_embeds, + noise_level=noise_level, + generator=generator, + ) # duplicate image embeddings for each generation per prompt, using mps friendly method image_embeds = image_embeds.unsqueeze(1) @@ -463,9 +477,9 @@ def check_inputs( f" {negative_prompt_embeds.shape}." ) - if noise_level < 0 or noise_level >= self.noise_augmentor.max_noise_level: + if noise_level < 0 or noise_level >= self.image_noising_scheduler.config.num_train_timesteps: raise ValueError( - f"`noise_level` must be between 0 and {self.noise_augmentor.max_noise_level - 1}, inclusive." + f"`noise_level` must be between 0 and {self.image_noising_scheduler.config.num_train_timesteps - 1}, inclusive." ) if image is not None and image_embeds is not None: @@ -507,6 +521,52 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype latents = latents * self.scheduler.init_noise_sigma return latents + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_unclip.StableUnCLIPPipeline.noise_image_embeddings + def noise_image_embeddings( + self, + image_embeds: torch.Tensor, + noise_level: int, + noise: Optional[torch.FloatTensor] = None, + generator: Optional[torch.Generator] = None, + ): + """ + Add noise to the image embeddings. The amount of noise is controlled by a `noise_level` input. A higher + `noise_level` increases the variance in the final un-noised images. + + The noise is applied in two ways + 1. A noise schedule is applied directly to the embeddings + 2. A vector of sinusoidal time embeddings are appended to the output. + + In both cases, the amount of noise is controlled by the same `noise_level`. + + The embeddings are normalized before the noise is applied and un-normalized after the noise is applied. + """ + if noise is None: + noise = randn_tensor( + image_embeds.shape, generator=generator, device=image_embeds.device, dtype=image_embeds.dtype + ) + + noise_level = torch.tensor([noise_level] * image_embeds.shape[0], device=image_embeds.device) + + image_embeds = self.image_normalizer.scale(image_embeds) + + image_embeds = self.image_noising_scheduler.add_noise(image_embeds, timesteps=noise_level, noise=noise) + + image_embeds = self.image_normalizer.unscale(image_embeds) + + noise_level = get_timestep_embedding( + timesteps=noise_level, embedding_dim=image_embeds.shape[-1], flip_sin_to_cos=True, downscale_freq_shift=0 + ) + + # `get_timestep_embeddings` does not contain any weights and will always return f32 tensors, + # but we might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + noise_level = noise_level.to(image_embeds.dtype) + + image_embeds = torch.cat((image_embeds, noise_level), 1) + + return image_embeds + @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( @@ -598,7 +658,7 @@ def __call__( [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). noise_level (`int`, *optional*, defaults to `0`): The amount of noise to add to the image embeddings. A higher `noise_level` increases the variance in - the final un-noised images. See `NoiseAugmentor` for how the noise is added. + the final un-noised images. See `StableUnCLIPPipeline.noise_image_embeddings` for details. image_embeds (`torch.FloatTensor`, *optional*): Pre-generated CLIP embeddings to condition the unet on. Note that these are not latents to be used in the denoising process. If you want to provide pre-generated latents, pass them to `__call__` as diff --git a/src/diffusers/pipelines/stable_diffusion/stable_unclip_image_normalizer.py b/src/diffusers/pipelines/stable_diffusion/stable_unclip_image_normalizer.py new file mode 100644 index 000000000000..c7803da70df8 --- /dev/null +++ b/src/diffusers/pipelines/stable_diffusion/stable_unclip_image_normalizer.py @@ -0,0 +1,46 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from torch import nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ...models.modeling_utils import ModelMixin + + +class StableUnCLIPImageNormalizer(ModelMixin, ConfigMixin): + """ + This class is used to hold the mean and standard deviation of the CLIP embedder used in stable unCLIP. + + It is used to normalize the image embeddings before the noise is applied and un-normalize the noised image + embeddings. + """ + + @register_to_config + def __init__( + self, + embedding_dim: int = 768, + ): + super().__init__() + + self.mean = nn.Parameter(torch.zeros(1, embedding_dim)) + self.std = nn.Parameter(torch.ones(1, embedding_dim)) + + def scale(self, embeds): + embeds = (embeds - self.mean) * 1.0 / self.std + return embeds + + def unscale(self, embeds): + embeds = (embeds * self.std) + self.mean + return embeds diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 1f67ad597766..546992bc436e 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -32,21 +32,6 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) -class NoiseAugmentor(metaclass=DummyObject): - _backends = ["torch"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch"]) - - class PriorTransformer(metaclass=DummyObject): _backends = ["torch"] diff --git a/tests/pipelines/stable_unclip/test_stable_unclip.py b/tests/pipelines/stable_unclip/test_stable_unclip.py index 74099d646407..7bc351ad76a9 100644 --- a/tests/pipelines/stable_unclip/test_stable_unclip.py +++ b/tests/pipelines/stable_unclip/test_stable_unclip.py @@ -8,11 +8,11 @@ AutoencoderKL, DDIMScheduler, DDPMScheduler, - NoiseAugmentor, PriorTransformer, StableUnCLIPPipeline, UNet2DConditionModel, ) +from diffusers.pipelines.stable_diffusion.stable_unclip_image_normalizer import StableUnCLIPImageNormalizer from diffusers.utils.testing_utils import load_numpy, require_torch_gpu, slow, torch_device from ...test_pipelines_common import PipelineTesterMixin, assert_mean_pixel_difference @@ -70,7 +70,8 @@ def get_dummy_components(self): # regular denoising components torch.manual_seed(0) - noise_augmentor = NoiseAugmentor(embedding_dim=embedder_hidden_size, beta_schedule="squaredcos_cap_v2") + image_normalizer = StableUnCLIPImageNormalizer(embedding_dim=embedder_hidden_size) + image_noising_scheduler = DDPMScheduler(beta_schedule="squaredcos_cap_v2") torch.manual_seed(0) tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") @@ -129,8 +130,10 @@ def get_dummy_components(self): "prior_text_encoder": prior_text_encoder, "prior": prior, "prior_scheduler": prior_scheduler, + # image noising components + "image_normalizer": image_normalizer, + "image_noising_scheduler": image_noising_scheduler, # regular denoising components - "noise_augmentor": noise_augmentor, "tokenizer": tokenizer, "text_encoder": text_encoder, "unet": unet, diff --git a/tests/pipelines/stable_unclip/test_stable_unclip_img2img.py b/tests/pipelines/stable_unclip/test_stable_unclip_img2img.py index 39c988cc371a..adbf3b272706 100644 --- a/tests/pipelines/stable_unclip/test_stable_unclip_img2img.py +++ b/tests/pipelines/stable_unclip/test_stable_unclip_img2img.py @@ -12,8 +12,9 @@ CLIPVisionModelWithProjection, ) -from diffusers import AutoencoderKL, DDIMScheduler, NoiseAugmentor, StableUnCLIPImg2ImgPipeline, UNet2DConditionModel +from diffusers import AutoencoderKL, DDIMScheduler, DDPMScheduler, StableUnCLIPImg2ImgPipeline, UNet2DConditionModel from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.pipelines.stable_diffusion.stable_unclip_image_normalizer import StableUnCLIPImageNormalizer from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.testing_utils import floats_tensor, load_image, load_numpy, require_torch_gpu, slow, torch_device @@ -46,7 +47,8 @@ def get_dummy_components(self): # regular denoising components torch.manual_seed(0) - noise_augmentor = NoiseAugmentor(embedding_dim=embedder_hidden_size, beta_schedule="squaredcos_cap_v2") + image_normalizer = StableUnCLIPImageNormalizer(embedding_dim=embedder_hidden_size) + image_noising_scheduler = DDPMScheduler(beta_schedule="squaredcos_cap_v2") torch.manual_seed(0) tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") @@ -103,8 +105,10 @@ def get_dummy_components(self): # image encoding components "feature_extractor": feature_extractor, "image_encoder": image_encoder, + # image noising components + "image_normalizer": image_normalizer, + "image_noising_scheduler": image_noising_scheduler, # regular denoising components - "noise_augmentor": noise_augmentor, "tokenizer": tokenizer, "text_encoder": text_encoder, "unet": unet, From e3458b0b1a8b63e23cdf49e07efe483a7dd6cb91 Mon Sep 17 00:00:00 2001 From: William Berman Date: Mon, 13 Feb 2023 16:32:23 -0800 Subject: [PATCH 09/14] stable_unclip_prior default to None re: @patrickvonplaten --- scripts/convert_original_stable_diffusion_to_diffusers.py | 4 ++-- src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/scripts/convert_original_stable_diffusion_to_diffusers.py b/scripts/convert_original_stable_diffusion_to_diffusers.py index ed82d08c971e..11e35211b242 100644 --- a/scripts/convert_original_stable_diffusion_to_diffusers.py +++ b/scripts/convert_original_stable_diffusion_to_diffusers.py @@ -110,9 +110,9 @@ parser.add_argument( "--stable_unclip_prior", type=str, - default="karlo", + default=None, required=False, - help="Set if this is a stable unCLIP txt2img model. Selects which prior to use. Defaults to 'karlo'.", + help="Set if this is a stable unCLIP txt2img model. Selects which prior to use. If `--stable_unclip` is set to `txt2img`, the karlo prior (https://huggingface.co/kakaobrain/karlo-v1-alpha/tree/main/prior) is selected by default.", ) parser.add_argument( "--clip_stats_path", diff --git a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py index ec90ce45acf1..e7fd1f4d6828 100644 --- a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +++ b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py @@ -910,7 +910,7 @@ def load_pipeline_from_original_stable_diffusion_ckpt( device: str = None, from_safetensors: bool = False, stable_unclip: Optional[str] = None, - stable_unclip_prior: str = "karlo", + stable_unclip_prior: Optional[str] = None, clip_stats_path: Optional[str] = None, ) -> StableDiffusionPipeline: """ @@ -1132,7 +1132,7 @@ def load_pipeline_from_original_stable_diffusion_ckpt( vae=vae, ) elif stable_unclip == "txt2img": - if stable_unclip_prior == "karlo": + if stable_unclip_prior is None or stable_unclip_prior == "karlo": karlo_model = "kakaobrain/karlo-v1-alpha" prior = PriorTransformer.from_pretrained(karlo_model, subfolder="prior") From 4d62cf0a28e5477c170d3e1b6d2af7689b8c7e04 Mon Sep 17 00:00:00 2001 From: Will Berman Date: Tue, 14 Feb 2023 07:28:50 +0000 Subject: [PATCH 10/14] prepare_prior_extra_step_kwargs --- .../pipeline_stable_unclip.py | 41 ++++++++++++++----- 1 file changed, 31 insertions(+), 10 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py index 7236587de820..01bddc3e07e2 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py @@ -447,6 +447,24 @@ def decode_latents(self, latents): image = image.cpu().permute(0, 2, 3, 1).float().numpy() return image + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs with prepare_extra_step_kwargs->prepare_prior_extra_step_kwargs, scheduler->prior_scheduler + def prepare_prior_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the prior_scheduler step, since not all prior_schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other prior_schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.prior_scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the prior_scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.prior_scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature @@ -749,7 +767,10 @@ def __call__( self.prior_scheduler, ) - # 6. Prior denoising loop + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + prior_extra_step_kwargs = self.prepare_prior_extra_step_kwargs(generator, eta) + + # 7. Prior denoising loop for i, t in enumerate(self.progress_bar(prior_timesteps_tensor)): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([prior_latents] * 2) if prior_do_classifier_free_guidance else prior_latents @@ -772,7 +793,7 @@ def __call__( predicted_image_embedding, timestep=t, sample=prior_latents, - generator=generator, + **prior_extra_step_kwargs, ).prev_sample if callback is not None and i % callback_steps == 0: @@ -789,7 +810,7 @@ def __call__( # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 - # 7. Encode input prompt + # 8. Encode input prompt prompt_embeds = self._encode_prompt( prompt=prompt, device=device, @@ -800,7 +821,7 @@ def __call__( negative_prompt_embeds=negative_prompt_embeds, ) - # 8. Prepare image embeddings + # 9. Prepare image embeddings image_embeds = self.noise_image_embeddings( image_embeds=image_embeds, noise_level=noise_level, @@ -815,11 +836,11 @@ def __call__( # to avoid doing two forward passes image_embeds = torch.cat([negative_prompt_embeds, image_embeds]) - # 9. Prepare timesteps + # 10. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps - # 10. Prepare latent variables + # 11. Prepare latent variables num_channels_latents = self.unet.in_channels shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) latents = self.prepare_latents( @@ -831,10 +852,10 @@ def __call__( scheduler=self.scheduler, ) - # 11. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + # 12. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) - # 12. Denoising loop + # 13. Denoising loop for i, t in enumerate(self.progress_bar(timesteps)): latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) @@ -859,10 +880,10 @@ def __call__( if callback is not None and i % callback_steps == 0: callback(i, t, latents) - # 13. Post-processing + # 14. Post-processing image = self.decode_latents(latents) - # 14. Convert to PIL + # 15. Convert to PIL if output_type == "pil": image = self.numpy_to_pil(image) From f35b2e8d7492ba28cb1dcddc83044ca9b73c7e2e Mon Sep 17 00:00:00 2001 From: Will Berman Date: Tue, 14 Feb 2023 08:03:05 +0000 Subject: [PATCH 11/14] prior denoising scale model input --- .../pipelines/stable_diffusion/pipeline_stable_unclip.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py index 01bddc3e07e2..179ce1d05db2 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py @@ -774,6 +774,7 @@ def __call__( for i, t in enumerate(self.progress_bar(prior_timesteps_tensor)): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([prior_latents] * 2) if prior_do_classifier_free_guidance else prior_latents + latent_model_input = self.prior_scheduler.scale_model_input(latent_model_input, t) predicted_image_embedding = self.prior( latent_model_input, From c872cbc65812d89babf448d8b2cde2ecdc36708b Mon Sep 17 00:00:00 2001 From: William Berman Date: Tue, 14 Feb 2023 00:22:53 -0800 Subject: [PATCH 12/14] {DDIM,DDPM}Scheduler -> KarrasDiffusionSchedulers re: @patrickvonplaten --- .../pipeline_stable_unclip.py | 20 +++++++++---------- .../pipeline_stable_unclip_img2img.py | 14 ++++++------- 2 files changed, 17 insertions(+), 17 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py index 179ce1d05db2..954f088958a5 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py @@ -21,7 +21,7 @@ from ...models import AutoencoderKL, PriorTransformer, UNet2DConditionModel from ...models.embeddings import get_timestep_embedding -from ...schedulers import DDIMScheduler, DDPMScheduler +from ...schedulers import KarrasDiffusionSchedulers from ...utils import is_accelerate_available, logging, randn_tensor, replace_example_docstring from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from .stable_unclip_image_normalizer import StableUnCLIPImageNormalizer @@ -62,12 +62,12 @@ class StableUnCLIPPipeline(DiffusionPipeline): Frozen text-encoder. prior ([`PriorTransformer`]): The canonincal unCLIP prior to approximate the image embedding from the text embedding. - prior_scheduler ([`DDPMScheduler`]): + prior_scheduler ([`KarrasDiffusionSchedulers`]): Scheduler used in the prior denoising process. image_normalizer ([`StableUnCLIPImageNormalizer`]): Used to normalize the predicted image embeddings before the noise is applied and un-normalize the image embeddings after the noise has been applied. - image_noising_scheduler ([`DDPMScheduler`]): + image_noising_scheduler ([`KarrasDiffusionSchedulers`]): Noise schedule for adding noise to the predicted image embeddings. The amount of noise to add is determined by `noise_level` in `StableUnCLIPPipeline.__call__`. tokenizer (`CLIPTokenizer`): @@ -76,7 +76,7 @@ class StableUnCLIPPipeline(DiffusionPipeline): text_encoder ([`CLIPTextModel`]): Frozen text-encoder. unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. - scheduler ([`DDIMScheduler`]): + scheduler ([`KarrasDiffusionSchedulers`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. @@ -86,17 +86,17 @@ class StableUnCLIPPipeline(DiffusionPipeline): prior_tokenizer: CLIPTokenizer prior_text_encoder: CLIPTextModelWithProjection prior: PriorTransformer - prior_scheduler: DDPMScheduler + prior_scheduler: KarrasDiffusionSchedulers # image noising components image_normalizer: StableUnCLIPImageNormalizer - image_noising_scheduler: DDPMScheduler + image_noising_scheduler: KarrasDiffusionSchedulers # regular denoising components tokenizer: CLIPTokenizer text_encoder: CLIPTextModel unet: UNet2DConditionModel - scheduler: DDIMScheduler + scheduler: KarrasDiffusionSchedulers vae: AutoencoderKL @@ -106,15 +106,15 @@ def __init__( prior_tokenizer: CLIPTokenizer, prior_text_encoder: CLIPTextModelWithProjection, prior: PriorTransformer, - prior_scheduler: DDPMScheduler, + prior_scheduler: KarrasDiffusionSchedulers, # image noising components image_normalizer: StableUnCLIPImageNormalizer, - image_noising_scheduler: DDPMScheduler, + image_noising_scheduler: KarrasDiffusionSchedulers, # regular denoising components tokenizer: CLIPTokenizer, text_encoder: CLIPTextModelWithProjection, unet: UNet2DConditionModel, - scheduler: DDIMScheduler, + scheduler: KarrasDiffusionSchedulers, # vae vae: AutoencoderKL, ): diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py index 1d7a10c914c7..2910595e0472 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py @@ -23,7 +23,7 @@ from ...models import AutoencoderKL, UNet2DConditionModel from ...models.embeddings import get_timestep_embedding -from ...schedulers import DDIMScheduler, DDPMScheduler +from ...schedulers import KarrasDiffusionSchedulers from ...utils import logging, randn_tensor, replace_example_docstring from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from .stable_unclip_image_normalizer import StableUnCLIPImageNormalizer @@ -75,7 +75,7 @@ class StableUnCLIPImg2ImgPipeline(DiffusionPipeline): image_normalizer ([`StableUnCLIPImageNormalizer`]): Used to normalize the predicted image embeddings before the noise is applied and un-normalize the image embeddings after the noise has been applied. - image_noising_scheduler ([`DDPMScheduler`]): + image_noising_scheduler ([`KarrasDiffusionSchedulers`]): Noise schedule for adding noise to the predicted image embeddings. The amount of noise to add is determined by `noise_level` in `StableUnCLIPPipeline.__call__`. tokenizer (`CLIPTokenizer`): @@ -84,7 +84,7 @@ class StableUnCLIPImg2ImgPipeline(DiffusionPipeline): text_encoder ([`CLIPTextModel`]): Frozen text-encoder. unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. - scheduler ([`DDIMScheduler`]): + scheduler ([`KarrasDiffusionSchedulers`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. @@ -96,13 +96,13 @@ class StableUnCLIPImg2ImgPipeline(DiffusionPipeline): # image noising components image_normalizer: StableUnCLIPImageNormalizer - image_noising_scheduler: DDPMScheduler + image_noising_scheduler: KarrasDiffusionSchedulers # regular denoising components tokenizer: CLIPTokenizer text_encoder: CLIPTextModel unet: UNet2DConditionModel - scheduler: DDIMScheduler + scheduler: KarrasDiffusionSchedulers vae: AutoencoderKL @@ -113,12 +113,12 @@ def __init__( image_encoder: CLIPVisionModelWithProjection, # image noising components image_normalizer: StableUnCLIPImageNormalizer, - image_noising_scheduler: DDPMScheduler, + image_noising_scheduler: KarrasDiffusionSchedulers, # regular denoising components tokenizer: CLIPTokenizer, text_encoder: CLIPTextModel, unet: UNet2DConditionModel, - scheduler: DDIMScheduler, + scheduler: KarrasDiffusionSchedulers, # vae vae: AutoencoderKL, ): From 19a2bbff42867f37d8dd8cc3ecfe0756c3a17a8d Mon Sep 17 00:00:00 2001 From: William Berman Date: Tue, 14 Feb 2023 10:30:45 -0800 Subject: [PATCH 13/14] docs --- docs/source/en/_toctree.yml | 2 + docs/source/en/api/pipelines/overview.mdx | 2 + .../pipelines/stable_diffusion/text2img.mdx | 2 +- .../source/en/api/pipelines/stable_unclip.mdx | 97 +++++++++++++++++++ docs/source/en/index.mdx | 2 + 5 files changed, 104 insertions(+), 1 deletion(-) create mode 100644 docs/source/en/api/pipelines/stable_unclip.mdx diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index f3175e9b7f8a..2f8cf19fea83 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -154,6 +154,8 @@ title: Stable Diffusion - local: api/pipelines/stable_diffusion_2 title: Stable Diffusion 2 + - local: api/pipelines/stable_unclip + title: Stable unCLIP - local: api/pipelines/stochastic_karras_ve title: Stochastic Karras VE - local: api/pipelines/unclip diff --git a/docs/source/en/api/pipelines/overview.mdx b/docs/source/en/api/pipelines/overview.mdx index fa2968351345..56b4abbca3dc 100644 --- a/docs/source/en/api/pipelines/overview.mdx +++ b/docs/source/en/api/pipelines/overview.mdx @@ -64,6 +64,8 @@ available a colab notebook to directly try them out. | [stable_diffusion_2](./stable_diffusion_2) | [**Stable Diffusion 2**](https://stability.ai/blog/stable-diffusion-v2-release) | Text-Guided Image Inpainting | | [stable_diffusion_2](./stable_diffusion_2) | [**Stable Diffusion 2**](https://stability.ai/blog/stable-diffusion-v2-release) | Text-Guided Super Resolution Image-to-Image | | [stable_diffusion_safe](./stable_diffusion_safe) | [**Safe Stable Diffusion**](https://arxiv.org/abs/2211.05105) | Text-Guided Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ml-research/safe-latent-diffusion/blob/main/examples/Safe%20Latent%20Diffusion.ipynb) +| [stable_unclip](./stable_unclip) | **Stable unCLIP** | Text-to-Image Generation | +| [stable_unclip](./stable_unclip) | **Stable unCLIP** | Image-to-Image Text-Guided Generation | | [stochastic_karras_ve](./stochastic_karras_ve) | [**Elucidating the Design Space of Diffusion-Based Generative Models**](https://arxiv.org/abs/2206.00364) | Unconditional Image Generation | | [unclip](./unclip) | [Hierarchical Text-Conditional Image Generation with CLIP Latents](https://arxiv.org/abs/2204.06125) | Text-to-Image Generation | | [versatile_diffusion](./versatile_diffusion) | [Versatile Diffusion: Text, Images and Variations All in One Diffusion Model](https://arxiv.org/abs/2211.08332) | Text-to-Image Generation | diff --git a/docs/source/en/api/pipelines/stable_diffusion/text2img.mdx b/docs/source/en/api/pipelines/stable_diffusion/text2img.mdx index 952ad24808b8..18657a2c0c15 100644 --- a/docs/source/en/api/pipelines/stable_diffusion/text2img.mdx +++ b/docs/source/en/api/pipelines/stable_diffusion/text2img.mdx @@ -17,7 +17,7 @@ specific language governing permissions and limitations under the License. The Stable Diffusion model was created by the researchers and engineers from [CompVis](https://github.com/CompVis), [Stability AI](https://stability.ai/), [runway](https://github.com/runwayml), and [LAION](https://laion.ai/). The [`StableDiffusionPipeline`] is capable of generating photo-realistic images given any text input using Stable Diffusion. The original codebase can be found here: -- *Stable Diffusion V1*: [CampVis/stable-diffusion](https://github.com/CompVis/stable-diffusion) +- *Stable Diffusion V1*: [CompVis/stable-diffusion](https://github.com/CompVis/stable-diffusion) - *Stable Diffusion v2*: [Stability-AI/stablediffusion](https://github.com/Stability-AI/stablediffusion) Available Checkpoints are: diff --git a/docs/source/en/api/pipelines/stable_unclip.mdx b/docs/source/en/api/pipelines/stable_unclip.mdx new file mode 100644 index 000000000000..04cad31b0175 --- /dev/null +++ b/docs/source/en/api/pipelines/stable_unclip.mdx @@ -0,0 +1,97 @@ + + +# Stable unCLIP + +The stable unCLIP model is [stable diffusion 2.1](./stable_diffusion_2) finetuned to condition on CLIP image embeddings. +Stable unCLIP also still conditions on text embeddings. Given the two separate conditionings, stable unCLIP can be used +for text guided image variation. When combined with an unCLIP prior, it can also be used for full text to image generation. + +## Tips + +Stable unCLIP takes a `noise_level` as input during inference. `noise_level` determines how much noise is added +to the image embeddings. A higher `noise_level` increases variation in the final un-noised images. By default, +we do not add any additional noise to the image embeddings i.e. `noise_level = 0`. + +### Available checkpoints: + +TODO + +### Text-to-Image Generation + +```python +import torch +from diffusers import StableUnCLIPPipeline + +pipe = StableUnCLIPPipeline.from_pretrained( + "fusing/stable-unclip-2-1-l", torch_dtype=torch.float16 +) # TODO update model path +pipe = pipe.to("cuda") + +prompt = "a photo of an astronaut riding a horse on mars" +images = pipe(prompt).images +images[0].save("astronaut_horse.png") +``` + + +### Text guided Image-to-Image Variation + +```python +import requests +import torch +from PIL import Image +from io import BytesIO + +from diffusers import StableUnCLIPImg2ImgPipeline + +pipe = StableUnCLIPImg2ImgPipeline.from_pretrained( + "fusing/stable-unclip-2-1-l-img2img", torch_dtype=torch.float16 +) # TODO update model path +pipe = pipe.to("cuda") + +url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg" + +response = requests.get(url) +init_image = Image.open(BytesIO(response.content)).convert("RGB") +init_image = init_image.resize((768, 512)) + +prompt = "A fantasy landscape, trending on artstation" + +images = pipe(prompt, init_image).images +images[0].save("fantasy_landscape.png") +``` + +### StableUnCLIPPipeline + +[[autodoc]] StableUnCLIPPipeline + - all + - __call__ + - enable_attention_slicing + - disable_attention_slicing + - enable_vae_slicing + - disable_vae_slicing + - enable_xformers_memory_efficient_attention + - disable_xformers_memory_efficient_attention + + +### StableUnCLIPImg2ImgPipeline + +[[autodoc]] StableUnCLIPImg2ImgPipeline + - all + - __call__ + - enable_attention_slicing + - disable_attention_slicing + - enable_vae_slicing + - disable_vae_slicing + - enable_xformers_memory_efficient_attention + - disable_xformers_memory_efficient_attention + \ No newline at end of file diff --git a/docs/source/en/index.mdx b/docs/source/en/index.mdx index 148ee53f411f..c116c7f5bb6e 100644 --- a/docs/source/en/index.mdx +++ b/docs/source/en/index.mdx @@ -54,6 +54,8 @@ available a colab notebook to directly try them out. | [stable_diffusion_2](./api/pipelines/stable_diffusion_2) | [**Stable Diffusion 2**](https://stability.ai/blog/stable-diffusion-v2-release) | Text-Guided Image Inpainting | | [stable_diffusion_2](./api/pipelines/stable_diffusion_2) | [**Stable Diffusion 2**](https://stability.ai/blog/stable-diffusion-v2-release) | Text-Guided Super Resolution Image-to-Image | | [stable_diffusion_safe](./api/pipelines/stable_diffusion_safe) | [**Safe Stable Diffusion**](https://arxiv.org/abs/2211.05105) | Text-Guided Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ml-research/safe-latent-diffusion/blob/main/examples/Safe%20Latent%20Diffusion.ipynb) +| [stable_unclip](./stable_unclip) | **Stable unCLIP** | Text-to-Image Generation | +| [stable_unclip](./stable_unclip) | **Stable unCLIP** | Image-to-Image Text-Guided Generation | | [stochastic_karras_ve](./api/pipelines/stochastic_karras_ve) | [**Elucidating the Design Space of Diffusion-Based Generative Models**](https://arxiv.org/abs/2206.00364) | Unconditional Image Generation | | [unclip](./api/pipelines/unclip) | [Hierarchical Text-Conditional Image Generation with CLIP Latents](https://arxiv.org/abs/2204.06125) | Text-to-Image Generation | | [versatile_diffusion](./api/pipelines/versatile_diffusion) | [Versatile Diffusion: Text, Images and Variations All in One Diffusion Model](https://arxiv.org/abs/2211.08332) | Text-to-Image Generation | From b91ca1384c72e5e5261e58e959603159dafaec0a Mon Sep 17 00:00:00 2001 From: Will Berman Date: Tue, 14 Feb 2023 11:12:48 -0800 Subject: [PATCH 14/14] Update docs/source/en/api/pipelines/stable_unclip.mdx Co-authored-by: Patrick von Platen --- docs/source/en/api/pipelines/stable_unclip.mdx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/en/api/pipelines/stable_unclip.mdx b/docs/source/en/api/pipelines/stable_unclip.mdx index 04cad31b0175..5b6bec5ecbb5 100644 --- a/docs/source/en/api/pipelines/stable_unclip.mdx +++ b/docs/source/en/api/pipelines/stable_unclip.mdx @@ -12,7 +12,7 @@ specific language governing permissions and limitations under the License. # Stable unCLIP -The stable unCLIP model is [stable diffusion 2.1](./stable_diffusion_2) finetuned to condition on CLIP image embeddings. +Stable unCLIP checkpoints are finetuned from [stable diffusion 2.1](./stable_diffusion_2) checkpoints to condition on CLIP image embeddings. Stable unCLIP also still conditions on text embeddings. Given the two separate conditionings, stable unCLIP can be used for text guided image variation. When combined with an unCLIP prior, it can also be used for full text to image generation.