From b218062fed08d6cc164206d6cb852b2b7b00847a Mon Sep 17 00:00:00 2001 From: Clarence Chen Date: Sat, 25 Mar 2023 18:12:22 -0700 Subject: [PATCH 01/41] Update Pix2PixZero Auto-correlation Loss --- .../pipeline_stable_diffusion_pix2pix_zero.py | 25 ++++++++----------- 1 file changed, 10 insertions(+), 15 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py index 4c2dbe6ff85d..72336f72a1e7 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py @@ -750,23 +750,18 @@ def get_epsilon(self, model_output: torch.Tensor, sample: torch.Tensor, timestep ) def auto_corr_loss(self, hidden_states, generator=None): - batch_size, channel, height, width = hidden_states.shape - if batch_size > 1: - raise ValueError("Only batch_size 1 is supported for now") - - hidden_states = hidden_states.squeeze(0) - # hidden_states must be shape [C,H,W] now reg_loss = 0.0 for i in range(hidden_states.shape[0]): - noise = hidden_states[i][None, None, :, :] - while True: - roll_amount = torch.randint(noise.shape[2] // 2, (1,), generator=generator).item() - reg_loss += (noise * torch.roll(noise, shifts=roll_amount, dims=2)).mean() ** 2 - reg_loss += (noise * torch.roll(noise, shifts=roll_amount, dims=3)).mean() ** 2 - - if noise.shape[2] <= 8: - break - noise = F.avg_pool2d(noise, kernel_size=2) + for j in range(hidden_states.shape[1]): + noise = hidden_states[i : i + 1, j : j + 1, :, :] + while True: + roll_amount = torch.randint(noise.shape[2] // 2, (1,), generator=generator).item() + reg_loss += (noise * torch.roll(noise, shifts=roll_amount, dims=2)).mean() ** 2 + reg_loss += (noise * torch.roll(noise, shifts=roll_amount, dims=3)).mean() ** 2 + + if noise.shape[2] <= 8: + break + noise = F.avg_pool2d(noise, kernel_size=2) return reg_loss def kl_divergence(self, hidden_states): From f953739b9da751323b9a4095aa1435b0deb2ac3b Mon Sep 17 00:00:00 2001 From: Clarence Chen Date: Sun, 26 Mar 2023 16:55:15 -0700 Subject: [PATCH 02/41] Add Stable Diffusion DiffEdit pipeline --- .../pipeline_stable_diffusion_diffedit.py | 1523 +++++++++++++++++ 1 file changed, 1523 insertions(+) create mode 100644 src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py new file mode 100644 index 000000000000..110822e9fb1b --- /dev/null +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py @@ -0,0 +1,1523 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from dataclasses import dataclass +from typing import Callable, List, Optional, Union + +import numpy as np +import PIL +import torch +from packaging import version +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer + +from ...configuration_utils import FrozenDict +from ...models import AutoencoderKL, UNet2DConditionModel +from ...schedulers import DDIMInverseScheduler, KarrasDiffusionSchedulers +from ...utils import ( + PIL_INTERPOLATION, + BaseOutput, + deprecate, + is_accelerate_available, + is_accelerate_version, + logging, + randn_tensor, +) +from ..pipeline_utils import DiffusionPipeline +from . import StableDiffusionPipelineOutput +from .safety_checker import StableDiffusionSafetyChecker + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class DiffEditInversionPipelineOutput(BaseOutput): + """ + Output class for Stable Diffusion pipelines. + + Args: + latents (`torch.FloatTensor`) + inverted latents tensor + images (`List[PIL.Image.Image]` or `np.ndarray`) + List of denoised PIL images of length `num_timesteps * batch_size` or numpy array of shape `(num_timesteps, + batch_size, height, width, num_channels)`. PIL images or numpy array present the denoised images of the + diffusion pipeline. + """ + + latents: torch.FloatTensor + images: Union[List[PIL.Image.Image], np.ndarray] + + +def auto_corr_loss(hidden_states, generator=None): + reg_loss = 0.0 + for i in range(hidden_states.shape[0]): + for j in range(hidden_states.shape[1]): + noise = hidden_states[i : i + 1, j : j + 1, :, :] + while True: + roll_amount = torch.randint(noise.shape[2] // 2, (1,), generator=generator).item() + reg_loss += (noise * torch.roll(noise, shifts=roll_amount, dims=2)).mean() ** 2 + reg_loss += (noise * torch.roll(noise, shifts=roll_amount, dims=3)).mean() ** 2 + + if noise.shape[2] <= 8: + break + noise = torch.nn.functional.avg_pool2d(noise, kernel_size=2) + return reg_loss + + +def kl_divergence(hidden_states): + return hidden_states.var() + hidden_states.mean() ** 2 - 1 - torch.log(hidden_states.var() + 1e-7) + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess +def preprocess(image): + if isinstance(image, torch.Tensor): + return image + elif isinstance(image, PIL.Image.Image): + image = [image] + + if isinstance(image[0], PIL.Image.Image): + w, h = image[0].size + w, h = map(lambda x: x - x % 8, (w, h)) # resize to integer multiple of 8 + + image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image] + image = np.concatenate(image, axis=0) + image = np.array(image).astype(np.float32) / 255.0 + image = image.transpose(0, 3, 1, 2) + image = 2.0 * image - 1.0 + image = torch.from_numpy(image) + elif isinstance(image[0], torch.Tensor): + image = torch.cat(image, dim=0) + return image + + +def preprocess_image_and_mask(image, mask): + """ + Prepares a pair (image, mask) to be consumed by the Stable Diffusion pipeline. This means that those inputs will be + converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the + ``image`` and ``1`` for the ``mask``. + + The ``image`` will be converted to ``torch.float32`` and normalized to be in ``[-1, 1]``. The ``mask`` will be + binarized (``mask > 0.5``) and cast to ``torch.float32`` too. + + Args: + image (Union[np.array, PIL.Image, torch.Tensor]): The image to inpaint. + It can be a ``PIL.Image``, or a ``height x width x 3`` ``np.array`` or a ``channels x height x width`` + ``torch.Tensor`` or a ``batch x channels x height x width`` ``torch.Tensor``. + mask (Union[np.array, PIL.Image, torch.Tensor], *optional*): + The mask to apply to the image, i.e. regions to inpaint. It can be a ``PIL.Image``, or a ``height x width`` + ``np.array`` or a ``1 x height x width`` ``torch.Tensor`` or a ``batch x 1 x height x width`` + ``torch.Tensor``. + + Raises: + ValueError: ``torch.Tensor`` images should be in the ``[-1, 1]`` range. ValueError: ``torch.Tensor`` mask + should be in the ``[0, 1]`` range. ValueError: ``mask`` and ``image`` should have the same spatial dimensions. + TypeError: ``mask`` is a ``torch.Tensor`` but ``image`` is not + (ot the other way around). + + Returns: + tuple[torch.Tensor]: The pair (mask, masked_image) as ``torch.Tensor`` with 4 + dimensions: ``batch x channels x height x width``. + """ + if isinstance(image, torch.Tensor): + if not isinstance(mask, torch.Tensor): + raise TypeError(f"`image` is a torch.Tensor but `mask` (type: {type(mask)} is not") + + # Batch single image + if image.ndim == 3: + assert image.shape[0] == 3, "Image outside a batch should be of shape (3, H, W)" + image = image.unsqueeze(0) + + # Batch and add channel dim for single mask + if mask.ndim == 2: + mask = mask.unsqueeze(0).unsqueeze(0) + + # Batch single mask or add channel dim + if mask.ndim == 3: + # Single batched mask, no channel dim or single mask not batched but channel dim + if mask.shape[0] == 1: + mask = mask.unsqueeze(0) + + # Batched masks no channel dim + else: + mask = mask.unsqueeze(1) + + assert image.ndim == 4 and mask.ndim == 4, "Image and Mask must have 4 dimensions" + assert image.shape[-2:] == mask.shape[-2:], "Image and Mask must have the same spatial dimensions" + assert image.shape[0] == mask.shape[0], "Image and Mask must have the same batch size" + + # Check image is in [-1, 1] + if image.min() < -1 or image.max() > 1: + raise ValueError("Image should be in [-1, 1] range") + + # Check mask is in [0, 1] + if mask.min() < 0 or mask.max() > 1: + raise ValueError("Mask should be in [0, 1] range") + + # Binarize mask + mask[mask < 0.5] = 0 + mask[mask >= 0.5] = 1 + + # Image as float32 + image = image.to(dtype=torch.float32) + elif isinstance(mask, torch.Tensor): + raise TypeError(f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not") + else: + # preprocess image + if isinstance(image, (PIL.Image.Image, np.ndarray)): + image = [image] + + if isinstance(image, list) and isinstance(image[0], PIL.Image.Image): + image = [np.array(i.convert("RGB"))[None, :] for i in image] + image = np.concatenate(image, axis=0) + elif isinstance(image, list) and isinstance(image[0], np.ndarray): + image = np.concatenate([i[None, :] for i in image], axis=0) + + image = image.transpose(0, 3, 1, 2) + image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0 + + # preprocess mask + if isinstance(mask, (PIL.Image.Image, np.ndarray)): + mask = [mask] + + if isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image): + mask = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask], axis=0) + mask = mask.astype(np.float32) / 255.0 + elif isinstance(mask, list) and isinstance(mask[0], np.ndarray): + mask = np.concatenate([m[None, None, :] for m in mask], axis=0) + + mask[mask < 0.5] = 0 + mask[mask >= 0.5] = 1 + mask = torch.from_numpy(mask) + + return mask, image + + +class StableDiffusionDiffEditPipeline(DiffusionPipeline): + r""" + Pipeline for text-guided image inpainting using Stable Diffusion using DiffEdit. *This is an experimental feature*. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. + inverse_scheduler (`Union[]`): + A scheduler to be used in combination with `unet` to fill in the unmasked part of the input latents + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. + feature_extractor ([`CLIPImageProcessor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + _optional_components = ["safety_checker", "feature_extractor"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + inverse_scheduler: DDIMInverseScheduler, + requires_safety_checker: bool = True, + ): + super().__init__() + + if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if hasattr(scheduler.config, "skip_prk_steps") and scheduler.config.skip_prk_steps is False: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration" + " `skip_prk_steps`. `skip_prk_steps` should be set to True in the configuration file. Please make" + " sure to update the config accordingly as not setting `skip_prk_steps` in the config might lead to" + " incorrect results in future versions. If you have downloaded this checkpoint from the Hugging Face" + " Hub, it would be very nice if you could open a Pull request for the" + " `scheduler/scheduler_config.json` file" + ) + deprecate("skip_prk_steps not set", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["skip_prk_steps"] = True + scheduler._internal_dict = FrozenDict(new_config) + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( + version.parse(unet.config._diffusers_version).base_version + ) < version.parse("0.9.0.dev0") + is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: + deprecation_message = ( + "The configuration file of the unet has set the default `sample_size` to smaller than" + " 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the" + " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" + " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" + " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" + " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" + " in the config might lead to incorrect results in future versions. If you have downloaded this" + " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" + " the `unet/config.json` file" + ) + deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(unet.config) + new_config["sample_size"] = 64 + unet._internal_dict = FrozenDict(new_config) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + inverse_scheduler=inverse_scheduler, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. + + When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several + steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. + + When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in + several steps. This is useful to save a large amount of memory and to allow the processing of larger images. + """ + self.vae.enable_tiling() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload + def enable_sequential_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, + text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a + `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. + Note that offloading happens on a submodule basis. Memory savings are higher than with + `enable_model_cpu_offload`, but performance is lower. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"): + from accelerate import cpu_offload + else: + raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher") + + device = torch.device(f"cuda:{gpu_id}") + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + + for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: + cpu_offload(cpu_offloaded_model, device) + + if self.safety_checker is not None: + cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_model_cpu_offload + def enable_model_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` + method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with + `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate import cpu_offload_with_hook + else: + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") + + device = torch.device(f"cuda:{gpu_id}") + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + + hook = None + for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: + _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) + + if self.safety_checker is not None: + _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) + + # We'll offload the last model manually. + self.final_offload_hook = hook + + @property + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device + def _execution_device(self): + r""" + Returns the device on which the pipeline's models will be executed. After calling + `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module + hooks. + """ + if not hasattr(self.unet, "_hf_hook"): + return self.device + for module in self.unet.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. + Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + """ + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + prompt_embeds = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is not None: + safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + else: + has_nsfw_concept = None + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents).sample + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + def check_inputs( + self, + prompt, + height, + width, + strength, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (strength is None) or (strength is not None and (strength < 0 or strength > 1)): + raise ValueError( + f"The value of `strength` should in [0.0, 1.0] but is, but is {strength} of type {type(strength)}." + ) + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + def check_mask_inputs( + self, + mask_prompt=None, + mask_negative_prompt=None, + mask_prompt_embeds=None, + mask_negative_prompt_embeds=None, + ): + if mask_prompt is not None and mask_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `mask_prompt`: {mask_prompt} and `mask_prompt_embeds`: {mask_prompt_embeds}." + " Please make sure to only forward one of the two." + ) + elif mask_prompt is None and mask_prompt_embeds is None: + raise ValueError( + "Provide either `mask_image` or `mask_prompt_embeds`. Cannot leave all both of the arguments undefined." + ) + elif mask_prompt is not None and (not isinstance(mask_prompt, str) and not isinstance(mask_prompt, list)): + raise ValueError(f"`mask_prompt` has to be of type `str` or `list` but is {type(mask_prompt)}") + + if mask_negative_prompt is not None and mask_negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `mask_negative_prompt`: {mask_negative_prompt} and `mask_negative_prompt_embeds`:" + f" {mask_negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if mask_prompt_embeds is not None and mask_negative_prompt_embeds is not None: + if mask_prompt_embeds.shape != mask_negative_prompt_embeds.shape: + raise ValueError( + "`mask_prompt_embeds` and `mask_negative_prompt_embeds` must have the same shape when passed" + f" directly, but got: `mask_prompt_embeds` {mask_prompt_embeds.shape} !=" + f" `mask_negative_prompt_embeds` {mask_negative_prompt_embeds.shape}." + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start:] + + return timesteps, num_inference_steps - t_start + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def prepare_image_latents(self, image, batch_size, dtype, device, generator=None): + if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): + raise ValueError( + f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" + ) + + image = image.to(device=device, dtype=dtype) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if isinstance(generator, list): + latents = [self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)] + latents = torch.cat(latents, dim=0) + else: + latents = self.vae.encode(image).latent_dist.sample(generator) + + latents = self.vae.config.scaling_factor * latents + + if batch_size > latents.shape[0] and batch_size % latents.shape[0] == 0: + additional_latents_per_image = batch_size // latents.shape[0] + latents = torch.cat([latents] * additional_latents_per_image, dim=0) + elif batch_size > latents.shape[0] and batch_size % latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {latents.shape[0]} to {batch_size} text prompts." + ) + else: + latents = torch.cat([latents], dim=0) + + return latents + + def get_epsilon(self, model_output: torch.Tensor, sample: torch.Tensor, timestep: int): + pred_type = self.inverse_scheduler.config.prediction_type + alpha_prod_t = self.inverse_scheduler.alphas_cumprod[timestep] + + beta_prod_t = 1 - alpha_prod_t + + if pred_type == "epsilon": + return model_output + elif pred_type == "sample": + return (sample - alpha_prod_t ** (0.5) * model_output) / beta_prod_t ** (0.5) + elif pred_type == "v_prediction": + return (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample + else: + raise ValueError( + f"prediction_type given as {pred_type} must be one of `epsilon`, `sample`, or `v_prediction`" + ) + + @torch.no_grad() + def compute_mask( + self, + prompt: Optional[Union[str, List[str]]] = None, + image: Union[torch.FloatTensor, PIL.Image.Image] = None, + negative_prompt: Optional[Union[str, List[str]]] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + mask_prompt: Optional[Union[str, List[str]]] = None, + mask_negative_prompt: Optional[Union[str, List[str]]] = None, + mask_prompt_embeds: Optional[torch.FloatTensor] = None, + mask_negative_prompt_embeds: Optional[torch.FloatTensor] = None, + num_maps_per_mask: Optional[int] = 10, + mask_encode_strength: Optional[float] = 0.5, + mask_thresholding_ratio: Optional[float] = 3.0, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + output_type: Optional[str] = "np", + ): + r""" + Function used to generate a latent mask given a mask prompt, a target prompt, and an image. + + Args: + image (`PIL.Image.Image`): + `Image`, or tensor representing an image batch which will be used for computing the mask. + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the semantic mask generation. If not defined, one has to pass + `prompt_embeds`. instead. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` + is less than `1`). + mask_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the semantic mask generation using the method in [DiffEdit: + Diffusion-Based Semantic Image Editing with Mask Guidance](https://arxiv.org/pdf/2210.11427.pdf). If + not defined, one has to pass `mask_prompt_embeds` or `mask_image` instead. + mask_negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the semantic mask generation away from using the method in [DiffEdit: + Diffusion-Based Semantic Image Editing with Mask Guidance](https://arxiv.org/pdf/2210.11427.pdf). If + not defined, one has to pass `mask_negative_prompt_embeds` or `mask_image` instead. + mask_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings to guide the semantic mask generation. Can be used to easily tweak text + inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `mask_prompt` + input argument. + mask_negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings to negatively guide the semantic mask generation. Can be used to easily + tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from + `mask_negative_prompt` input argument. + num_maps_per_mask (`int`, *optional*, defaults to 10): + The number of noise maps sampled to generate the semantic mask using the method in [DiffEdit: + Diffusion-Based Semantic Image Editing with Mask Guidance](https://arxiv.org/pdf/2210.11427.pdf). + mask_encode_strength (`float`, *optional*, defaults to 0.5): + Conceptually, the strength of the noise maps sampled to generate the semantic mask using the method in + [DiffEdit: Diffusion-Based Semantic Image Editing with Mask Guidance]( + https://arxiv.org/pdf/2210.11427.pdf). Must be between 0 and 1. + mask_thresholding_ratio (`float`, *optional*, defaults to 3.0): + AAAAAA. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + + Examples: + + Returns: + `List[PIL.Image.Image]` or `np.array`: `List[PIL.Image.Image]` if `output_type` is `"pil"`, otherwise a + `np.array`. When returning a `List[PIL.Image.Image]`, the list will consist of a batch of single-channel + binary image with dimensions `(height // self.vae_scale_factor, width // self.vae_scale_factor)`, otherwise + the `np.array` will have shape `(batch_size, 1, height // self.vae_scale_factor, width // + self.vae_scale_factor)`. + """ + + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + # 1. Check inputs + self.check_inputs( + prompt, + height, + width, + mask_encode_strength, + 1, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + ) + + self.check_mask_inputs( + mask_prompt, + mask_negative_prompt, + mask_prompt_embeds, + mask_negative_prompt_embeds, + ) + + if (num_maps_per_mask is None) or ( + num_maps_per_mask is not None and (not isinstance(num_maps_per_mask, int) or num_maps_per_mask <= 0) + ): + raise ValueError( + f"`num_maps_per_mask` has to be a positive integer but is {num_maps_per_mask} of type" + f" {type(num_maps_per_mask)}." + ) + + if mask_thresholding_ratio is None or mask_thresholding_ratio <= 0: + raise ValueError( + f"`mask_thresholding_ratio` has to be positive but is {mask_thresholding_ratio} of type" + f" {type(mask_thresholding_ratio)}." + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompts + prompt_embeds = self._encode_prompt( + prompt, + device, + num_maps_per_mask, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + ) + + mask_prompt_embeds = self._encode_prompt( + mask_prompt, + device, + num_maps_per_mask, + do_classifier_free_guidance, + mask_negative_prompt, + prompt_embeds=mask_prompt_embeds, + negative_prompt_embeds=mask_negative_prompt_embeds, + ) + + # 4. Preprocess image + image = preprocess(image) + + # 5. Set timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 6. Prepare image latents and add noise with specified strength + image_latents = self.prepare_image_latents( + image, batch_size * num_maps_per_mask, self.vae.dtype, device, generator + ) + noise = torch.randn_like(image_latents) + encode_timestep = torch.tensor(timesteps[int(mask_encode_strength * num_inference_steps)]) + latent_model_input = self.scheduler.add_noise(image_latents, noise, encode_timestep) + + latent_model_input = torch.cat([image_latents] * (4 if do_classifier_free_guidance else 2)) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, encode_timestep) + + # 7. Predict the noise residual + prompt_embeds = torch.cat([mask_prompt_embeds, prompt_embeds]) + noise_pred = self.unet(latent_model_input, encode_timestep, encoder_hidden_states=prompt_embeds).sample + + if do_classifier_free_guidance: + noise_pred_neg_mask, noise_pred_mask, noise_pred_uncond, noise_pred_target = noise_pred.chunk(4) + noise_pred_mask = noise_pred_neg_mask + guidance_scale * (noise_pred_mask - noise_pred_neg_mask) + noise_pred_target = noise_pred_uncond + guidance_scale * (noise_pred_target - noise_pred_uncond) + else: + noise_pred_mask, noise_pred_target = noise_pred.chunk(2) + + # 8. Compute the mask from the absolute difference of predicted noise residuals + # TODO: Consider smoothing mask guidance map + mask_guidance_map = ( + torch.abs(noise_pred_mask - noise_pred_target) + .reshape(batch_size, num_maps_per_mask, noise_pred_target.shape[:-3]) + .mean(1, 2) + ) + clamp_magnitude = mask_guidance_map.mean() * mask_thresholding_ratio + semantic_mask_image = mask_guidance_map.clamp(0, clamp_magnitude) / clamp_magnitude + semantic_mask_image = torch.where(semantic_mask_image <= 0.5, 0, 1).unsqueeze(1) + mask_image = semantic_mask_image.detach().clone().numpy() + + # 9. Convert to Numpy array or PIL. + if output_type == "pil": + mask_image = self.numpy_to_pil(mask_image) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + return mask_image + + @torch.no_grad() + def invert( + self, + prompt: Optional[str] = None, + image: Union[torch.FloatTensor, PIL.Image.Image] = None, + num_inference_steps: int = 50, + inpaint_strength: float = 0.8, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: Optional[int] = 1, + lambda_auto_corr: float = 20.0, + lambda_kl: float = 20.0, + num_reg_steps: int = 0, + num_auto_corr_rolls: int = 5, + ): + r""" + Function used to generate inverted latents given a prompt and image. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + image (`PIL.Image.Image`): + `Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will + be masked out with `mask_image` and repainted according to `prompt`. + inpaint_strength (`float`, *optional*, defaults to 0.8): + Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1. When `strength` + is 1, the denoising process will be run on the masked area for the full number of iterations specified + in `num_inference_steps`. `image` will be used as a reference for the masked area, adding more noise to + that region the larger the `strength`. If `strength` is 0, no inpainting will occur. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` + is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + lambda_auto_corr (`float`, *optional*, defaults to 20.0): + Lambda parameter to control auto correction + lambda_kl (`float`, *optional*, defaults to 20.0): + Lambda parameter to control Kullback–Leibler divergence output + num_reg_steps (`int`, *optional*, defaults to 0): + Number of regularization loss steps + num_auto_corr_rolls (`int`, *optional*, defaults to 5): + Number of auto correction roll steps + + Examples: + + ```py + >>> import PIL + >>> import requests + >>> import torch + >>> from io import BytesIO + + >>> from diffusers import StableDiffusionDiffEditPipeline + + + >>> def download_image(url): + ... response = requests.get(url) + ... return PIL.Image.open(BytesIO(response.content)).convert("RGB") + + >>> img_url = "https://github.com/Xiang-cd/DiffEdit-stable-diffusion/raw/main/assets/origin.png" + + >>> init_image = download_image(img_url).resize((512, 512)) + + >>> pipe = StableDiffusionDiffEditPipeline.from_pretrained( + ... "stabilityai/stable-diffusion-2-1", torch_dtype=torch.float16 + ... ) + >>> pipe = pipe.to("cuda") + + >>> pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config) + >>> pipeline.inverse_scheduler = DDIMInverseScheduler.from_config(pipeline.scheduler.config) + >>> pipeline.enable_model_cpu_offload() + + >>> prompt = "A bowl of fruits" + + >>> inverted_latents = pipe.invert(image=invert_image, prompt=mask_prompt).latents + ``` + + Returns: + [`~pipelines.stable_diffusion.pipeline_stable_diffusion_diffedit.DiffEditInversionPipelineOutput`] or + `tuple`: [`~pipelines.stable_diffusion.pipeline_stable_diffusion_diffedit.DiffEditInversionPipelineOutput`] + if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is the inverted + latents tensors ordered by increasing noise, and then second is the corresponding decoded images. + """ + # 1. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Preprocess image + image = preprocess(image) + + # 4. Prepare latent variables + num_images_per_prompt = 1 + latents = self.prepare_image_latents( + image, batch_size * num_images_per_prompt, self.vae.dtype, device, generator + ) + + # 5. Encode input prompt + prompt_embeds = self._encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + ) + + # 6. Prepare timesteps + self.inverse_scheduler.set_timesteps(num_inference_steps, device=device) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, inpaint_strength, device) + + # 7. Noising loop where we obtain the intermediate noised latent image for each timestep. + num_warmup_steps = len(timesteps) - num_inference_steps * self.inverse_scheduler.order + inverted_latents = [latents.detach().clone()] + with self.progress_bar(total=num_inference_steps - 1) as progress_bar: + for i, t in enumerate(timesteps[:-1]): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.inverse_scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # regularization of the noise prediction (not in original code or paper but borrowed from Pix2PixZero) + with torch.enable_grad(): + for _ in range(num_reg_steps): + if lambda_auto_corr > 0: + for _ in range(num_auto_corr_rolls): + var = torch.autograd.Variable(noise_pred.detach().clone(), requires_grad=True) + + # Derive epsilon from model output before regularizing to IID standard normal + var_epsilon = self.get_epsilon(var, latent_model_input.detach(), t) + + l_ac = self.auto_corr_loss(var_epsilon, generator=generator) + l_ac.backward() + + grad = var.grad.detach() / num_auto_corr_rolls + noise_pred = noise_pred - lambda_auto_corr * grad + + if lambda_kl > 0: + var = torch.autograd.Variable(noise_pred.detach().clone(), requires_grad=True) + + # Derive epsilon from model output before regularizing to IID standard normal + var_epsilon = self.get_epsilon(var, latent_model_input.detach(), t) + + l_kld = self.kl_divergence(var_epsilon) + l_kld.backward() + + grad = var.grad.detach() + noise_pred = noise_pred - lambda_kl * grad + + noise_pred = noise_pred.detach() + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.inverse_scheduler.step(noise_pred, t, latents).prev_sample + inverted_latents.append(latents.detach().clone()) + + # call the callback, if provided + if i == len(timesteps) - 1 or ( + (i + 1) > num_warmup_steps and (i + 1) % self.inverse_scheduler.order == 0 + ): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + assert len(inverted_latents) == len(timesteps) + latents = torch.stack(inverted_latents, 0)[::-1] + + # 8. Post-processing + image = self.decode_latents(latents.detach()) + + # 9. Convert to PIL. + if output_type == "pil": + image = self.numpy_to_pil(image) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + if not return_dict: + return (latents, image) + + return DiffEditInversionPipelineOutput(latents=latents, images=image) + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]] = None, + image: Union[torch.FloatTensor, PIL.Image.Image] = None, + mask_image: Optional[Union[torch.FloatTensor, PIL.Image.Image]] = None, + inverted_latents: Optional[torch.FloatTensor] = None, + num_maps_per_mask: Optional[int] = 10, + inpaint_strength: Optional[float] = 0.8, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + lambda_auto_corr: float = 20.0, + lambda_kl: float = 20.0, + num_reg_steps: int = 0, + num_auto_corr_rolls: int = 5, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + image (`PIL.Image.Image`): + `Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will + be masked out with `mask_image` and repainted according to `prompt`. + mask_image (`PIL.Image.Image`): + `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be + repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted + to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L) + instead of 3, so the expected shape would be `(B, 1, H, W)`. + inverted_latents (`torch.FloatTensor`): + Pre-generated partially noised latents inverted from `image` to be used as inputs for image generation. + inpaint_strength (`float`, *optional*, defaults to 0.8): + Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1. When `strength` + is 1, the denoising process will be run on the masked area for the full number of iterations specified + in `num_inference_steps`. `image` will be used as a reference for the masked area, adding more noise to + that region the larger the `strength`. If `strength` is 0, no inpainting will occur. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` + is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + lambda_auto_corr (`float`, *optional*, defaults to 20.0): + Lambda parameter to control auto correction + lambda_kl (`float`, *optional*, defaults to 20.0): + Lambda parameter to control Kullback–Leibler divergence output + num_reg_steps (`int`, *optional*, defaults to 0): + Number of regularization loss steps + num_auto_corr_rolls (`int`, *optional*, defaults to 5): + Number of auto correction roll steps + + Examples: + + ```py + >>> import PIL + >>> import requests + >>> import torch + >>> from io import BytesIO + + >>> from diffusers import StableDiffusionDiffEditPipeline + + + >>> def download_image(url): + ... response = requests.get(url) + ... return PIL.Image.open(BytesIO(response.content)).convert("RGB") + + >>> img_url = "https://github.com/Xiang-cd/DiffEdit-stable-diffusion/raw/main/assets/origin.png" + + >>> init_image = download_image(img_url).resize((512, 512)) + + >>> pipe = StableDiffusionDiffEditPipeline.from_pretrained( + ... "stabilityai/stable-diffusion-2-1", torch_dtype=torch.float16 + ... ) + >>> pipe = pipe.to("cuda") + + >>> pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config) + >>> pipeline.inverse_scheduler = DDIMInverseScheduler.from_config(pipeline.scheduler.config) + >>> pipeline.enable_model_cpu_offload() + + >>> mask_prompt = "A bowl of fruits" + >>> prompt = "A bowl of pears" + + >>> mask_image = pipe.compute_mask(image=init_image, prompt=prompt, mask_prompt=mask_prompt) + >>> inverted_latents = pipe.invert(image=invert_image, prompt=mask_prompt).latents + >>> image = pipe( + ... prompt=prompt, image=init_image, mask_image=mask_image, inverted_latents=inverted_latents + ... ).images[0] + ``` + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + # 1. Check inputs + self.check_inputs( + prompt, + height, + width, + inpaint_strength, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + ) + + if image is None: + raise ValueError("`image` input cannot be undefined.") + if mask_image is None: + raise ValueError( + "`mask_image` input cannot be undefined. Use `compute_mask()` to compute `mask_image` from text prompts." + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds = self._encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + ) + + # 4. Preprocess mask and image + image, mask_image = preprocess_image_and_mask(image, mask_image) + + mask_image = torch.cat([mask_image] * num_images_per_prompt) + vae_latent_size = (height // self.vae_scale_factor, width // self.vae_scale_factor) + mask_image = torch.nn.functional.interpolate(mask_image, size=vae_latent_size) + mask_image = mask_image.to(device=device, dtype=prompt_embeds.dtype) + + # 5. Set timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + self.inverse_scheduler.set_timesteps(num_inference_steps, device=device) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, inpaint_strength, device) + + # 6. Prepare latent variables + num_channels_latents = self.vae.config.latent_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 7. Preprocess inverted latents + inverted_latents_shape = (len(timesteps), batch_size, num_channels_latents, *vae_latent_size) + if inverted_latents is None: + raise ValueError( + "`inverted_latents` input cannot be undefined. Use `invert()` to compute `inverted_latents`." + ) + elif inverted_latents.shape != inverted_latents_shape: + raise ValueError( + f"`inverted_latents` must have shape {inverted_latents_shape}, but has shape {inverted_latents.shape}" + ) + if isinstance(inverted_latents, np.ndarray): + inverted_latents = torch.from_numpy(inverted_latents) + inverted_latents = inverted_latents.to(device=device, dtype=prompt_embeds.dtype) + + # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 9. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.inverse_scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # mask with inverted latents from appropriate timestep - use original image latent for last step + latents = latents * mask_image + inverted_latents[i] * (1 - mask_image) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + # 10. Post-processing + image = self.decode_latents(latents) + + # 11. Run safety checker + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + + # 12. Convert to PIL + if output_type == "pil": + image = self.numpy_to_pil(image) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) From a37b6e061e5015ccd974d533f96b5a966c66594c Mon Sep 17 00:00:00 2001 From: Clarence Chen Date: Mon, 27 Mar 2023 17:21:48 -0700 Subject: [PATCH 03/41] Add draft documentation and import code --- docs/source/en/_toctree.yml | 2 + .../pipelines/stable_diffusion/diffedit.mdx | 256 ++++++++++++++++++ src/diffusers/__init__.py | 1 + src/diffusers/pipelines/__init__.py | 1 + .../pipelines/stable_diffusion/__init__.py | 2 + .../pipeline_stable_diffusion_diffedit.py | 2 +- 6 files changed, 263 insertions(+), 1 deletion(-) create mode 100644 docs/source/en/api/pipelines/stable_diffusion/diffedit.mdx diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index e6ec96c3a3d9..4bc1283e076a 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -191,6 +191,8 @@ title: MultiDiffusion Panorama - local: api/pipelines/stable_diffusion/controlnet title: Text-to-Image Generation with ControlNet Conditioning + - local: api/pipelines/stable_diffusion/diffedit + title: DiffEdit title: Stable Diffusion - local: api/pipelines/stable_diffusion_2 title: Stable Diffusion 2 diff --git a/docs/source/en/api/pipelines/stable_diffusion/diffedit.mdx b/docs/source/en/api/pipelines/stable_diffusion/diffedit.mdx new file mode 100644 index 000000000000..95ee0bfc02c0 --- /dev/null +++ b/docs/source/en/api/pipelines/stable_diffusion/diffedit.mdx @@ -0,0 +1,256 @@ + + +# Zero-shot Diffusion-based Semantic Image Editing with Mask Guidance + +## Overview + +[DiffEdit: Diffusion-based semantic image editing with mask guidance](https://arxiv.org/abs/2210.11427). + +The abstract of the paper is the following: + +*Image generation has recently seen tremendous advances, with diffusion models allowing to synthesize convincing images for a large variety of text prompts. In this article, we propose DiffEdit, a method to take advantage of text-conditioned diffusion models for the task of semantic image editing, where the goal is to edit an image based on a text query. Semantic image editing is an extension of image generation, with the additional constraint that the generated image should be as similar as possible to a given input image. Current editing methods based on diffusion models usually require to provide a mask, making the task much easier by treating it as a conditional inpainting task. In contrast, our main contribution is able to automatically generate a mask highlighting regions of the input image that need to be edited, by contrasting predictions of a diffusion model conditioned on different text prompts. Moreover, we rely on latent inference to preserve content in those regions of interest and show excellent synergies with mask-based diffusion. DiffEdit achieves state-of-the-art editing performance on ImageNet. In addition, we evaluate semantic image editing in more challenging settings, using images from the COCO dataset as well as text-based generated images.* + +Resources: + +* [Project Page](https://pix2pixzero.github.io/). +* [Paper](https://arxiv.org/abs/2210.11427). +* [Blog Post with Demo](https://blog.problemsolversguild.com/technical/research/2022/11/02/DiffEdit-Implementation.html). +* [Implementation on Github](https://github.com/Xiang-cd/DiffEdit-stable-diffusion/blob/main/assets/origin.png). + +## Tips + +* The pipeline can be conditioned on real input images. Check out the code examples below to know more. +* The pipeline exposes two arguments namely `source_embeds` and `target_embeds` +that let you control the direction of the semantic edits in the final image to be generated. Let's say, +you wanted to translate from "cat" to "dog". In this case, the edit direction will be "cat -> dog". To reflect +this in the pipeline, you simply have to set the embeddings related to the phrases including "cat" to +`source_embeds` and "dog" to `target_embeds`. Refer to the code example below for more details. +* When you're using this pipeline from a prompt, specify the _source_ concept in the prompt. Taking +the above example, a valid input prompt would be: "a high resolution painting of a **cat** in the style of van gough". +* If you wanted to reverse the direction in the example above, i.e., "dog -> cat", then it's recommended to: + * Swap the `source_embeds` and `target_embeds`. + * Change the input prompt to include "dog". +* To learn more about how the source and target embeddings are generated, refer to the [original +paper](https://arxiv.org/abs/2302.03027). Below, we also provide some directions on how to generate the embeddings. +* Note that the quality of the outputs generated with this pipeline is dependent on how good the `source_embeds` and `target_embeds` are. Please, refer to [this discussion](#generating-source-and-target-embeddings) for some suggestions on the topic. + +## Available Pipelines: + +| Pipeline | Tasks +|---|---| +| [StableDiffusionDiffEditPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py) | *Text-Based Image Editing* + + + +## Usage example + +### Based on an input image + +When the pipeline is conditioned on an input image, we first obtain an inverted +noise from it using a `DDIMInverseScheduler` with the help of a generated caption. Then +the inverted noise is used to start the generation process. + +First, let's load our pipeline: + +```py +import torch +from transformers import BlipForConditionalGeneration, BlipProcessor +from diffusers import DDIMScheduler, DDIMInverseScheduler, StableDiffusionPix2PixZeroPipeline + +captioner_id = "Salesforce/blip-image-captioning-base" +processor = BlipProcessor.from_pretrained(captioner_id) +model = BlipForConditionalGeneration.from_pretrained(captioner_id, torch_dtype=torch.float16, low_cpu_mem_usage=True) + +sd_model_ckpt = "CompVis/stable-diffusion-2-1" +pipeline = StableDiffusionPix2PixZeroPipeline.from_pretrained( + sd_model_ckpt, + caption_generator=model, + caption_processor=processor, + torch_dtype=torch.float16, + safety_checker=None, +) +pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config) +pipeline.inverse_scheduler = DDIMInverseScheduler.from_config(pipeline.scheduler.config) +pipeline.enable_model_cpu_offload() +``` + +Then, we load an input image for conditioning and obtain a suitable caption for it: + +```py +import requests +from PIL import Image + +img_url = "https://github.com/pix2pixzero/pix2pix-zero/raw/main/assets/test_images/cats/cat_6.png" +raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB").resize((512, 512)) +caption = pipeline.generate_caption(raw_image) +``` + +Then we employ the generated caption and the input image to get the inverted latents: + +```py +generator = torch.manual_seed(0) +inv_latents = pipeline.invert(caption, image=raw_image, generator=generator).latents +``` + +Then we employ the source and target prompts to generate the editing mask: + +```py +# See the "Generating source and target embeddings" section below to +# automate the generation of these captions with a pre-trained model like Flan-T5 as explained below. +source_prompts = ["a cat sitting on the street", "a cat playing in the field", "a face of a cat"] +target_prompts = ["a dog sitting on the street", "a dog playing in the field", "a face of a dog"] + +source_embeds = pipeline.get_embeds(source_prompts, batch_size=2) +target_embeds = pipeline.get_embeds(target_prompts, batch_size=2) +mask_image = pipeline.compute_mask( + image=raw_image, + prompt_embeds=target_embeds, + mask_prompt_embeds=source_embeds, + generator=generator, +) +``` + +Now, generate the image with the inverted latents and semantically generated mask: + +```py +image = pipeline( + prompt_embeds=target_embeds, + num_inference_steps=50, + generator=generator, + inverted_latents=inv_latents, + mask_image=mask_image, + negative_prompt=caption, +).images[0] +image.save("edited_image.png") +``` + +## Generating source and target embeddings + +The authors originally used the [GPT-3 API](https://openai.com/api/) to generate the source and target captions for discovering +edit directions. However, we can also leverage open source and public models for the same purpose. +Below, we provide an end-to-end example with the [Flan-T5](https://huggingface.co/docs/transformers/model_doc/flan-t5) model +for generating captions and [CLIP](https://huggingface.co/docs/transformers/model_doc/clip) for +computing embeddings on the generated captions. + +**1. Load the generation model**: + +```py +import torch +from transformers import AutoTokenizer, T5ForConditionalGeneration + +tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-xl") +model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-xl", device_map="auto", torch_dtype=torch.float16) +``` + +**2. Construct a starting prompt**: + +```py +source_concept = "cat" +target_concept = "dog" + +source_text = f"Provide a caption for images containing a {source_concept}. " +"The captions should be in English and should be no longer than 150 characters." + +target_text = f"Provide a caption for images containing a {target_concept}. " +"The captions should be in English and should be no longer than 150 characters." +``` + +Here, we're interested in the "cat -> dog" direction. + +**3. Generate captions**: + +We can use a utility like so for this purpose. + +```py +def generate_captions(input_prompt): + input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids.to("cuda") + + outputs = model.generate( + input_ids, temperature=0.8, num_return_sequences=16, do_sample=True, max_new_tokens=128, top_k=10 + ) + return tokenizer.batch_decode(outputs, skip_special_tokens=True) +``` + +And then we just call it to generate our captions: + +```py +source_captions = generate_captions(source_text) +target_captions = generate_captions(target_concept) +``` + +We encourage you to play around with the different parameters supported by the +`generate()` method ([documentation](https://huggingface.co/docs/transformers/main/en/main_classes/text_generation#transformers.generation_tf_utils.TFGenerationMixin.generate)) for the generation quality you are looking for. + +**4. Load the embedding model**: + +Here, we need to use the same text encoder model used by the subsequent Stable Diffusion model. + +```py +from diffusers import StableDiffusionPix2PixZeroPipeline + +pipeline = StableDiffusionPix2PixZeroPipeline.from_pretrained( + "CompVis/stable-diffusion-v1-4", torch_dtype=torch.float16 +) +pipeline = pipeline.to("cuda") +tokenizer = pipeline.tokenizer +text_encoder = pipeline.text_encoder +``` + +**5. Compute embeddings**: + +```py +import torch + +def embed_captions(sentences, tokenizer, text_encoder, device="cuda"): + with torch.no_grad(): + embeddings = [] + for sent in sentences: + text_inputs = tokenizer( + sent, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + prompt_embeds = text_encoder(text_input_ids.to(device), attention_mask=None)[0] + embeddings.append(prompt_embeds) + return torch.concatenate(embeddings, dim=0).mean(dim=0).unsqueeze(0) + +source_embeddings = embed_captions(source_captions, tokenizer, text_encoder) +target_embeddings = embed_captions(target_captions, tokenizer, text_encoder) +``` + +And you're done! [Here](https://colab.research.google.com/drive/1tz2C1EdfZYAPlzXXbTnf-5PRBiR8_R1F?usp=sharing) is a Colab Notebook that you can use to interact with the entire process. + +Now, you can use these embeddings directly while calling the pipeline: + +```py +from diffusers import DDIMScheduler + +pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config) + +images = pipeline( + prompt, + source_embeds=source_embeddings, + target_embeds=target_embeddings, + num_inference_steps=50, + cross_attention_guidance_amount=0.15, +).images +images[0].save("edited_image_dog.png") +``` + +## StableDiffusionDiffEditPipeline +[[autodoc]] StableDiffusionDiffEditPipeline + - __call__ + - all diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index f0597f9d61c8..1190fafa6bb2 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -120,6 +120,7 @@ StableDiffusionAttendAndExcitePipeline, StableDiffusionControlNetPipeline, StableDiffusionDepth2ImgPipeline, + StableDiffusionDiffEditPipeline, StableDiffusionImageVariationPipeline, StableDiffusionImg2ImgPipeline, StableDiffusionInpaintPipeline, diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 31d748ced8e8..70da670630bd 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -53,6 +53,7 @@ StableDiffusionAttendAndExcitePipeline, StableDiffusionControlNetPipeline, StableDiffusionDepth2ImgPipeline, + StableDiffusionDiffEditPipeline, StableDiffusionImageVariationPipeline, StableDiffusionImg2ImgPipeline, StableDiffusionInpaintPipeline, diff --git a/src/diffusers/pipelines/stable_diffusion/__init__.py b/src/diffusers/pipelines/stable_diffusion/__init__.py index b386ab04c167..a32733a37117 100644 --- a/src/diffusers/pipelines/stable_diffusion/__init__.py +++ b/src/diffusers/pipelines/stable_diffusion/__init__.py @@ -74,10 +74,12 @@ class StableDiffusionPipelineOutput(BaseOutput): except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_objects import ( StableDiffusionDepth2ImgPipeline, + StableDiffusionDiffEditPipeline, StableDiffusionPix2PixZeroPipeline, ) else: from .pipeline_stable_diffusion_depth2img import StableDiffusionDepth2ImgPipeline + from .pipeline_stable_diffusion_diffedit import StableDiffusionDiffEditPipeline from .pipeline_stable_diffusion_pix2pix_zero import StableDiffusionPix2PixZeroPipeline diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py index 110822e9fb1b..91b1a681fd3d 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py @@ -232,7 +232,7 @@ class StableDiffusionDiffEditPipeline(DiffusionPipeline): feature_extractor ([`CLIPImageProcessor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ - _optional_components = ["safety_checker", "feature_extractor"] + _optional_components = ["safety_checker", "feature_extractor", "inverse_scheduler"] def __init__( self, From 1b8753b508a9777466e27a11c0299c91832c7312 Mon Sep 17 00:00:00 2001 From: Clarence Chen Date: Mon, 27 Mar 2023 00:14:02 -0700 Subject: [PATCH 04/41] Bugfixes and refactoring --- .../pipeline_stable_diffusion_diffedit.py | 436 ++++++++---------- 1 file changed, 185 insertions(+), 251 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py index 91b1a681fd3d..56ec589b44c0 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py @@ -81,7 +81,7 @@ def kl_divergence(hidden_states): # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess -def preprocess(image): +def preprocess_image(image): if isinstance(image, torch.Tensor): return image elif isinstance(image, PIL.Image.Image): @@ -102,93 +102,10 @@ def preprocess(image): return image -def preprocess_image_and_mask(image, mask): - """ - Prepares a pair (image, mask) to be consumed by the Stable Diffusion pipeline. This means that those inputs will be - converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the - ``image`` and ``1`` for the ``mask``. - - The ``image`` will be converted to ``torch.float32`` and normalized to be in ``[-1, 1]``. The ``mask`` will be - binarized (``mask > 0.5``) and cast to ``torch.float32`` too. - - Args: - image (Union[np.array, PIL.Image, torch.Tensor]): The image to inpaint. - It can be a ``PIL.Image``, or a ``height x width x 3`` ``np.array`` or a ``channels x height x width`` - ``torch.Tensor`` or a ``batch x channels x height x width`` ``torch.Tensor``. - mask (Union[np.array, PIL.Image, torch.Tensor], *optional*): - The mask to apply to the image, i.e. regions to inpaint. It can be a ``PIL.Image``, or a ``height x width`` - ``np.array`` or a ``1 x height x width`` ``torch.Tensor`` or a ``batch x 1 x height x width`` - ``torch.Tensor``. - - Raises: - ValueError: ``torch.Tensor`` images should be in the ``[-1, 1]`` range. ValueError: ``torch.Tensor`` mask - should be in the ``[0, 1]`` range. ValueError: ``mask`` and ``image`` should have the same spatial dimensions. - TypeError: ``mask`` is a ``torch.Tensor`` but ``image`` is not - (ot the other way around). - - Returns: - tuple[torch.Tensor]: The pair (mask, masked_image) as ``torch.Tensor`` with 4 - dimensions: ``batch x channels x height x width``. - """ - if isinstance(image, torch.Tensor): - if not isinstance(mask, torch.Tensor): - raise TypeError(f"`image` is a torch.Tensor but `mask` (type: {type(mask)} is not") - - # Batch single image - if image.ndim == 3: - assert image.shape[0] == 3, "Image outside a batch should be of shape (3, H, W)" - image = image.unsqueeze(0) - - # Batch and add channel dim for single mask - if mask.ndim == 2: - mask = mask.unsqueeze(0).unsqueeze(0) - - # Batch single mask or add channel dim - if mask.ndim == 3: - # Single batched mask, no channel dim or single mask not batched but channel dim - if mask.shape[0] == 1: - mask = mask.unsqueeze(0) - - # Batched masks no channel dim - else: - mask = mask.unsqueeze(1) - - assert image.ndim == 4 and mask.ndim == 4, "Image and Mask must have 4 dimensions" - assert image.shape[-2:] == mask.shape[-2:], "Image and Mask must have the same spatial dimensions" - assert image.shape[0] == mask.shape[0], "Image and Mask must have the same batch size" - - # Check image is in [-1, 1] - if image.min() < -1 or image.max() > 1: - raise ValueError("Image should be in [-1, 1] range") - - # Check mask is in [0, 1] - if mask.min() < 0 or mask.max() > 1: - raise ValueError("Mask should be in [0, 1] range") - - # Binarize mask - mask[mask < 0.5] = 0 - mask[mask >= 0.5] = 1 - - # Image as float32 - image = image.to(dtype=torch.float32) - elif isinstance(mask, torch.Tensor): - raise TypeError(f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not") - else: - # preprocess image - if isinstance(image, (PIL.Image.Image, np.ndarray)): - image = [image] - - if isinstance(image, list) and isinstance(image[0], PIL.Image.Image): - image = [np.array(i.convert("RGB"))[None, :] for i in image] - image = np.concatenate(image, axis=0) - elif isinstance(image, list) and isinstance(image[0], np.ndarray): - image = np.concatenate([i[None, :] for i in image], axis=0) - - image = image.transpose(0, 3, 1, 2) - image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0 - +def preprocess_mask(mask): + if not isinstance(mask, torch.Tensor): # preprocess mask - if isinstance(mask, (PIL.Image.Image, np.ndarray)): + if isinstance(mask, PIL.Image.Image) or (isinstance(mask, np.ndarray) and mask.ndim < 3): mask = [mask] if isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image): @@ -197,11 +114,31 @@ def preprocess_image_and_mask(image, mask): elif isinstance(mask, list) and isinstance(mask[0], np.ndarray): mask = np.concatenate([m[None, None, :] for m in mask], axis=0) - mask[mask < 0.5] = 0 - mask[mask >= 0.5] = 1 mask = torch.from_numpy(mask) - return mask, image + # Batch and add channel dim for single mask + if mask.ndim == 2: + mask = mask.unsqueeze(0).unsqueeze(0) + + # Batch single mask or add channel dim + if mask.ndim == 3: + # Single batched mask, no channel dim or single mask not batched but channel dim + if mask.shape[0] == 1: + mask = mask.unsqueeze(0) + + # Batched masks no channel dim + else: + mask = mask.unsqueeze(1) + + # Check mask is in [0, 1] + if mask.min() < 0 or mask.max() > 1: + raise ValueError("Mask should be in [0, 1] range") + + # Binarize mask + mask[mask < 0.5] = 0 + mask[mask >= 0.5] = 1 + + return mask class StableDiffusionDiffEditPipeline(DiffusionPipeline): @@ -666,37 +603,39 @@ def check_inputs( f" {negative_prompt_embeds.shape}." ) - def check_mask_inputs( + def check_source_inputs( self, - mask_prompt=None, - mask_negative_prompt=None, - mask_prompt_embeds=None, - mask_negative_prompt_embeds=None, + source_prompt=None, + source_negative_prompt=None, + source_prompt_embeds=None, + source_negative_prompt_embeds=None, ): - if mask_prompt is not None and mask_prompt_embeds is not None: + if source_prompt is not None and source_prompt_embeds is not None: raise ValueError( - f"Cannot forward both `mask_prompt`: {mask_prompt} and `mask_prompt_embeds`: {mask_prompt_embeds}." + f"Cannot forward both `source_prompt`: {source_prompt} and `source_prompt_embeds`: {source_prompt_embeds}." " Please make sure to only forward one of the two." ) - elif mask_prompt is None and mask_prompt_embeds is None: + elif source_prompt is None and source_prompt_embeds is None: raise ValueError( - "Provide either `mask_image` or `mask_prompt_embeds`. Cannot leave all both of the arguments undefined." + "Provide either `source_image` or `source_prompt_embeds`. Cannot leave all both of the arguments undefined." ) - elif mask_prompt is not None and (not isinstance(mask_prompt, str) and not isinstance(mask_prompt, list)): - raise ValueError(f"`mask_prompt` has to be of type `str` or `list` but is {type(mask_prompt)}") + elif source_prompt is not None and ( + not isinstance(source_prompt, str) and not isinstance(source_prompt, list) + ): + raise ValueError(f"`source_prompt` has to be of type `str` or `list` but is {type(source_prompt)}") - if mask_negative_prompt is not None and mask_negative_prompt_embeds is not None: + if source_negative_prompt is not None and source_negative_prompt_embeds is not None: raise ValueError( - f"Cannot forward both `mask_negative_prompt`: {mask_negative_prompt} and `mask_negative_prompt_embeds`:" - f" {mask_negative_prompt_embeds}. Please make sure to only forward one of the two." + f"Cannot forward both `source_negative_prompt`: {source_negative_prompt} and `source_negative_prompt_embeds`:" + f" {source_negative_prompt_embeds}. Please make sure to only forward one of the two." ) - if mask_prompt_embeds is not None and mask_negative_prompt_embeds is not None: - if mask_prompt_embeds.shape != mask_negative_prompt_embeds.shape: + if source_prompt_embeds is not None and source_negative_prompt_embeds is not None: + if source_prompt_embeds.shape != source_negative_prompt_embeds.shape: raise ValueError( - "`mask_prompt_embeds` and `mask_negative_prompt_embeds` must have the same shape when passed" - f" directly, but got: `mask_prompt_embeds` {mask_prompt_embeds.shape} !=" - f" `mask_negative_prompt_embeds` {mask_negative_prompt_embeds.shape}." + "`source_prompt_embeds` and `source_negative_prompt_embeds` must have the same shape when passed" + f" directly, but got: `source_prompt_embeds` {source_prompt_embeds.shape} !=" + f" `source_negative_prompt_embeds` {source_negative_prompt_embeds.shape}." ) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps @@ -709,6 +648,15 @@ def get_timesteps(self, num_inference_steps, strength, device): return timesteps, num_inference_steps - t_start + def get_inverse_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.inverse_scheduler.timesteps[:-t_start] + + return timesteps, num_inference_steps - t_start + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) @@ -781,15 +729,15 @@ def get_epsilon(self, model_output: torch.Tensor, sample: torch.Tensor, timestep @torch.no_grad() def compute_mask( self, - prompt: Optional[Union[str, List[str]]] = None, image: Union[torch.FloatTensor, PIL.Image.Image] = None, - negative_prompt: Optional[Union[str, List[str]]] = None, - prompt_embeds: Optional[torch.FloatTensor] = None, - negative_prompt_embeds: Optional[torch.FloatTensor] = None, - mask_prompt: Optional[Union[str, List[str]]] = None, - mask_negative_prompt: Optional[Union[str, List[str]]] = None, - mask_prompt_embeds: Optional[torch.FloatTensor] = None, - mask_negative_prompt_embeds: Optional[torch.FloatTensor] = None, + target_prompt: Optional[Union[str, List[str]]] = None, + target_negative_prompt: Optional[Union[str, List[str]]] = None, + target_prompt_embeds: Optional[torch.FloatTensor] = None, + target_negative_prompt_embeds: Optional[torch.FloatTensor] = None, + source_prompt: Optional[Union[str, List[str]]] = None, + source_negative_prompt: Optional[Union[str, List[str]]] = None, + source_prompt_embeds: Optional[torch.FloatTensor] = None, + source_negative_prompt_embeds: Optional[torch.FloatTensor] = None, num_maps_per_mask: Optional[int] = 10, mask_encode_strength: Optional[float] = 0.5, mask_thresholding_ratio: Optional[float] = 3.0, @@ -806,29 +754,36 @@ def compute_mask( Args: image (`PIL.Image.Image`): `Image`, or tensor representing an image batch which will be used for computing the mask. - prompt (`str` or `List[str]`, *optional*): + target_prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the semantic mask generation. If not defined, one has to pass `prompt_embeds`. instead. - negative_prompt (`str` or `List[str]`, *optional*): + target_negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). - mask_prompt (`str` or `List[str]`, *optional*): + target_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + target_negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + source_prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the semantic mask generation using the method in [DiffEdit: Diffusion-Based Semantic Image Editing with Mask Guidance](https://arxiv.org/pdf/2210.11427.pdf). If - not defined, one has to pass `mask_prompt_embeds` or `mask_image` instead. - mask_negative_prompt (`str` or `List[str]`, *optional*): + not defined, one has to pass `source_prompt_embeds` or `source_image` instead. + source_negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the semantic mask generation away from using the method in [DiffEdit: Diffusion-Based Semantic Image Editing with Mask Guidance](https://arxiv.org/pdf/2210.11427.pdf). If - not defined, one has to pass `mask_negative_prompt_embeds` or `mask_image` instead. - mask_prompt_embeds (`torch.FloatTensor`, *optional*): + not defined, one has to pass `source_negative_prompt_embeds` or `source_image` instead. + source_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings to guide the semantic mask generation. Can be used to easily tweak text - inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `mask_prompt` - input argument. - mask_negative_prompt_embeds (`torch.FloatTensor`, *optional*): + inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from + `source_prompt` input argument. + source_negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings to negatively guide the semantic mask generation. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from - `mask_negative_prompt` input argument. + `source_negative_prompt` input argument. num_maps_per_mask (`int`, *optional*, defaults to 10): The number of noise maps sampled to generate the semantic mask using the method in [DiffEdit: Diffusion-Based Semantic Image Editing with Mask Guidance](https://arxiv.org/pdf/2210.11427.pdf). @@ -854,13 +809,6 @@ def compute_mask( generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. - prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - negative_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input - argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. @@ -871,7 +819,7 @@ def compute_mask( `List[PIL.Image.Image]` or `np.array`: `List[PIL.Image.Image]` if `output_type` is `"pil"`, otherwise a `np.array`. When returning a `List[PIL.Image.Image]`, the list will consist of a batch of single-channel binary image with dimensions `(height // self.vae_scale_factor, width // self.vae_scale_factor)`, otherwise - the `np.array` will have shape `(batch_size, 1, height // self.vae_scale_factor, width // + the `np.array` will have shape `(batch_size, height // self.vae_scale_factor, width // self.vae_scale_factor)`. """ @@ -881,21 +829,21 @@ def compute_mask( # 1. Check inputs self.check_inputs( - prompt, + target_prompt, height, width, mask_encode_strength, 1, - negative_prompt, - prompt_embeds, - negative_prompt_embeds, + target_negative_prompt, + target_prompt_embeds, + target_negative_prompt_embeds, ) - self.check_mask_inputs( - mask_prompt, - mask_negative_prompt, - mask_prompt_embeds, - mask_negative_prompt_embeds, + self.check_source_inputs( + source_prompt, + source_negative_prompt, + source_prompt_embeds, + source_negative_prompt_embeds, ) if (num_maps_per_mask is None) or ( @@ -913,12 +861,12 @@ def compute_mask( ) # 2. Define call parameters - if prompt is not None and isinstance(prompt, str): + if target_prompt is not None and isinstance(target_prompt, str): batch_size = 1 - elif prompt is not None and isinstance(prompt, list): - batch_size = len(prompt) + elif target_prompt is not None and isinstance(target_prompt, list): + batch_size = len(target_prompt) else: - batch_size = prompt_embeds.shape[0] + batch_size = target_prompt_embeds.shape[0] device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) @@ -927,66 +875,66 @@ def compute_mask( do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompts - prompt_embeds = self._encode_prompt( - prompt, + target_prompt_embeds = self._encode_prompt( + target_prompt, device, num_maps_per_mask, do_classifier_free_guidance, - negative_prompt, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, + target_negative_prompt, + prompt_embeds=target_prompt_embeds, + negative_prompt_embeds=target_negative_prompt_embeds, ) - mask_prompt_embeds = self._encode_prompt( - mask_prompt, + source_prompt_embeds = self._encode_prompt( + source_prompt, device, num_maps_per_mask, do_classifier_free_guidance, - mask_negative_prompt, - prompt_embeds=mask_prompt_embeds, - negative_prompt_embeds=mask_negative_prompt_embeds, + source_negative_prompt, + prompt_embeds=source_prompt_embeds, + negative_prompt_embeds=source_negative_prompt_embeds, ) # 4. Preprocess image - image = preprocess(image) + image = preprocess_image(image) # 5. Set timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) - timesteps = self.scheduler.timesteps + timesteps, _ = self.get_timesteps(num_inference_steps, mask_encode_strength, device) + encode_timestep = timesteps[0] # 6. Prepare image latents and add noise with specified strength image_latents = self.prepare_image_latents( image, batch_size * num_maps_per_mask, self.vae.dtype, device, generator ) noise = torch.randn_like(image_latents) - encode_timestep = torch.tensor(timesteps[int(mask_encode_strength * num_inference_steps)]) - latent_model_input = self.scheduler.add_noise(image_latents, noise, encode_timestep) + image_latents = self.scheduler.add_noise(image_latents, noise, encode_timestep) latent_model_input = torch.cat([image_latents] * (4 if do_classifier_free_guidance else 2)) latent_model_input = self.scheduler.scale_model_input(latent_model_input, encode_timestep) # 7. Predict the noise residual - prompt_embeds = torch.cat([mask_prompt_embeds, prompt_embeds]) + prompt_embeds = torch.cat([source_prompt_embeds, target_prompt_embeds]) noise_pred = self.unet(latent_model_input, encode_timestep, encoder_hidden_states=prompt_embeds).sample if do_classifier_free_guidance: - noise_pred_neg_mask, noise_pred_mask, noise_pred_uncond, noise_pred_target = noise_pred.chunk(4) - noise_pred_mask = noise_pred_neg_mask + guidance_scale * (noise_pred_mask - noise_pred_neg_mask) + noise_pred_neg_src, noise_pred_source, noise_pred_uncond, noise_pred_target = noise_pred.chunk(4) + noise_pred_source = noise_pred_neg_src + guidance_scale * (noise_pred_source - noise_pred_neg_src) noise_pred_target = noise_pred_uncond + guidance_scale * (noise_pred_target - noise_pred_uncond) else: - noise_pred_mask, noise_pred_target = noise_pred.chunk(2) + noise_pred_source, noise_pred_target = noise_pred.chunk(2) # 8. Compute the mask from the absolute difference of predicted noise residuals # TODO: Consider smoothing mask guidance map mask_guidance_map = ( - torch.abs(noise_pred_mask - noise_pred_target) - .reshape(batch_size, num_maps_per_mask, noise_pred_target.shape[:-3]) - .mean(1, 2) + torch.abs(noise_pred_target - noise_pred_source) + .reshape(batch_size, num_maps_per_mask, *noise_pred_target.shape[-3:]) + .mean([1, 2]) ) clamp_magnitude = mask_guidance_map.mean() * mask_thresholding_ratio semantic_mask_image = mask_guidance_map.clamp(0, clamp_magnitude) / clamp_magnitude - semantic_mask_image = torch.where(semantic_mask_image <= 0.5, 0, 1).unsqueeze(1) - mask_image = semantic_mask_image.detach().clone().numpy() + semantic_mask_image = torch.where(semantic_mask_image <= 0.5, 0, 1) + mask_image = semantic_mask_image.cpu().numpy() # 9. Convert to Numpy array or PIL. if output_type == "pil": @@ -1097,7 +1045,7 @@ def invert( >>> img_url = "https://github.com/Xiang-cd/DiffEdit-stable-diffusion/raw/main/assets/origin.png" - >>> init_image = download_image(img_url).resize((512, 512)) + >>> init_image = download_image(img_url).resize((768, 768)) >>> pipe = StableDiffusionDiffEditPipeline.from_pretrained( ... "stabilityai/stable-diffusion-2-1", torch_dtype=torch.float16 @@ -1119,7 +1067,23 @@ def invert( if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is the inverted latents tensors ordered by increasing noise, and then second is the corresponding decoded images. """ - # 1. Define call parameters + + # 1. Check inputs + self.check_inputs( + prompt, + 0, + 0, + inpaint_strength, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + ) + + if image is None: + raise ValueError("`image` input cannot be undefined.") + + # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): @@ -1134,7 +1098,7 @@ def invert( do_classifier_free_guidance = guidance_scale > 1.0 # 3. Preprocess image - image = preprocess(image) + image = preprocess_image(image) # 4. Prepare latent variables num_images_per_prompt = 1 @@ -1155,7 +1119,7 @@ def invert( # 6. Prepare timesteps self.inverse_scheduler.set_timesteps(num_inference_steps, device=device) - timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, inpaint_strength, device) + timesteps, num_inference_steps = self.get_inverse_timesteps(num_inference_steps, inpaint_strength, device) # 7. Noising loop where we obtain the intermediate noised latent image for each timestep. num_warmup_steps = len(timesteps) - num_inference_steps * self.inverse_scheduler.order @@ -1217,7 +1181,7 @@ def invert( callback(i, t, latents) assert len(inverted_latents) == len(timesteps) - latents = torch.stack(inverted_latents, 0)[::-1] + latents = torch.cat(list(reversed(inverted_latents))) # 8. Post-processing image = self.decode_latents(latents.detach()) @@ -1239,10 +1203,8 @@ def invert( def __call__( self, prompt: Union[str, List[str]] = None, - image: Union[torch.FloatTensor, PIL.Image.Image] = None, - mask_image: Optional[Union[torch.FloatTensor, PIL.Image.Image]] = None, - inverted_latents: Optional[torch.FloatTensor] = None, - num_maps_per_mask: Optional[int] = 10, + mask: Union[torch.FloatTensor, PIL.Image.Image] = None, + image_latents: torch.FloatTensor = None, inpaint_strength: Optional[float] = 0.8, height: Optional[int] = None, width: Optional[int] = None, @@ -1259,10 +1221,6 @@ def __call__( return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, - lambda_auto_corr: float = 20.0, - lambda_kl: float = 20.0, - num_reg_steps: int = 0, - num_auto_corr_rolls: int = 5, ): r""" Function invoked when calling the pipeline for generation. @@ -1271,21 +1229,18 @@ def __call__( prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. - image (`PIL.Image.Image`): - `Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will - be masked out with `mask_image` and repainted according to `prompt`. - mask_image (`PIL.Image.Image`): - `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be - repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted + mask (`PIL.Image.Image`): + `Image`, or tensor representing an image batch, to mask the generated image. White pixels in the mask + will be repainted, while black pixels will be preserved. If `mask` is a PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L) instead of 3, so the expected shape would be `(B, 1, H, W)`. - inverted_latents (`torch.FloatTensor`): - Pre-generated partially noised latents inverted from `image` to be used as inputs for image generation. + image_latents (`torch.FloatTensor`): + Partially noised image latents from the inversion process to be used as inputs for image generation. inpaint_strength (`float`, *optional*, defaults to 0.8): Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1. When `strength` is 1, the denoising process will be run on the masked area for the full number of iterations specified - in `num_inference_steps`. `image` will be used as a reference for the masked area, adding more noise to - that region the larger the `strength`. If `strength` is 0, no inpainting will occur. + in `num_inference_steps`. `image_latents` will be used as a reference for the masked area, adding more + noise to that region the larger the `strength`. If `strength` is 0, no inpainting will occur. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. @@ -1330,14 +1285,6 @@ def __call__( callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. - lambda_auto_corr (`float`, *optional*, defaults to 20.0): - Lambda parameter to control auto correction - lambda_kl (`float`, *optional*, defaults to 20.0): - Lambda parameter to control Kullback–Leibler divergence output - num_reg_steps (`int`, *optional*, defaults to 0): - Number of regularization loss steps - num_auto_corr_rolls (`int`, *optional*, defaults to 5): - Number of auto correction roll steps Examples: @@ -1356,7 +1303,7 @@ def __call__( >>> img_url = "https://github.com/Xiang-cd/DiffEdit-stable-diffusion/raw/main/assets/origin.png" - >>> init_image = download_image(img_url).resize((512, 512)) + >>> init_image = download_image(img_url).resize((768, 768)) >>> pipe = StableDiffusionDiffEditPipeline.from_pretrained( ... "stabilityai/stable-diffusion-2-1", torch_dtype=torch.float16 @@ -1370,11 +1317,9 @@ def __call__( >>> mask_prompt = "A bowl of fruits" >>> prompt = "A bowl of pears" - >>> mask_image = pipe.compute_mask(image=init_image, prompt=prompt, mask_prompt=mask_prompt) - >>> inverted_latents = pipe.invert(image=invert_image, prompt=mask_prompt).latents - >>> image = pipe( - ... prompt=prompt, image=init_image, mask_image=mask_image, inverted_latents=inverted_latents - ... ).images[0] + >>> mask_image = pipe.compute_mask(image=init_image, source_prompt=prompt, target_prompt=mask_prompt) + >>> image_latents = pipe.invert(image=init_image, prompt=mask_prompt).latents + >>> image = pipe(prompt=prompt, mask=mask_image, image_latents=image_latents).images[0] ``` Returns: @@ -1387,6 +1332,7 @@ def __call__( # 0. Default height and width to unet height = height or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor + vae_latent_size = (height // self.vae_scale_factor, width // self.vae_scale_factor) # 1. Check inputs self.check_inputs( @@ -1400,12 +1346,12 @@ def __call__( negative_prompt_embeds, ) - if image is None: - raise ValueError("`image` input cannot be undefined.") - if mask_image is None: + if mask is None: raise ValueError( - "`mask_image` input cannot be undefined. Use `compute_mask()` to compute `mask_image` from text prompts." + "`mask` input cannot be undefined. Use `compute_mask()` to compute `mask` from text prompts." ) + if image_latents is None: + raise ValueError("`image_latents` input cannot be undefined. Use `invert()` to compute `image_latents` from input images.") # 2. Define call parameters if prompt is not None and isinstance(prompt, str): @@ -1432,51 +1378,39 @@ def __call__( negative_prompt_embeds=negative_prompt_embeds, ) - # 4. Preprocess mask and image - image, mask_image = preprocess_image_and_mask(image, mask_image) - - mask_image = torch.cat([mask_image] * num_images_per_prompt) - vae_latent_size = (height // self.vae_scale_factor, width // self.vae_scale_factor) - mask_image = torch.nn.functional.interpolate(mask_image, size=vae_latent_size) - mask_image = mask_image.to(device=device, dtype=prompt_embeds.dtype) + # 4. Preprocess mask + mask = preprocess_mask(mask) + mask_shape = (batch_size, 1, *vae_latent_size) + if mask.shape != mask_shape: + raise ValueError(f"`mask` must have shape {mask_shape}, but has shape {mask.shape}") + mask = torch.cat([mask] * num_images_per_prompt) + mask = mask.to(device=device, dtype=prompt_embeds.dtype) # 5. Set timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) - self.inverse_scheduler.set_timesteps(num_inference_steps, device=device) timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, inpaint_strength, device) - # 6. Prepare latent variables + # 6. Preprocess image latents num_channels_latents = self.vae.config.latent_channels - latents = self.prepare_latents( - batch_size * num_images_per_prompt, - num_channels_latents, - height, - width, - prompt_embeds.dtype, - device, - generator, - latents, - ) - - # 7. Preprocess inverted latents - inverted_latents_shape = (len(timesteps), batch_size, num_channels_latents, *vae_latent_size) - if inverted_latents is None: - raise ValueError( - "`inverted_latents` input cannot be undefined. Use `invert()` to compute `inverted_latents`." - ) - elif inverted_latents.shape != inverted_latents_shape: + image_latents_shape = (len(timesteps) * batch_size, num_channels_latents, *vae_latent_size) + if image_latents.shape != image_latents_shape: raise ValueError( - f"`inverted_latents` must have shape {inverted_latents_shape}, but has shape {inverted_latents.shape}" + f"`image_latents` must have shape {image_latents_shape}, but has shape {image_latents.shape}" ) - if isinstance(inverted_latents, np.ndarray): - inverted_latents = torch.from_numpy(inverted_latents) - inverted_latents = inverted_latents.to(device=device, dtype=prompt_embeds.dtype) + if isinstance(image_latents, np.ndarray): + image_latents = torch.from_numpy(image_latents) + image_latents = torch.cat( + [image_latents.reshape(-1, batch_size, num_channels_latents, *vae_latent_size)] * num_images_per_prompt, + 1, + ) + image_latents = image_latents.to(device=device, dtype=prompt_embeds.dtype) - # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) - # 9. Denoising loop - num_warmup_steps = len(timesteps) - num_inference_steps * self.inverse_scheduler.order + # 8. Denoising loop + latents = image_latents[0].detach().clone() + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance @@ -1495,7 +1429,7 @@ def __call__( latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample # mask with inverted latents from appropriate timestep - use original image latent for last step - latents = latents * mask_image + inverted_latents[i] * (1 - mask_image) + latents = latents * mask + image_latents[i] * (1 - mask) # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): @@ -1503,13 +1437,13 @@ def __call__( if callback is not None and i % callback_steps == 0: callback(i, t, latents) - # 10. Post-processing + # 9. Post-processing image = self.decode_latents(latents) - # 11. Run safety checker + # 10. Run safety checker image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) - # 12. Convert to PIL + # 11. Convert to PIL if output_type == "pil": image = self.numpy_to_pil(image) From eda67c22a85572d0e361894c41cec3313d19aa74 Mon Sep 17 00:00:00 2001 From: Clarence Chen Date: Mon, 27 Mar 2023 17:20:54 -0700 Subject: [PATCH 05/41] Add option to not decode latents in the inversion process --- .../pipeline_stable_diffusion_diffedit.py | 36 ++++++++++++------- 1 file changed, 24 insertions(+), 12 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py index 56ec589b44c0..3f613cb9d3bb 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py @@ -653,6 +653,10 @@ def get_inverse_timesteps(self, num_inference_steps, strength, device): init_timestep = min(int(num_inference_steps * strength), num_inference_steps) t_start = max(num_inference_steps - init_timestep, 0) + + # safety for t_start overflow to prevent empty timsteps slice + if t_start == num_inference_steps: + return self.inverse_scheduler.timesteps, num_inference_steps timesteps = self.inverse_scheduler.timesteps[:-t_start] return timesteps, num_inference_steps - t_start @@ -958,6 +962,7 @@ def invert( generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, + decode_latents: bool = False, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, @@ -975,13 +980,12 @@ def invert( The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. image (`PIL.Image.Image`): - `Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will - be masked out with `mask_image` and repainted according to `prompt`. + `Image`, or tensor representing an image batch to produce the inverted latents, guided by `prompt`. inpaint_strength (`float`, *optional*, defaults to 0.8): - Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1. When `strength` - is 1, the denoising process will be run on the masked area for the full number of iterations specified - in `num_inference_steps`. `image` will be used as a reference for the masked area, adding more noise to - that region the larger the `strength`. If `strength` is 0, no inpainting will occur. + Conceptually, indicates how far into the noising process to run latent inversion. Must be between 0 and + 1. When `strength` is 1, the inversion process will be run for the full number of iterations specified + in `num_inference_steps`. `image` will be used as a reference for the inversion process, adding more + noise the larger the `strength`. If `strength` is 0, no inpainting will occur. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. @@ -1007,11 +1011,14 @@ def invert( Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. + decode_latents (`bool`, *optional*, defaults to `False`): + Whether or not to decode the inverted latents into a generated image. Setting this argument to `True` + will decode all inverted latents for each timestep into a list of generated images. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + Whether or not to return a [`~pipelines.stable_diffusion.DiffEditInversionPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be @@ -1064,8 +1071,9 @@ def invert( Returns: [`~pipelines.stable_diffusion.pipeline_stable_diffusion_diffedit.DiffEditInversionPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.pipeline_stable_diffusion_diffedit.DiffEditInversionPipelineOutput`] - if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is the inverted - latents tensors ordered by increasing noise, and then second is the corresponding decoded images. + if `return_dict` is `True`, otherwise a `tuple`. When returning a tuple, the first element is the inverted + latents tensors ordered by increasing noise, and then second is the corresponding decoded images if + `decode_latents` is `True`, otherwise `None`. """ # 1. Check inputs @@ -1184,10 +1192,12 @@ def invert( latents = torch.cat(list(reversed(inverted_latents))) # 8. Post-processing - image = self.decode_latents(latents.detach()) + image = None + if decode_latents: + image = self.decode_latents(latents.detach()) # 9. Convert to PIL. - if output_type == "pil": + if decode_latents and output_type == "pil": image = self.numpy_to_pil(image) # Offload last model to CPU @@ -1351,7 +1361,9 @@ def __call__( "`mask` input cannot be undefined. Use `compute_mask()` to compute `mask` from text prompts." ) if image_latents is None: - raise ValueError("`image_latents` input cannot be undefined. Use `invert()` to compute `image_latents` from input images.") + raise ValueError( + "`image_latents` input cannot be undefined. Use `invert()` to compute `image_latents` from input images." + ) # 2. Define call parameters if prompt is not None and isinstance(prompt, str): From 219793cf77fe844b1c3e8ab82d6699bd81955588 Mon Sep 17 00:00:00 2001 From: Clarence Chen Date: Mon, 27 Mar 2023 17:36:42 -0700 Subject: [PATCH 06/41] Harmonize preprocessing --- .../pipeline_stable_diffusion_diffedit.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py index 3f613cb9d3bb..93d1b0423fac 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py @@ -81,7 +81,7 @@ def kl_divergence(hidden_states): # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess -def preprocess_image(image): +def preprocess(image): if isinstance(image, torch.Tensor): return image elif isinstance(image, PIL.Image.Image): @@ -89,7 +89,7 @@ def preprocess_image(image): if isinstance(image[0], PIL.Image.Image): w, h = image[0].size - w, h = map(lambda x: x - x % 8, (w, h)) # resize to integer multiple of 8 + w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8 image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image] image = np.concatenate(image, axis=0) @@ -900,7 +900,7 @@ def compute_mask( ) # 4. Preprocess image - image = preprocess_image(image) + image = preprocess(image) # 5. Set timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) @@ -1106,7 +1106,7 @@ def invert( do_classifier_free_guidance = guidance_scale > 1.0 # 3. Preprocess image - image = preprocess_image(image) + image = preprocess(image) # 4. Prepare latent variables num_images_per_prompt = 1 From 44d187e264c318a05ed71e24f0892f0f9548bb80 Mon Sep 17 00:00:00 2001 From: Clarence Chen Date: Mon, 27 Mar 2023 17:38:19 -0700 Subject: [PATCH 07/41] Revert "Update Pix2PixZero Auto-correlation Loss" This reverts commit b218062fed08d6cc164206d6cb852b2b7b00847a. --- .../pipeline_stable_diffusion_pix2pix_zero.py | 25 +++++++++++-------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py index be4920b89bce..bc072d4c73e8 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py @@ -750,18 +750,23 @@ def get_epsilon(self, model_output: torch.Tensor, sample: torch.Tensor, timestep ) def auto_corr_loss(self, hidden_states, generator=None): + batch_size, channel, height, width = hidden_states.shape + if batch_size > 1: + raise ValueError("Only batch_size 1 is supported for now") + + hidden_states = hidden_states.squeeze(0) + # hidden_states must be shape [C,H,W] now reg_loss = 0.0 for i in range(hidden_states.shape[0]): - for j in range(hidden_states.shape[1]): - noise = hidden_states[i : i + 1, j : j + 1, :, :] - while True: - roll_amount = torch.randint(noise.shape[2] // 2, (1,), generator=generator).item() - reg_loss += (noise * torch.roll(noise, shifts=roll_amount, dims=2)).mean() ** 2 - reg_loss += (noise * torch.roll(noise, shifts=roll_amount, dims=3)).mean() ** 2 - - if noise.shape[2] <= 8: - break - noise = F.avg_pool2d(noise, kernel_size=2) + noise = hidden_states[i][None, None, :, :] + while True: + roll_amount = torch.randint(noise.shape[2] // 2, (1,), generator=generator).item() + reg_loss += (noise * torch.roll(noise, shifts=roll_amount, dims=2)).mean() ** 2 + reg_loss += (noise * torch.roll(noise, shifts=roll_amount, dims=3)).mean() ** 2 + + if noise.shape[2] <= 8: + break + noise = F.avg_pool2d(noise, kernel_size=2) return reg_loss def kl_divergence(self, hidden_states): From 5fb209f946d8091ff702a7d1a349791187591117 Mon Sep 17 00:00:00 2001 From: Clarence Chen Date: Mon, 27 Mar 2023 17:44:12 -0700 Subject: [PATCH 08/41] Update annotations --- .../pipeline_stable_diffusion_diffedit.py | 4 ++-- .../utils/dummy_torch_and_transformers_objects.py | 15 +++++++++++++++ 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py index 93d1b0423fac..0ad6424080d5 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py @@ -953,7 +953,7 @@ def compute_mask( @torch.no_grad() def invert( self, - prompt: Optional[str] = None, + prompt: Optional[Union[str, List[str]]] = None, image: Union[torch.FloatTensor, PIL.Image.Image] = None, num_inference_steps: int = 50, inpaint_strength: float = 0.8, @@ -1212,7 +1212,7 @@ def invert( @torch.no_grad() def __call__( self, - prompt: Union[str, List[str]] = None, + prompt: Optional[Union[str, List[str]]] = None, mask: Union[torch.FloatTensor, PIL.Image.Image] = None, image_latents: torch.FloatTensor = None, inpaint_strength: Optional[float] = 0.8, diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index ab85566049d8..b64af923d71c 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -152,6 +152,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class StableDiffusionDiffEditPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class StableDiffusionImageVariationPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] From 3f74c41172b634e7f532a2367a296f315d4086bb Mon Sep 17 00:00:00 2001 From: Clarence Chen Date: Tue, 28 Mar 2023 13:42:05 -0700 Subject: [PATCH 09/41] rename `compute_mask` to `generate_mask` --- .../stable_diffusion/pipeline_stable_diffusion_diffedit.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py index 0ad6424080d5..b388e3b1a453 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py @@ -731,7 +731,7 @@ def get_epsilon(self, model_output: torch.Tensor, sample: torch.Tensor, timestep ) @torch.no_grad() - def compute_mask( + def generate_mask( self, image: Union[torch.FloatTensor, PIL.Image.Image] = None, target_prompt: Optional[Union[str, List[str]]] = None, @@ -1327,7 +1327,7 @@ def __call__( >>> mask_prompt = "A bowl of fruits" >>> prompt = "A bowl of pears" - >>> mask_image = pipe.compute_mask(image=init_image, source_prompt=prompt, target_prompt=mask_prompt) + >>> mask_image = pipe.generate_mask(image=init_image, source_prompt=prompt, target_prompt=mask_prompt) >>> image_latents = pipe.invert(image=init_image, prompt=mask_prompt).latents >>> image = pipe(prompt=prompt, mask=mask_image, image_latents=image_latents).images[0] ``` @@ -1358,7 +1358,7 @@ def __call__( if mask is None: raise ValueError( - "`mask` input cannot be undefined. Use `compute_mask()` to compute `mask` from text prompts." + "`mask` input cannot be undefined. Use `generate_mask()` to compute `mask` from text prompts." ) if image_latents is None: raise ValueError( From d459eb901202ba7c6e44a5b6b3ae1a974c55cb49 Mon Sep 17 00:00:00 2001 From: Clarence Chen Date: Tue, 28 Mar 2023 16:19:56 -0700 Subject: [PATCH 10/41] Update documentation --- .../pipelines/stable_diffusion/diffedit.mdx | 197 ++++++++++++------ 1 file changed, 128 insertions(+), 69 deletions(-) diff --git a/docs/source/en/api/pipelines/stable_diffusion/diffedit.mdx b/docs/source/en/api/pipelines/stable_diffusion/diffedit.mdx index 95ee0bfc02c0..6d210b800307 100644 --- a/docs/source/en/api/pipelines/stable_diffusion/diffedit.mdx +++ b/docs/source/en/api/pipelines/stable_diffusion/diffedit.mdx @@ -22,27 +22,31 @@ The abstract of the paper is the following: Resources: -* [Project Page](https://pix2pixzero.github.io/). * [Paper](https://arxiv.org/abs/2210.11427). * [Blog Post with Demo](https://blog.problemsolversguild.com/technical/research/2022/11/02/DiffEdit-Implementation.html). -* [Implementation on Github](https://github.com/Xiang-cd/DiffEdit-stable-diffusion/blob/main/assets/origin.png). +* [Implementation on Github](https://github.com/Xiang-cd/DiffEdit-stable-diffusion/). ## Tips -* The pipeline can be conditioned on real input images. Check out the code examples below to know more. -* The pipeline exposes two arguments namely `source_embeds` and `target_embeds` -that let you control the direction of the semantic edits in the final image to be generated. Let's say, +* The pipeline can generate masks that can be fed into other inpainting pipelines. Check out the code examples below to know more. +* In order to generate an image using this pipeline, both an image mask (manually specified or generated using `generate_mask`) +and a set of partially inverted latents (generated using `invert`) _must_ be provided as arguments when calling the pipeline to generate the final edited image. +Refer to the code examples below for more details. +* The function `generate_mask` exposes two prompt arguments, `source_prompt` and `target_prompt`, +that let you control the locations of the semantic edits in the final image to be generated. Let's say, you wanted to translate from "cat" to "dog". In this case, the edit direction will be "cat -> dog". To reflect -this in the pipeline, you simply have to set the embeddings related to the phrases including "cat" to -`source_embeds` and "dog" to `target_embeds`. Refer to the code example below for more details. -* When you're using this pipeline from a prompt, specify the _source_ concept in the prompt. Taking -the above example, a valid input prompt would be: "a high resolution painting of a **cat** in the style of van gough". +this in the generated mask, you simply have to set the embeddings related to the phrases including "cat" to +`source_prompt_embeds` and "dog" to `target_prompt_embeds`. Refer to the code example below for more details. +* When generating partially inverted latents using `invert`, specify a caption describing the overall image for the `prompt` +argument. Taking the above example, a valid input prompt would be: "a high resolution painting of a cat in the style of van gough". +* Note that the caption provided for `invert` can be automatically generated. Please refer to [this code example](#generating-image-captions-for-inversion) for more details. +* When calling the pipeline to generate the final edited image, assign the source concept to `negative_prompt` and the target concept to `prompt`. Taking the above example, you simply have to set the embeddings related to the phrases including "cat" to +`negative_prompt_embeds` and "dog" to `prompt_embeds`. Refer to the code example below for more details. * If you wanted to reverse the direction in the example above, i.e., "dog -> cat", then it's recommended to: - * Swap the `source_embeds` and `target_embeds`. - * Change the input prompt to include "dog". -* To learn more about how the source and target embeddings are generated, refer to the [original -paper](https://arxiv.org/abs/2302.03027). Below, we also provide some directions on how to generate the embeddings. -* Note that the quality of the outputs generated with this pipeline is dependent on how good the `source_embeds` and `target_embeds` are. Please, refer to [this discussion](#generating-source-and-target-embeddings) for some suggestions on the topic. + * Swap the `source_prompt` and `target_prompt` in the arguments to `generate_mask`. + * Change the input prompt for `invert` to include "dog". + * Swap the `prompt` and `negative_prompt` in the arguments to call the pipeline to generate the final edited image. +* Note that the source and target prompts, or their corresponding embeddings, can also be automatically generated. Please, refer to [this discussion](#generating-source-and-target-embeddings) for more details. ## Available Pipelines: @@ -54,89 +58,128 @@ paper](https://arxiv.org/abs/2302.03027). Below, we also provide some directions ## Usage example -### Based on an input image +### Based on an input image with a caption -When the pipeline is conditioned on an input image, we first obtain an inverted -noise from it using a `DDIMInverseScheduler` with the help of a generated caption. Then -the inverted noise is used to start the generation process. +When the pipeline is conditioned on an input image, we first obtain partially inverted latents from the input image using a +`DDIMInverseScheduler` with the help of a caption. Then we generate an editing mask to identify relevant regions in the image using the source and target prompts. Finally, +the inverted noise and generated mask is used to start the generation process. First, let's load our pipeline: ```py import torch -from transformers import BlipForConditionalGeneration, BlipProcessor from diffusers import DDIMScheduler, DDIMInverseScheduler, StableDiffusionPix2PixZeroPipeline -captioner_id = "Salesforce/blip-image-captioning-base" -processor = BlipProcessor.from_pretrained(captioner_id) -model = BlipForConditionalGeneration.from_pretrained(captioner_id, torch_dtype=torch.float16, low_cpu_mem_usage=True) - -sd_model_ckpt = "CompVis/stable-diffusion-2-1" -pipeline = StableDiffusionPix2PixZeroPipeline.from_pretrained( +sd_model_ckpt = "stabilityai/stable-diffusion-2-1" +pipeline = StableDiffusionDiffEditPipeline.from_pretrained( sd_model_ckpt, - caption_generator=model, - caption_processor=processor, torch_dtype=torch.float16, safety_checker=None, ) pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config) pipeline.inverse_scheduler = DDIMInverseScheduler.from_config(pipeline.scheduler.config) pipeline.enable_model_cpu_offload() +pipeline.enable_vae_slicing() ``` -Then, we load an input image for conditioning and obtain a suitable caption for it: +Then, we load an input image for conditioning and provide a suitable caption for it: ```py import requests -from PIL import Image - -img_url = "https://github.com/pix2pixzero/pix2pix-zero/raw/main/assets/test_images/cats/cat_6.png" -raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB").resize((512, 512)) -caption = pipeline.generate_caption(raw_image) -``` +from diffusers.utils import load_image -Then we employ the generated caption and the input image to get the inverted latents: - -```py -generator = torch.manual_seed(0) -inv_latents = pipeline.invert(caption, image=raw_image, generator=generator).latents +img_url = "https://github.com/Xiang-cd/DiffEdit-stable-diffusion/raw/main/assets/origin.png" +raw_image = load_image(img_url).convert("RGB").resize((768, 768)) ``` -Then we employ the source and target prompts to generate the editing mask: +Then, we employ the source and target prompts to generate the editing mask: ```py # See the "Generating source and target embeddings" section below to # automate the generation of these captions with a pre-trained model like Flan-T5 as explained below. -source_prompts = ["a cat sitting on the street", "a cat playing in the field", "a face of a cat"] -target_prompts = ["a dog sitting on the street", "a dog playing in the field", "a face of a dog"] -source_embeds = pipeline.get_embeds(source_prompts, batch_size=2) -target_embeds = pipeline.get_embeds(target_prompts, batch_size=2) -mask_image = pipeline.compute_mask( +source_prompt = "a bowl of fruits" +target_prompt = "a basket of fruits" +mask_image = pipeline.generate_mask( image=raw_image, - prompt_embeds=target_embeds, - mask_prompt_embeds=source_embeds, + source_prompt=source_prompt, + target_prompt=target_prompt, generator=generator, ) ``` +Then, we employ the caption and the input image to get the inverted latents: + +```py +generator = torch.manual_seed(0) +inv_latents = pipeline.invert(prompt=source_prompt, image=raw_image, generator=generator).latents +``` + Now, generate the image with the inverted latents and semantically generated mask: ```py image = pipeline( - prompt_embeds=target_embeds, - num_inference_steps=50, + prompt=target_prompt, + mask=mask_image, + image_latents=inv_latents, + generator=generator, + negative_prompt=source_prompt, +).images[0] +image.save("edited_image.png") +``` + +## Generating image captions for inversion + +The authors originally used the source concept prompt as the caption for generating the partially inverted latents. However, we can also leverage open source and public image captioning models for the same purpose. +Below, we provide an end-to-end example with the [Flan-T5](https://huggingface.co/docs/transformers/model_doc/flan-t5) model +for generating captions. + +First, let's load our automatic image captioning model: + +```py +import torch +from transformers import BlipForConditionalGeneration, BlipProcessor + +captioner_id = "Salesforce/blip-image-captioning-base" +processor = BlipProcessor.from_pretrained(captioner_id) +model = BlipForConditionalGeneration.from_pretrained(captioner_id, torch_dtype=torch.float16, low_cpu_mem_usage=True) +``` + +Then, we load an input image for conditioning and obtain a suitable caption for it: + +```py +import requests +from PIL import Image +from diffusers.utils import load_image + +img_url = "https://github.com/pix2pixzero/pix2pix-zero/raw/main/assets/test_images/cats/cat_6.png" +raw_image = load_image(img_url).convert("RGB").resize((768, 768)) +caption = None +``` + +Then, we employ the generated caption and the input image to get the inverted latents: + +```py +generator = torch.manual_seed(0) +inv_latents = pipeline.invert(prompt=source_prompt, image=raw_image, generator=generator).latents +``` + +Now, generate the image with the inverted latents and semantically generated mask from the previous section: + +```py +image = pipeline( + prompt=target_prompt, + mask=mask_image, + image_latents=inv_latents, generator=generator, - inverted_latents=inv_latents, - mask_image=mask_image, - negative_prompt=caption, + negative_prompt=source_prompt, ).images[0] image.save("edited_image.png") ``` ## Generating source and target embeddings -The authors originally used the [GPT-3 API](https://openai.com/api/) to generate the source and target captions for discovering +The authors originally required the user to manually provide the source and target prompts for discovering edit directions. However, we can also leverage open source and public models for the same purpose. Below, we provide an end-to-end example with the [Flan-T5](https://huggingface.co/docs/transformers/model_doc/flan-t5) model for generating captions and [CLIP](https://huggingface.co/docs/transformers/model_doc/clip) for @@ -196,14 +239,14 @@ We encourage you to play around with the different parameters supported by the Here, we need to use the same text encoder model used by the subsequent Stable Diffusion model. ```py -from diffusers import StableDiffusionPix2PixZeroPipeline +from diffusers import StableDiffusionDiffEditPipeline -pipeline = StableDiffusionPix2PixZeroPipeline.from_pretrained( - "CompVis/stable-diffusion-v1-4", torch_dtype=torch.float16 +pipeline = StableDiffusionDiffEditPipeline.from_pretrained( + "stabilityai/stable-diffusion-2-1", torch_dtype=torch.float16 ) pipeline = pipeline.to("cuda") -tokenizer = pipeline.tokenizer -text_encoder = pipeline.text_encoder +pipeline.enable_model_cpu_offload() +pipeline.enable_vae_slicing() ``` **5. Compute embeddings**: @@ -227,25 +270,39 @@ def embed_captions(sentences, tokenizer, text_encoder, device="cuda"): embeddings.append(prompt_embeds) return torch.concatenate(embeddings, dim=0).mean(dim=0).unsqueeze(0) -source_embeddings = embed_captions(source_captions, tokenizer, text_encoder) -target_embeddings = embed_captions(target_captions, tokenizer, text_encoder) +source_embeddings = embed_captions(source_captions, pipeline.tokenizer, pipeline.text_encoder) +target_embeddings = embed_captions(target_captions, pipeline.tokenizer, pipeline.text_encoder) ``` -And you're done! [Here](https://colab.research.google.com/drive/1tz2C1EdfZYAPlzXXbTnf-5PRBiR8_R1F?usp=sharing) is a Colab Notebook that you can use to interact with the entire process. - -Now, you can use these embeddings directly while calling the pipeline: +And you're done! Now, you can use these embeddings directly while calling the pipeline: ```py -from diffusers import DDIMScheduler +from diffusers import DDIMInverseScheduler, DDIMScheduler +from diffusers.utils import load_image pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config) +pipeline.inverse_scheduler = DDIMInverseScheduler.from_config(pipeline.scheduler.config) + +img_url = "https://github.com/pix2pixzero/pix2pix-zero/raw/main/assets/test_images/cats/cat_6.png" +raw_image = load_image(img_url).convert("RGB").resize((768, 768)) + + +mask = pipeline.generate_mask( + image=raw_image, + source_prompt_embeds=source_embeds, + target_prompt_embeds=target_embeds, +) + +inv_latents = pipeline.invert( + prompt=caption, + image=raw_image, +).latents images = pipeline( - prompt, - source_embeds=source_embeddings, - target_embeds=target_embeddings, - num_inference_steps=50, - cross_attention_guidance_amount=0.15, + mask=mask, + image_latents=inv_latents, + prompt_embeds=target_embeddings, + negative_prompt_embeds=source_embeddings, ).images images[0].save("edited_image_dog.png") ``` @@ -253,4 +310,6 @@ images[0].save("edited_image_dog.png") ## StableDiffusionDiffEditPipeline [[autodoc]] StableDiffusionDiffEditPipeline - __call__ + - generate_mask + - invert - all From e64fcbadb1ce6a2510ffedc95c9ba098a6ed2e4f Mon Sep 17 00:00:00 2001 From: Clarence Chen Date: Tue, 28 Mar 2023 16:43:24 -0700 Subject: [PATCH 11/41] Update docs --- .../pipelines/stable_diffusion/diffedit.mdx | 23 +++++++++++-------- .../pipeline_stable_diffusion_diffedit.py | 6 +++-- 2 files changed, 17 insertions(+), 12 deletions(-) diff --git a/docs/source/en/api/pipelines/stable_diffusion/diffedit.mdx b/docs/source/en/api/pipelines/stable_diffusion/diffedit.mdx index 6d210b800307..61de07da5f53 100644 --- a/docs/source/en/api/pipelines/stable_diffusion/diffedit.mdx +++ b/docs/source/en/api/pipelines/stable_diffusion/diffedit.mdx @@ -14,7 +14,7 @@ specific language governing permissions and limitations under the License. ## Overview -[DiffEdit: Diffusion-based semantic image editing with mask guidance](https://arxiv.org/abs/2210.11427). +[DiffEdit: Diffusion-based semantic image editing with mask guidance](https://arxiv.org/abs/2210.11427) by Guillaume Couairon, Jakob Verbeek, Holger Schwenk, and Matthieu Cord. The abstract of the paper is the following: @@ -37,14 +37,17 @@ that let you control the locations of the semantic edits in the final image to b you wanted to translate from "cat" to "dog". In this case, the edit direction will be "cat -> dog". To reflect this in the generated mask, you simply have to set the embeddings related to the phrases including "cat" to `source_prompt_embeds` and "dog" to `target_prompt_embeds`. Refer to the code example below for more details. -* When generating partially inverted latents using `invert`, specify a caption describing the overall image for the `prompt` -argument. Taking the above example, a valid input prompt would be: "a high resolution painting of a cat in the style of van gough". -* Note that the caption provided for `invert` can be automatically generated. Please refer to [this code example](#generating-image-captions-for-inversion) for more details. -* When calling the pipeline to generate the final edited image, assign the source concept to `negative_prompt` and the target concept to `prompt`. Taking the above example, you simply have to set the embeddings related to the phrases including "cat" to -`negative_prompt_embeds` and "dog" to `prompt_embeds`. Refer to the code example below for more details. +* When generating partially inverted latents using `invert`, assign a caption or text embedding describing the +overall image to the `prompt` argument to help guide the inverse latent sampling process. In most cases, the +source concept is sufficently descriptive to yield good results, but feel free to explore alternatives. +Please refer to [this code example](#generating-image-captions-for-inversion) for more details. +* When calling the pipeline to generate the final edited image, assign the source concept to `negative_prompt` +and the target concept to `prompt`. Taking the above example, you simply have to set the embeddings related to +the phrases including "cat" to `negative_prompt_embeds` and "dog" to `prompt_embeds`. Refer to the code example +below for more details. * If you wanted to reverse the direction in the example above, i.e., "dog -> cat", then it's recommended to: * Swap the `source_prompt` and `target_prompt` in the arguments to `generate_mask`. - * Change the input prompt for `invert` to include "dog". + * Change the input prompt for `invert` to include "dog". * Swap the `prompt` and `negative_prompt` in the arguments to call the pipeline to generate the final edited image. * Note that the source and target prompts, or their corresponding embeddings, can also be automatically generated. Please, refer to [this discussion](#generating-source-and-target-embeddings) for more details. @@ -228,7 +231,7 @@ And then we just call it to generate our captions: ```py source_captions = generate_captions(source_text) -target_captions = generate_captions(target_concept) +target_captions = generate_captions(target_text) ``` We encourage you to play around with the different parameters supported by the @@ -309,7 +312,7 @@ images[0].save("edited_image_dog.png") ## StableDiffusionDiffEditPipeline [[autodoc]] StableDiffusionDiffEditPipeline - - __call__ + - all - generate_mask - invert - - all + - __call__ \ No newline at end of file diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py index b388e3b1a453..761785fac043 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py @@ -161,7 +161,7 @@ class StableDiffusionDiffEditPipeline(DiffusionPipeline): unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. - inverse_scheduler (`Union[]`): + inverse_scheduler (`[DDIMInverseScheduler]`): A scheduler to be used in combination with `unet` to fill in the unmasked part of the input latents safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. @@ -1065,7 +1065,7 @@ def invert( >>> prompt = "A bowl of fruits" - >>> inverted_latents = pipe.invert(image=invert_image, prompt=mask_prompt).latents + >>> inverted_latents = pipe.invert(image=init_image, prompt=prompt).latents ``` Returns: @@ -1077,6 +1077,8 @@ def invert( """ # 1. Check inputs + # provide dummy height and width arguments to check_inputs, as the spatial dimensions of the inverted latents + # will be determined by the spatial dimensions of the input image. self.check_inputs( prompt, 0, From 1328b6acde580807726a38941adf79c62e27b1d4 Mon Sep 17 00:00:00 2001 From: Clarence Chen Date: Wed, 29 Mar 2023 11:56:08 -0700 Subject: [PATCH 12/41] Update Docs --- .../pipelines/stable_diffusion/diffedit.mdx | 101 +++++++++++------- 1 file changed, 64 insertions(+), 37 deletions(-) diff --git a/docs/source/en/api/pipelines/stable_diffusion/diffedit.mdx b/docs/source/en/api/pipelines/stable_diffusion/diffedit.mdx index 61de07da5f53..9c2e1977e7ad 100644 --- a/docs/source/en/api/pipelines/stable_diffusion/diffedit.mdx +++ b/docs/source/en/api/pipelines/stable_diffusion/diffedit.mdx @@ -85,7 +85,7 @@ pipeline.enable_model_cpu_offload() pipeline.enable_vae_slicing() ``` -Then, we load an input image for conditioning and provide a suitable caption for it: +Then, we load an input image to edit using our method: ```py import requests @@ -134,7 +134,7 @@ image.save("edited_image.png") ## Generating image captions for inversion The authors originally used the source concept prompt as the caption for generating the partially inverted latents. However, we can also leverage open source and public image captioning models for the same purpose. -Below, we provide an end-to-end example with the [Flan-T5](https://huggingface.co/docs/transformers/model_doc/flan-t5) model +Below, we provide an end-to-end example with the [BLIP](https://huggingface.co/docs/transformers/model_doc/blip) model for generating captions. First, let's load our automatic image captioning model: @@ -148,28 +148,55 @@ processor = BlipProcessor.from_pretrained(captioner_id) model = BlipForConditionalGeneration.from_pretrained(captioner_id, torch_dtype=torch.float16, low_cpu_mem_usage=True) ``` +Then, we define a utility to generate captions from an input image using the model: + +```py +@torch.no_grad() +def generate_caption(images, caption_generator, caption_processor): + text = "a photograph of" + + inputs = caption_processor(images, text, return_tensors="pt").to( + device="cuda", dtype=caption_generator.dtype + ) + caption_generator.to("cuda") + outputs = caption_generator.generate(**inputs, max_new_tokens=128) + + # offload caption generator + caption_generator.to("cpu") + + caption = caption_processor.batch_decode(outputs, skip_special_tokens=True)[0] + return caption +``` + Then, we load an input image for conditioning and obtain a suitable caption for it: ```py -import requests -from PIL import Image from diffusers.utils import load_image -img_url = "https://github.com/pix2pixzero/pix2pix-zero/raw/main/assets/test_images/cats/cat_6.png" +img_url = "https://github.com/Xiang-cd/DiffEdit-stable-diffusion/raw/main/assets/origin.png" raw_image = load_image(img_url).convert("RGB").resize((768, 768)) -caption = None +caption = generate_caption(raw_image, model, processor) ``` Then, we employ the generated caption and the input image to get the inverted latents: ```py generator = torch.manual_seed(0) -inv_latents = pipeline.invert(prompt=source_prompt, image=raw_image, generator=generator).latents +inv_latents = pipeline.invert(prompt=caption, image=raw_image, generator=generator).latents ``` -Now, generate the image with the inverted latents and semantically generated mask from the previous section: +Now, generate the image with the inverted latents and semantically generated mask from our source and target prompts: ```py +source_prompt = "a bowl of fruits" +target_prompt = "a basket of fruits" + +mask_image = pipeline.generate_mask( + image=raw_image, + source_prompt=source_prompt, + target_prompt=target_prompt, +) + image = pipeline( prompt=target_prompt, mask=mask_image, @@ -185,8 +212,7 @@ image.save("edited_image.png") The authors originally required the user to manually provide the source and target prompts for discovering edit directions. However, we can also leverage open source and public models for the same purpose. Below, we provide an end-to-end example with the [Flan-T5](https://huggingface.co/docs/transformers/model_doc/flan-t5) model -for generating captions and [CLIP](https://huggingface.co/docs/transformers/model_doc/clip) for -computing embeddings on the generated captions. +for generating source an target embeddings. **1. Load the generation model**: @@ -201,8 +227,8 @@ model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-xl", device_m **2. Construct a starting prompt**: ```py -source_concept = "cat" -target_concept = "dog" +source_concept = "bowl" +target_concept = "basket" source_text = f"Provide a caption for images containing a {source_concept}. " "The captions should be in English and should be no longer than 150 characters." @@ -211,14 +237,15 @@ target_text = f"Provide a caption for images containing a {target_concept}. " "The captions should be in English and should be no longer than 150 characters." ``` -Here, we're interested in the "cat -> dog" direction. +Here, we're interested in the "bowl -> basket" direction. -**3. Generate captions**: +**3. Generate prompts**: We can use a utility like so for this purpose. ```py -def generate_captions(input_prompt): +@torch.no_grad +def generate_prompts(input_prompt): input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids.to("cuda") outputs = model.generate( @@ -227,11 +254,11 @@ def generate_captions(input_prompt): return tokenizer.batch_decode(outputs, skip_special_tokens=True) ``` -And then we just call it to generate our captions: +And then we just call it to generate our prompts: ```py -source_captions = generate_captions(source_text) -target_captions = generate_captions(target_text) +source_prompts = generate_prompts(source_text) +target_prompts = generate_prompts(target_text) ``` We encourage you to play around with the different parameters supported by the @@ -257,24 +284,24 @@ pipeline.enable_vae_slicing() ```py import torch -def embed_captions(sentences, tokenizer, text_encoder, device="cuda"): - with torch.no_grad(): - embeddings = [] - for sent in sentences: - text_inputs = tokenizer( - sent, - padding="max_length", - max_length=tokenizer.model_max_length, - truncation=True, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids - prompt_embeds = text_encoder(text_input_ids.to(device), attention_mask=None)[0] - embeddings.append(prompt_embeds) +@torch.no_grad() +def embed_prompts(sentences, tokenizer, text_encoder, device="cuda"): + embeddings = [] + for sent in sentences: + text_inputs = tokenizer( + sent, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + prompt_embeds = text_encoder(text_input_ids.to(device), attention_mask=None)[0] + embeddings.append(prompt_embeds) return torch.concatenate(embeddings, dim=0).mean(dim=0).unsqueeze(0) -source_embeddings = embed_captions(source_captions, pipeline.tokenizer, pipeline.text_encoder) -target_embeddings = embed_captions(target_captions, pipeline.tokenizer, pipeline.text_encoder) +source_embeddings = embed_prompts(source_prompts, pipeline.tokenizer, pipeline.text_encoder) +target_embeddings = embed_prompts(target_captions, pipeline.tokenizer, pipeline.text_encoder) ``` And you're done! Now, you can use these embeddings directly while calling the pipeline: @@ -286,7 +313,7 @@ from diffusers.utils import load_image pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config) pipeline.inverse_scheduler = DDIMInverseScheduler.from_config(pipeline.scheduler.config) -img_url = "https://github.com/pix2pixzero/pix2pix-zero/raw/main/assets/test_images/cats/cat_6.png" +img_url = "https://github.com/Xiang-cd/DiffEdit-stable-diffusion/raw/main/assets/origin.png" raw_image = load_image(img_url).convert("RGB").resize((768, 768)) @@ -297,7 +324,7 @@ mask = pipeline.generate_mask( ) inv_latents = pipeline.invert( - prompt=caption, + prompt_embeds=source_embeds, image=raw_image, ).latents @@ -307,7 +334,7 @@ images = pipeline( prompt_embeds=target_embeddings, negative_prompt_embeds=source_embeddings, ).images -images[0].save("edited_image_dog.png") +images[0].save("edited_image.png") ``` ## StableDiffusionDiffEditPipeline From f7a12c8fe6c9cf47373a15ed0ddd4ad3fc24fa09 Mon Sep 17 00:00:00 2001 From: Clarence Chen Date: Wed, 29 Mar 2023 14:05:01 -0700 Subject: [PATCH 13/41] Fix copy --- .../stable_diffusion/pipeline_stable_diffusion_diffedit.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py index 6955b060f9a2..6b275d5d9b0b 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py @@ -398,8 +398,8 @@ def _encode_prompt( whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. - Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. From fcf26d4762ee318037660f6b4f7bd9abe8804da0 Mon Sep 17 00:00:00 2001 From: Clarence Chen Date: Wed, 29 Mar 2023 18:33:44 -0700 Subject: [PATCH 14/41] Change shape of output latents to batch first --- .../pipeline_stable_diffusion_diffedit.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py index 6b275d5d9b0b..96b1ff2e36a8 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py @@ -1192,7 +1192,7 @@ def invert( callback(i, t, latents) assert len(inverted_latents) == len(timesteps) - latents = torch.cat(list(reversed(inverted_latents))) + latents = torch.stack(list(reversed(inverted_latents)), 1) # 8. Post-processing image = None @@ -1407,17 +1407,14 @@ def __call__( # 6. Preprocess image latents num_channels_latents = self.vae.config.latent_channels - image_latents_shape = (len(timesteps) * batch_size, num_channels_latents, *vae_latent_size) + image_latents_shape = (batch_size, len(timesteps), num_channels_latents, *vae_latent_size) if image_latents.shape != image_latents_shape: raise ValueError( f"`image_latents` must have shape {image_latents_shape}, but has shape {image_latents.shape}" ) if isinstance(image_latents, np.ndarray): image_latents = torch.from_numpy(image_latents) - image_latents = torch.cat( - [image_latents.reshape(-1, batch_size, num_channels_latents, *vae_latent_size)] * num_images_per_prompt, - 1, - ) + image_latents = torch.cat([image_latents.transpose(0, 1)] * num_images_per_prompt, 1) image_latents = image_latents.to(device=device, dtype=prompt_embeds.dtype) # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline From 2821aa4d79e4064eab0471c10c1eb0e0b5cba65b Mon Sep 17 00:00:00 2001 From: Clarence Chen Date: Wed, 29 Mar 2023 18:33:51 -0700 Subject: [PATCH 15/41] Update docs --- .../pipelines/stable_diffusion/diffedit.mdx | 23 ++++++++++++++++--- 1 file changed, 20 insertions(+), 3 deletions(-) diff --git a/docs/source/en/api/pipelines/stable_diffusion/diffedit.mdx b/docs/source/en/api/pipelines/stable_diffusion/diffedit.mdx index 9c2e1977e7ad..919d02b23300 100644 --- a/docs/source/en/api/pipelines/stable_diffusion/diffedit.mdx +++ b/docs/source/en/api/pipelines/stable_diffusion/diffedit.mdx @@ -83,12 +83,12 @@ pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config) pipeline.inverse_scheduler = DDIMInverseScheduler.from_config(pipeline.scheduler.config) pipeline.enable_model_cpu_offload() pipeline.enable_vae_slicing() +generator = torch.manual_seed(0) ``` Then, we load an input image to edit using our method: ```py -import requests from diffusers.utils import load_image img_url = "https://github.com/Xiang-cd/DiffEdit-stable-diffusion/raw/main/assets/origin.png" @@ -114,7 +114,6 @@ mask_image = pipeline.generate_mask( Then, we employ the caption and the input image to get the inverted latents: ```py -generator = torch.manual_seed(0) inv_latents = pipeline.invert(prompt=source_prompt, image=raw_image, generator=generator).latents ``` @@ -180,7 +179,19 @@ caption = generate_caption(raw_image, model, processor) Then, we employ the generated caption and the input image to get the inverted latents: -```py +```py +from diffusers import DDIMInverseScheduler, DDIMScheduler + +pipeline = StableDiffusionDiffEditPipeline.from_pretrained( + "stabilityai/stable-diffusion-2-1", torch_dtype=torch.float16 +) +pipeline = pipeline.to("cuda") +pipeline.enable_model_cpu_offload() +pipeline.enable_vae_slicing() + +pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config) +pipeline.inverse_scheduler = DDIMInverseScheduler.from_config(pipeline.scheduler.config) + generator = torch.manual_seed(0) inv_latents = pipeline.invert(prompt=caption, image=raw_image, generator=generator).latents ``` @@ -195,6 +206,7 @@ mask_image = pipeline.generate_mask( image=raw_image, source_prompt=source_prompt, target_prompt=target_prompt, + generator=generator, ) image = pipeline( @@ -277,6 +289,8 @@ pipeline = StableDiffusionDiffEditPipeline.from_pretrained( pipeline = pipeline.to("cuda") pipeline.enable_model_cpu_offload() pipeline.enable_vae_slicing() + +generator = torch.manual_seed(0) ``` **5. Compute embeddings**: @@ -321,11 +335,13 @@ mask = pipeline.generate_mask( image=raw_image, source_prompt_embeds=source_embeds, target_prompt_embeds=target_embeds, + generator=generator, ) inv_latents = pipeline.invert( prompt_embeds=source_embeds, image=raw_image, + generator=generator, ).latents images = pipeline( @@ -333,6 +349,7 @@ images = pipeline( image_latents=inv_latents, prompt_embeds=target_embeddings, negative_prompt_embeds=source_embeddings, + generator=generator, ).images images[0].save("edited_image.png") ``` From dabd82f89e29858c3759a18ff84f06190d458170 Mon Sep 17 00:00:00 2001 From: Clarence Chen Date: Wed, 29 Mar 2023 18:41:52 -0700 Subject: [PATCH 16/41] Add first draft for tests --- .../test_stable_diffusion_diffedit.py | 259 ++++++++++++++++++ 1 file changed, 259 insertions(+) create mode 100644 tests/pipelines/stable_diffusion_2/test_stable_diffusion_diffedit.py diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_diffedit.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_diffedit.py new file mode 100644 index 000000000000..73ae21d5302a --- /dev/null +++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_diffedit.py @@ -0,0 +1,259 @@ +# coding=utf-8 +# Copyright 2023 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import random +import unittest + +import numpy as np +import torch +from PIL.Image import Image +from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer + +from diffusers import ( + AutoencoderKL, + DDIMInverseScheduler, + DDIMScheduler, + StableDiffusionDiffEditPipeline, + UNet2DConditionModel, +) +from diffusers.utils import load_image, load_numpy, skip_mps, slow +from diffusers.utils.testing_utils import floats_tensor, require_torch_gpu + +from ...pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS +from ...test_pipelines_common import PipelineTesterMixin + + +@skip_mps +class StableDiffusionDiffEditPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = StableDiffusionDiffEditPipeline + test_attention_slicing = False + params = TEXT_TO_IMAGE_PARAMS + batch_params = TEXT_TO_IMAGE_BATCH_PARAMS.union({"token_indices"}) + + def get_dummy_components(self): + torch.manual_seed(0) + unet = UNet2DConditionModel( + block_out_channels=(32, 64), + layers_per_block=2, + sample_size=32, + in_channels=4, + out_channels=4, + down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), + up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"), + cross_attention_dim=32, + # SD2-specific config below + attention_head_dim=(2, 4), + use_linear_projection=True, + ) + scheduler = DDIMScheduler( + beta_start=0.00085, + beta_end=0.012, + beta_schedule="scaled_linear", + clip_sample=False, + set_alpha_to_one=False, + ) + inverse_scheduler = DDIMInverseScheduler( + beta_start=0.00085, + beta_end=0.012, + beta_schedule="scaled_linear", + clip_sample=False, + set_alpha_to_zero=False, + ) + torch.manual_seed(0) + vae = AutoencoderKL( + block_out_channels=[32, 64], + in_channels=3, + out_channels=3, + down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"], + up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"], + latent_channels=4, + sample_size=128, + ) + torch.manual_seed(0) + text_encoder_config = CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=32, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=5, + pad_token_id=1, + vocab_size=1000, + # SD2-specific config below + hidden_act="gelu", + projection_dim=512, + ) + text_encoder = CLIPTextModel(text_encoder_config) + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + components = { + "unet": unet, + "scheduler": scheduler, + "inverse_scheduler": inverse_scheduler, + "vae": vae, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + "safety_checker": None, + "feature_extractor": None, + } + + return components + + def get_dummy_inputs(self, device, seed=0): + image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device) + image = image.cpu().permute(0, 2, 3, 1)[0] + image = Image.fromarray(np.uint8(image)).convert("RGB") + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + mask_inputs = { + "image": image, + "source_prompt": "a cat and a frog", + "target_prompt": "a dog and a newt", + "generator": generator, + "num_inference_steps": 2, + "num_maps_per_mask": 2, + "mask_encode_strength": 1.0, + "guidance_scale": 6.0, + "output_type": "numpy", + } + + invert_inputs = { + "image": image, + "prompt": "a cat and a frog", + "generator": generator, + "num_inference_steps": 2, + "inpaint_strength": 1.0, + "guidance_scale": 6.0, + "decode_latents": True, + "output_type": "numpy", + } + return mask_inputs, invert_inputs + + def test_mask(self): + device = "cpu" + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs, _ = self.get_dummy_inputs(device) + mask = pipe.generate_mask(**inputs) + mask_slice = mask[0, -3:, -3:] + + self.assertEqual(mask.shape, (1, 96, 96)) + expected_slice = np.array([0] * 9) + max_diff = np.abs(mask_slice.flatten() - expected_slice).max() + self.assertLessEqual(max_diff, 1e-3) + + def test_inversion(self): + device = "cpu" + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + _, inputs = self.get_dummy_inputs(device) + image = pipe.invert(**inputs).images + image_slice = image[0, :, -1, -3:, -3:] + + self.assertEqual(image.shape, (1, 2, 4, 96, 96)) + expected_slice = np.array( + [ + 0.5644937, + 0.60543084, + 0.48239064, + 0.5206757, + 0.55623394, + 0.46045133, + 0.5100435, + 0.48919064, + 0.4759359, + 0.5644937, + 0.60543084, + 0.48239064, + 0.5206757, + 0.55623394, + 0.46045133, + 0.5100435, + 0.48919064, + 0.4759359, + ], + ) + max_diff = np.abs(image_slice.flatten() - expected_slice).max() + self.assertLessEqual(max_diff, 1e-3) + + +@require_torch_gpu +@slow +class StableDiffusionDiffEditPipelineIntegrationTests(unittest.TestCase): + def tearDown(self): + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + @classmethod + def setUpClass(cls): + raw_image = load_image( + "https://raw.githubusercontent.com/Xiang-cd/DiffEdit-stable-diffusion/main/assets/origin.png" + ) + + raw_image = raw_image.convert("RGB").resize((768, 768)) + + cls.raw_image = raw_image + + def test_stable_diffusion_diffedit_full(self): + generator = torch.manual_seed(0) + + pipe = StableDiffusionDiffEditPipeline.from_pretrained( + "stabilityai/stable-diffusion-2-1", safety_checker=None, torch_dtype=torch.float16 + ) + pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) + pipe.inverse_scheduler = DDIMInverseScheduler.from_config(pipe.scheduler.config) + pipe.enable_model_cpu_offload() + pipe.set_progress_bar_config(disable=None) + + source_prompt = "a bowl of fruit" + target_prompt = "a bowl of pears" + + mask_image = pipe.generate_mask( + image=self.raw_image, + source_prompt=source_prompt, + target_prompt=target_prompt, + inpaint_strength=0.6, + generator=generator, + ) + + inv_latents = pipe.invert(prompt=source_prompt, image=self.raw_image, generator=generator).latents + + image = pipe( + prompt=target_prompt, + mask=mask_image, + image_latents=inv_latents, + generator=generator, + negative_prompt=source_prompt, + inpaint_strength=0.6, + output_type="numpy", + ).images[0] + + expected_image = load_numpy( + "https://raw.githubusercontent.com/Xiang-cd/DiffEdit-stable-diffusion/main/assets/target.png" + ) + assert np.abs((expected_image - image).max()) < 1e-1 From 18ee76a2b312b39470294b4b6da3ac87a4cb71fc Mon Sep 17 00:00:00 2001 From: Clarence Chen Date: Thu, 30 Mar 2023 13:55:24 -0700 Subject: [PATCH 17/41] Bugfix and update tests --- .../pipeline_stable_diffusion_diffedit.py | 4 +- .../test_stable_diffusion_diffedit.py | 70 +++++++++++-------- 2 files changed, 43 insertions(+), 31 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py index 96b1ff2e36a8..6a576ee5e4e8 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py @@ -655,7 +655,7 @@ def get_inverse_timesteps(self, num_inference_steps, strength, device): t_start = max(num_inference_steps - init_timestep, 0) # safety for t_start overflow to prevent empty timsteps slice - if t_start == num_inference_steps: + if t_start == 0: return self.inverse_scheduler.timesteps, num_inference_steps timesteps = self.inverse_scheduler.timesteps[:-t_start] @@ -1197,7 +1197,7 @@ def invert( # 8. Post-processing image = None if decode_latents: - image = self.decode_latents(latents.detach()) + image = self.decode_latents(latents.flatten(0, 1).detach()) # 9. Convert to PIL. if decode_latents and output_type == "pil": diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_diffedit.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_diffedit.py index 73ae21d5302a..faf774583a94 100644 --- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_diffedit.py +++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_diffedit.py @@ -19,7 +19,7 @@ import numpy as np import torch -from PIL.Image import Image +from PIL import Image from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer from diffusers import ( @@ -114,6 +114,26 @@ def get_dummy_components(self): return components def get_dummy_inputs(self, device, seed=0): + mask = floats_tensor((1, 16, 16), rng=random.Random(seed)).to(device) + latents = floats_tensor((2, 4, 16, 16), rng=random.Random(seed)).to(device) + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + inputs = { + "prompt": "a dog and a newt", + "mask": mask, + "image_latents": latents, + "generator": generator, + "num_inference_steps": 2, + "inpaint_strength": 1.0, + "guidance_scale": 6.0, + "output_type": "numpy", + } + + return inputs + + def get_dummy_mask_inputs(self, device, seed=0): image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device) image = image.cpu().permute(0, 2, 3, 1)[0] image = Image.fromarray(np.uint8(image)).convert("RGB") @@ -121,7 +141,7 @@ def get_dummy_inputs(self, device, seed=0): generator = torch.manual_seed(seed) else: generator = torch.Generator(device=device).manual_seed(seed) - mask_inputs = { + inputs = { "image": image, "source_prompt": "a cat and a frog", "target_prompt": "a dog and a newt", @@ -133,7 +153,17 @@ def get_dummy_inputs(self, device, seed=0): "output_type": "numpy", } - invert_inputs = { + return inputs + + def get_dummy_inversion_inputs(self, device, seed=0): + image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device) + image = image.cpu().permute(0, 2, 3, 1)[0] + image = Image.fromarray(np.uint8(image)).convert("RGB") + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + inputs = { "image": image, "prompt": "a cat and a frog", "generator": generator, @@ -143,7 +173,7 @@ def get_dummy_inputs(self, device, seed=0): "decode_latents": True, "output_type": "numpy", } - return mask_inputs, invert_inputs + return inputs def test_mask(self): device = "cpu" @@ -153,14 +183,15 @@ def test_mask(self): pipe.to(device) pipe.set_progress_bar_config(disable=None) - inputs, _ = self.get_dummy_inputs(device) + inputs = self.get_dummy_mask_inputs(device) mask = pipe.generate_mask(**inputs) mask_slice = mask[0, -3:, -3:] - self.assertEqual(mask.shape, (1, 96, 96)) + self.assertEqual(mask.shape, (1, 16, 16)) expected_slice = np.array([0] * 9) max_diff = np.abs(mask_slice.flatten() - expected_slice).max() self.assertLessEqual(max_diff, 1e-3) + self.assertEqual(mask[0, -3, -4], 1) def test_inversion(self): device = "cpu" @@ -170,32 +201,13 @@ def test_inversion(self): pipe.to(device) pipe.set_progress_bar_config(disable=None) - _, inputs = self.get_dummy_inputs(device) + inputs = self.get_dummy_inversion_inputs(device) image = pipe.invert(**inputs).images - image_slice = image[0, :, -1, -3:, -3:] + image_slice = image[0, -1, -3:, -3:] - self.assertEqual(image.shape, (1, 2, 4, 96, 96)) + self.assertEqual(image.shape, (2, 32, 32, 3)) expected_slice = np.array( - [ - 0.5644937, - 0.60543084, - 0.48239064, - 0.5206757, - 0.55623394, - 0.46045133, - 0.5100435, - 0.48919064, - 0.4759359, - 0.5644937, - 0.60543084, - 0.48239064, - 0.5206757, - 0.55623394, - 0.46045133, - 0.5100435, - 0.48919064, - 0.4759359, - ], + [0.5588859, 0.535619, 0.52224344, 0.55604255, 0.48608556, 0.51105076, 0.50301707, 0.44348782, 0.48488846], ) max_diff = np.abs(image_slice.flatten() - expected_slice).max() self.assertLessEqual(max_diff, 1e-3) From f31b7b85722fc2d91098d376ddf3ae7edaeaecba Mon Sep 17 00:00:00 2001 From: Clarence Chen Date: Fri, 31 Mar 2023 13:05:11 -0700 Subject: [PATCH 18/41] Add `cross_attention_kwargs` support for all pipeline methods --- .../pipeline_stable_diffusion_diffedit.py | 44 +++++++++++++++++-- 1 file changed, 40 insertions(+), 4 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py index 6a576ee5e4e8..876c3424557a 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py @@ -14,7 +14,7 @@ import inspect from dataclasses import dataclass -from typing import Callable, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Union import numpy as np import PIL @@ -751,6 +751,7 @@ def generate_mask( guidance_scale: float = 7.5, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, output_type: Optional[str] = "np", + cross_attention_kwargs: Optional[Dict[str, Any]] = None, ): r""" Function used to generate a latent mask given a mask prompt, a target prompt, and an image. @@ -817,6 +818,10 @@ def generate_mask( output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). Examples: @@ -872,6 +877,8 @@ def generate_mask( batch_size = len(target_prompt) else: batch_size = target_prompt_embeds.shape[0] + if cross_attention_kwargs is None: + cross_attention_kwargs = {} device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) @@ -920,7 +927,12 @@ def generate_mask( # 7. Predict the noise residual prompt_embeds = torch.cat([source_prompt_embeds, target_prompt_embeds]) - noise_pred = self.unet(latent_model_input, encode_timestep, encoder_hidden_states=prompt_embeds).sample + noise_pred = self.unet( + latent_model_input, + encode_timestep, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + ).sample if do_classifier_free_guidance: noise_pred_neg_src, noise_pred_source, noise_pred_uncond, noise_pred_target = noise_pred.chunk(4) @@ -968,6 +980,7 @@ def invert( return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: Optional[int] = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, lambda_auto_corr: float = 20.0, lambda_kl: float = 20.0, num_reg_steps: int = 0, @@ -1027,6 +1040,10 @@ def invert( callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). lambda_auto_corr (`float`, *optional*, defaults to 20.0): Lambda parameter to control auto correction lambda_kl (`float`, *optional*, defaults to 20.0): @@ -1101,6 +1118,8 @@ def invert( batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] + if cross_attention_kwargs is None: + cross_attention_kwargs = {} device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) @@ -1142,7 +1161,12 @@ def invert( latent_model_input = self.inverse_scheduler.scale_model_input(latent_model_input, t) # predict the noise residual - noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds).sample + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + ).sample # perform guidance if do_classifier_free_guidance: @@ -1234,6 +1258,7 @@ def __call__( return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, ): r""" Function invoked when calling the pipeline for generation. @@ -1298,6 +1323,10 @@ def __call__( callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). Examples: @@ -1375,6 +1404,8 @@ def __call__( batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] + if cross_attention_kwargs is None: + cross_attention_kwargs = {} device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) @@ -1430,7 +1461,12 @@ def __call__( latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual - noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds).sample + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + ).sample # perform guidance if do_classifier_free_guidance: From ab79c4f8ef286fdd05f594f3b27ecad772734f03 Mon Sep 17 00:00:00 2001 From: Clarence Chen Date: Fri, 31 Mar 2023 13:10:29 -0700 Subject: [PATCH 19/41] Fix Copies --- .../pipeline_stable_diffusion_diffedit.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py index 876c3424557a..0e73c8d11410 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py @@ -23,6 +23,7 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from ...configuration_utils import FrozenDict +from ...loaders import TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import DDIMInverseScheduler, KarrasDiffusionSchedulers from ...utils import ( @@ -416,6 +417,10 @@ def _encode_prompt( batch_size = prompt_embeds.shape[0] if prompt_embeds is None: + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + text_inputs = self.tokenizer( prompt, padding="max_length", @@ -476,6 +481,10 @@ def _encode_prompt( else: uncond_tokens = negative_prompt + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, From a08f620eea74c753d306f965bf5650033f6a2986 Mon Sep 17 00:00:00 2001 From: Clarence Chen Date: Sun, 2 Apr 2023 18:37:29 -0700 Subject: [PATCH 20/41] Add support for PIL image latents Add support for mask broadcasting Update docs and tests Align `mask` argument to `mask_image` Remove height and width arguments --- .../pipelines/stable_diffusion/diffedit.mdx | 8 +- .../pipeline_stable_diffusion_diffedit.py | 111 ++++++++---------- .../test_stable_diffusion_diffedit.py | 11 +- 3 files changed, 59 insertions(+), 71 deletions(-) diff --git a/docs/source/en/api/pipelines/stable_diffusion/diffedit.mdx b/docs/source/en/api/pipelines/stable_diffusion/diffedit.mdx index 919d02b23300..88fa692c1fc7 100644 --- a/docs/source/en/api/pipelines/stable_diffusion/diffedit.mdx +++ b/docs/source/en/api/pipelines/stable_diffusion/diffedit.mdx @@ -122,7 +122,7 @@ Now, generate the image with the inverted latents and semantically generated mas ```py image = pipeline( prompt=target_prompt, - mask=mask_image, + mask_image=mask_image, image_latents=inv_latents, generator=generator, negative_prompt=source_prompt, @@ -211,7 +211,7 @@ mask_image = pipeline.generate_mask( image = pipeline( prompt=target_prompt, - mask=mask_image, + mask_image=mask_image, image_latents=inv_latents, generator=generator, negative_prompt=source_prompt, @@ -331,7 +331,7 @@ img_url = "https://github.com/Xiang-cd/DiffEdit-stable-diffusion/raw/main/assets raw_image = load_image(img_url).convert("RGB").resize((768, 768)) -mask = pipeline.generate_mask( +mask_image = pipeline.generate_mask( image=raw_image, source_prompt_embeds=source_embeds, target_prompt_embeds=target_embeds, @@ -345,7 +345,7 @@ inv_latents = pipeline.invert( ).latents images = pipeline( - mask=mask, + mask_image=mask_image, image_latents=inv_latents, prompt_embeds=target_embeddings, negative_prompt_embeds=source_embeddings, diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py index 0e73c8d11410..83ab0eae81a2 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py @@ -103,19 +103,20 @@ def preprocess(image): return image -def preprocess_mask(mask): +def preprocess_mask(mask, batch_size: int = 1): if not isinstance(mask, torch.Tensor): # preprocess mask if isinstance(mask, PIL.Image.Image) or (isinstance(mask, np.ndarray) and mask.ndim < 3): mask = [mask] - if isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image): - mask = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask], axis=0) - mask = mask.astype(np.float32) / 255.0 - elif isinstance(mask, list) and isinstance(mask[0], np.ndarray): - mask = np.concatenate([m[None, None, :] for m in mask], axis=0) - - mask = torch.from_numpy(mask) + if isinstance(mask, list): + if isinstance(mask[0], PIL.Image.Image): + mask = [np.array(m.convert("L")).astype(np.float32) / 255.0 for m in mask] + if isinstance(mask[0], np.ndarray): + mask = np.stack(mask, axis=0) + mask = torch.from_numpy(mask) + elif isinstance(mask[0], torch.Tensor): + mask = torch.stack(mask, dim=0) # Batch and add channel dim for single mask if mask.ndim == 2: @@ -131,9 +132,22 @@ def preprocess_mask(mask): else: mask = mask.unsqueeze(1) + # Check mask shape + if batch_size > 1: + if mask.shape[0] == 1: + mask = torch.cat([mask] * batch_size) + elif mask.shape[0] > 1 and mask.shape[0] != batch_size: + raise ValueError( + f"`mask_image` with batch size {mask.shape[0]} cannot be broadcasted to batch size {batch_size} " + f"inferred by prompt inputs" + ) + + if mask.shape[1] != 1: + raise ValueError(f"`mask_image` must have 1 channel, but has {mask.shape[1]} channels") + # Check mask is in [0, 1] if mask.min() < 0 or mask.max() > 1: - raise ValueError("Mask should be in [0, 1] range") + raise ValueError("`mask_image` should be in [0, 1] range") # Binarize mask mask[mask < 0.5] = 0 @@ -562,17 +576,12 @@ def decode_latents(self, latents): def check_inputs( self, prompt, - height, - width, strength, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, ): - if height % 8 != 0 or width % 8 != 0: - raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") - if (strength is None) or (strength is not None and (strength < 0 or strength > 1)): raise ValueError( f"The value of `strength` should in [0.0, 1.0] but is, but is {strength} of type {type(strength)}." @@ -754,8 +763,6 @@ def generate_mask( num_maps_per_mask: Optional[int] = 10, mask_encode_strength: Optional[float] = 0.5, mask_thresholding_ratio: Optional[float] = 3.0, - height: Optional[int] = None, - width: Optional[int] = None, num_inference_steps: int = 50, guidance_scale: float = 7.5, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, @@ -808,10 +815,6 @@ def generate_mask( mask_thresholding_ratio (`float`, *optional*, defaults to 3.0): The maximum multiple of the mean absolute difference used to clamp the semantic guidance map before mask binarization. - height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): - The height in pixels of the generated image. - width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): - The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. @@ -842,15 +845,9 @@ def generate_mask( self.vae_scale_factor)`. """ - # 0. Default height and width to unet - height = height or self.unet.config.sample_size * self.vae_scale_factor - width = width or self.unet.config.sample_size * self.vae_scale_factor - - # 1. Check inputs + # 1. Check inputs (Provide dummy argument for callback_steps) self.check_inputs( target_prompt, - height, - width, mask_encode_strength, 1, target_negative_prompt, @@ -1104,12 +1101,8 @@ def invert( """ # 1. Check inputs - # provide dummy height and width arguments to check_inputs, as the spatial dimensions of the inverted latents - # will be determined by the spatial dimensions of the input image. self.check_inputs( prompt, - 0, - 0, inpaint_strength, callback_steps, negative_prompt, @@ -1249,11 +1242,9 @@ def invert( def __call__( self, prompt: Optional[Union[str, List[str]]] = None, - mask: Union[torch.FloatTensor, PIL.Image.Image] = None, + mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None, image_latents: torch.FloatTensor = None, inpaint_strength: Optional[float] = 0.8, - height: Optional[int] = None, - width: Optional[int] = None, num_inference_steps: int = 50, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, @@ -1276,12 +1267,12 @@ def __call__( prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. - mask (`PIL.Image.Image`): + mask_image (`PIL.Image.Image`): `Image`, or tensor representing an image batch, to mask the generated image. White pixels in the mask - will be repainted, while black pixels will be preserved. If `mask` is a PIL image, it will be converted - to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L) - instead of 3, so the expected shape would be `(B, 1, H, W)`. - image_latents (`torch.FloatTensor`): + will be repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be + converted to a single channel (luminance) before use. If it's a tensor, it should contain one color + channel (L) instead of 3, so the expected shape would be `(B, 1, H, W)`. + image_latents (`PIL.Image.Image` or `torch.FloatTensor`): Partially noised image latents from the inversion process to be used as inputs for image generation. inpaint_strength (`float`, *optional*, defaults to 0.8): Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1. When `strength` @@ -1370,7 +1361,7 @@ def __call__( >>> mask_image = pipe.generate_mask(image=init_image, source_prompt=prompt, target_prompt=mask_prompt) >>> image_latents = pipe.invert(image=init_image, prompt=mask_prompt).latents - >>> image = pipe(prompt=prompt, mask=mask_image, image_latents=image_latents).images[0] + >>> image = pipe(prompt=prompt, mask_image=mask_image, image_latents=image_latents).images[0] ``` Returns: @@ -1380,16 +1371,10 @@ def __call__( list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ - # 0. Default height and width to unet - height = height or self.unet.config.sample_size * self.vae_scale_factor - width = width or self.unet.config.sample_size * self.vae_scale_factor - vae_latent_size = (height // self.vae_scale_factor, width // self.vae_scale_factor) # 1. Check inputs self.check_inputs( prompt, - height, - width, inpaint_strength, callback_steps, negative_prompt, @@ -1397,9 +1382,9 @@ def __call__( negative_prompt_embeds, ) - if mask is None: + if mask_image is None: raise ValueError( - "`mask` input cannot be undefined. Use `generate_mask()` to compute `mask` from text prompts." + "`mask_image` input cannot be undefined. Use `generate_mask()` to compute `mask_image` from text prompts." ) if image_latents is None: raise ValueError( @@ -1434,26 +1419,30 @@ def __call__( ) # 4. Preprocess mask - mask = preprocess_mask(mask) - mask_shape = (batch_size, 1, *vae_latent_size) - if mask.shape != mask_shape: - raise ValueError(f"`mask` must have shape {mask_shape}, but has shape {mask.shape}") - mask = torch.cat([mask] * num_images_per_prompt) - mask = mask.to(device=device, dtype=prompt_embeds.dtype) + mask_image = preprocess_mask(mask_image, batch_size) + latent_height, latent_width = mask_image.shape[-2:] + mask_image = torch.cat([mask_image] * num_images_per_prompt) + mask_image = mask_image.to(device=device, dtype=prompt_embeds.dtype) # 5. Set timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, inpaint_strength, device) # 6. Preprocess image latents - num_channels_latents = self.vae.config.latent_channels - image_latents_shape = (batch_size, len(timesteps), num_channels_latents, *vae_latent_size) - if image_latents.shape != image_latents_shape: + image_latents = preprocess(image_latents) + latent_shape = (self.vae.config.latent_channels, latent_height, latent_width) + if image_latents.shape[-3:] != latent_shape: + raise ValueError( + f"Each latent image in `image_latents` must have shape {latent_shape}, " + f"but has shape {image_latents.shape[-3:]}" + ) + if image_latents.ndim == 4: + image_latents = image_latents.reshape(batch_size, len(timesteps), *latent_shape) + if image_latents.shape[:2] != (batch_size, len(timesteps)): raise ValueError( - f"`image_latents` must have shape {image_latents_shape}, but has shape {image_latents.shape}" + f"`image_latents` must have batch size {batch_size} with latent images from {len(timesteps)} timesteps, " + f"but has batch size {image_latents.shape[0]} with latent images from {image_latents.shape[1]} timesteps." ) - if isinstance(image_latents, np.ndarray): - image_latents = torch.from_numpy(image_latents) image_latents = torch.cat([image_latents.transpose(0, 1)] * num_images_per_prompt, 1) image_latents = image_latents.to(device=device, dtype=prompt_embeds.dtype) @@ -1486,7 +1475,7 @@ def __call__( latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample # mask with inverted latents from appropriate timestep - use original image latent for last step - latents = latents * mask + image_latents[i] * (1 - mask) + latents = latents * mask_image + image_latents[i] * (1 - mask_image) # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_diffedit.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_diffedit.py index faf774583a94..a24817715015 100644 --- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_diffedit.py +++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_diffedit.py @@ -32,16 +32,15 @@ from diffusers.utils import load_image, load_numpy, skip_mps, slow from diffusers.utils.testing_utils import floats_tensor, require_torch_gpu -from ...pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS +from ...pipeline_params import TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS, TEXT_GUIDED_IMAGE_INPAINTING_PARAMS from ...test_pipelines_common import PipelineTesterMixin @skip_mps class StableDiffusionDiffEditPipelineFastTests(PipelineTesterMixin, unittest.TestCase): pipeline_class = StableDiffusionDiffEditPipeline - test_attention_slicing = False - params = TEXT_TO_IMAGE_PARAMS - batch_params = TEXT_TO_IMAGE_BATCH_PARAMS.union({"token_indices"}) + params = TEXT_GUIDED_IMAGE_INPAINTING_PARAMS - {"height", "width", "image"} | {"image_latents"} + batch_params = TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS - {"image"} | {"image_latents"} def get_dummy_components(self): torch.manual_seed(0) @@ -115,14 +114,14 @@ def get_dummy_components(self): def get_dummy_inputs(self, device, seed=0): mask = floats_tensor((1, 16, 16), rng=random.Random(seed)).to(device) - latents = floats_tensor((2, 4, 16, 16), rng=random.Random(seed)).to(device) + latents = floats_tensor((1, 2, 4, 16, 16), rng=random.Random(seed)).to(device) if str(device).startswith("mps"): generator = torch.manual_seed(seed) else: generator = torch.Generator(device=device).manual_seed(seed) inputs = { "prompt": "a dog and a newt", - "mask": mask, + "mask_image": mask, "image_latents": latents, "generator": generator, "num_inference_steps": 2, From 862ce2e25d34c40060341c9deacc6bc01198802b Mon Sep 17 00:00:00 2001 From: Clarence Chen Date: Mon, 3 Apr 2023 00:28:34 -0700 Subject: [PATCH 21/41] Enable MPS Tests --- .../stable_diffusion_2/test_stable_diffusion_diffedit.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_diffedit.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_diffedit.py index a24817715015..33d8d377a2ac 100644 --- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_diffedit.py +++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_diffedit.py @@ -29,14 +29,13 @@ StableDiffusionDiffEditPipeline, UNet2DConditionModel, ) -from diffusers.utils import load_image, load_numpy, skip_mps, slow +from diffusers.utils import load_image, load_numpy, slow from diffusers.utils.testing_utils import floats_tensor, require_torch_gpu from ...pipeline_params import TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS, TEXT_GUIDED_IMAGE_INPAINTING_PARAMS from ...test_pipelines_common import PipelineTesterMixin -@skip_mps class StableDiffusionDiffEditPipelineFastTests(PipelineTesterMixin, unittest.TestCase): pipeline_class = StableDiffusionDiffEditPipeline params = TEXT_GUIDED_IMAGE_INPAINTING_PARAMS - {"height", "width", "image"} | {"image_latents"} From 7de158a826515104d15b3728fb0914b9804436ac Mon Sep 17 00:00:00 2001 From: Clarence Chen Date: Tue, 4 Apr 2023 00:49:32 -0700 Subject: [PATCH 22/41] Move example docstrings --- .../pipeline_stable_diffusion_diffedit.py | 139 ++++++++++-------- .../test_stable_diffusion_diffedit.py | 3 + 2 files changed, 77 insertions(+), 65 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py index 83ab0eae81a2..595e868f1377 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py @@ -34,6 +34,7 @@ is_accelerate_version, logging, randn_tensor, + replace_example_docstring, ) from ..pipeline_utils import DiffusionPipeline from . import StableDiffusionPipelineOutput @@ -61,6 +62,77 @@ class DiffEditInversionPipelineOutput(BaseOutput): images: Union[List[PIL.Image.Image], np.ndarray] +EXAMPLE_DOC_STRING = """ + + ```py + >>> import PIL + >>> import requests + >>> import torch + >>> from io import BytesIO + + >>> from diffusers import StableDiffusionDiffEditPipeline + + + >>> def download_image(url): + ... response = requests.get(url) + ... return PIL.Image.open(BytesIO(response.content)).convert("RGB") + + >>> img_url = "https://github.com/Xiang-cd/DiffEdit-stable-diffusion/raw/main/assets/origin.png" + + >>> init_image = download_image(img_url).resize((768, 768)) + + >>> pipe = StableDiffusionDiffEditPipeline.from_pretrained( + ... "stabilityai/stable-diffusion-2-1", torch_dtype=torch.float16 + ... ) + >>> pipe = pipe.to("cuda") + + >>> pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config) + >>> pipeline.inverse_scheduler = DDIMInverseScheduler.from_config(pipeline.scheduler.config) + >>> pipeline.enable_model_cpu_offload() + + >>> mask_prompt = "A bowl of fruits" + >>> prompt = "A bowl of pears" + + >>> mask_image = pipe.generate_mask(image=init_image, source_prompt=prompt, target_prompt=mask_prompt) + >>> image_latents = pipe.invert(image=init_image, prompt=mask_prompt).latents + >>> image = pipe(prompt=prompt, mask_image=mask_image, image_latents=image_latents).images[0] + ``` +""" + +EXAMPLE_INVERT_DOC_STRING = """ + ```py + >>> import PIL + >>> import requests + >>> import torch + >>> from io import BytesIO + + >>> from diffusers import StableDiffusionDiffEditPipeline + + + >>> def download_image(url): + ... response = requests.get(url) + ... return PIL.Image.open(BytesIO(response.content)).convert("RGB") + + >>> img_url = "https://github.com/Xiang-cd/DiffEdit-stable-diffusion/raw/main/assets/origin.png" + + >>> init_image = download_image(img_url).resize((768, 768)) + + >>> pipe = StableDiffusionDiffEditPipeline.from_pretrained( + ... "stabilityai/stable-diffusion-2-1", torch_dtype=torch.float16 + ... ) + >>> pipe = pipe.to("cuda") + + >>> pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config) + >>> pipeline.inverse_scheduler = DDIMInverseScheduler.from_config(pipeline.scheduler.config) + >>> pipeline.enable_model_cpu_offload() + + >>> prompt = "A bowl of fruits" + + >>> inverted_latents = pipe.invert(image=init_image, prompt=prompt).latents + ``` +""" + + def auto_corr_loss(hidden_states, generator=None): reg_loss = 0.0 for i in range(hidden_states.shape[0]): @@ -970,6 +1042,7 @@ def generate_mask( return mask_image @torch.no_grad() + @replace_example_docstring(EXAMPLE_INVERT_DOC_STRING) def invert( self, prompt: Optional[Union[str, List[str]]] = None, @@ -1061,37 +1134,6 @@ def invert( Examples: - ```py - >>> import PIL - >>> import requests - >>> import torch - >>> from io import BytesIO - - >>> from diffusers import StableDiffusionDiffEditPipeline - - - >>> def download_image(url): - ... response = requests.get(url) - ... return PIL.Image.open(BytesIO(response.content)).convert("RGB") - - >>> img_url = "https://github.com/Xiang-cd/DiffEdit-stable-diffusion/raw/main/assets/origin.png" - - >>> init_image = download_image(img_url).resize((768, 768)) - - >>> pipe = StableDiffusionDiffEditPipeline.from_pretrained( - ... "stabilityai/stable-diffusion-2-1", torch_dtype=torch.float16 - ... ) - >>> pipe = pipe.to("cuda") - - >>> pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config) - >>> pipeline.inverse_scheduler = DDIMInverseScheduler.from_config(pipeline.scheduler.config) - >>> pipeline.enable_model_cpu_offload() - - >>> prompt = "A bowl of fruits" - - >>> inverted_latents = pipe.invert(image=init_image, prompt=prompt).latents - ``` - Returns: [`~pipelines.stable_diffusion.pipeline_stable_diffusion_diffedit.DiffEditInversionPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.pipeline_stable_diffusion_diffedit.DiffEditInversionPipelineOutput`] @@ -1239,6 +1281,7 @@ def invert( return DiffEditInversionPipelineOutput(latents=latents, images=image) @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Optional[Union[str, List[str]]] = None, @@ -1330,40 +1373,6 @@ def __call__( Examples: - ```py - >>> import PIL - >>> import requests - >>> import torch - >>> from io import BytesIO - - >>> from diffusers import StableDiffusionDiffEditPipeline - - - >>> def download_image(url): - ... response = requests.get(url) - ... return PIL.Image.open(BytesIO(response.content)).convert("RGB") - - >>> img_url = "https://github.com/Xiang-cd/DiffEdit-stable-diffusion/raw/main/assets/origin.png" - - >>> init_image = download_image(img_url).resize((768, 768)) - - >>> pipe = StableDiffusionDiffEditPipeline.from_pretrained( - ... "stabilityai/stable-diffusion-2-1", torch_dtype=torch.float16 - ... ) - >>> pipe = pipe.to("cuda") - - >>> pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config) - >>> pipeline.inverse_scheduler = DDIMInverseScheduler.from_config(pipeline.scheduler.config) - >>> pipeline.enable_model_cpu_offload() - - >>> mask_prompt = "A bowl of fruits" - >>> prompt = "A bowl of pears" - - >>> mask_image = pipe.generate_mask(image=init_image, source_prompt=prompt, target_prompt=mask_prompt) - >>> image_latents = pipe.invert(image=init_image, prompt=mask_prompt).latents - >>> image = pipe(prompt=prompt, mask_image=mask_image, image_latents=image_latents).images[0] - ``` - Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_diffedit.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_diffedit.py index 33d8d377a2ac..1fe28d5e1ab8 100644 --- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_diffedit.py +++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_diffedit.py @@ -36,6 +36,9 @@ from ...test_pipelines_common import PipelineTesterMixin +torch.backends.cuda.matmul.allow_tf32 = False + + class StableDiffusionDiffEditPipelineFastTests(PipelineTesterMixin, unittest.TestCase): pipeline_class = StableDiffusionDiffEditPipeline params = TEXT_GUIDED_IMAGE_INPAINTING_PARAMS - {"height", "width", "image"} | {"image_latents"} From 79cf02baf35eadea270ea330d2a5dca131d95e94 Mon Sep 17 00:00:00 2001 From: Clarence Chen Date: Tue, 4 Apr 2023 01:05:48 -0700 Subject: [PATCH 23/41] Fix test --- .../source/en/api/pipelines/stable_diffusion/diffedit.mdx | 4 +--- .../pipeline_stable_diffusion_diffedit.py | 8 +++++--- .../stable_diffusion_2/test_stable_diffusion_diffedit.py | 8 +++++--- 3 files changed, 11 insertions(+), 9 deletions(-) diff --git a/docs/source/en/api/pipelines/stable_diffusion/diffedit.mdx b/docs/source/en/api/pipelines/stable_diffusion/diffedit.mdx index 88fa692c1fc7..a7cd906e0e77 100644 --- a/docs/source/en/api/pipelines/stable_diffusion/diffedit.mdx +++ b/docs/source/en/api/pipelines/stable_diffusion/diffedit.mdx @@ -154,9 +154,7 @@ Then, we define a utility to generate captions from an input image using the mod def generate_caption(images, caption_generator, caption_processor): text = "a photograph of" - inputs = caption_processor(images, text, return_tensors="pt").to( - device="cuda", dtype=caption_generator.dtype - ) + inputs = caption_processor(images, text, return_tensors="pt").to(device="cuda", dtype=caption_generator.dtype) caption_generator.to("cuda") outputs = caption_generator.generate(**inputs, max_new_tokens=128) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py index 595e868f1377..93071f779319 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py @@ -77,6 +77,7 @@ class DiffEditInversionPipelineOutput(BaseOutput): ... response = requests.get(url) ... return PIL.Image.open(BytesIO(response.content)).convert("RGB") + >>> img_url = "https://github.com/Xiang-cd/DiffEdit-stable-diffusion/raw/main/assets/origin.png" >>> init_image = download_image(img_url).resize((768, 768)) @@ -113,6 +114,7 @@ class DiffEditInversionPipelineOutput(BaseOutput): ... response = requests.get(url) ... return PIL.Image.open(BytesIO(response.content)).convert("RGB") + >>> img_url = "https://github.com/Xiang-cd/DiffEdit-stable-diffusion/raw/main/assets/origin.png" >>> init_image = download_image(img_url).resize((768, 768)) @@ -178,17 +180,17 @@ def preprocess(image): def preprocess_mask(mask, batch_size: int = 1): if not isinstance(mask, torch.Tensor): # preprocess mask - if isinstance(mask, PIL.Image.Image) or (isinstance(mask, np.ndarray) and mask.ndim < 3): + if isinstance(mask, PIL.Image.Image) or isinstance(mask, np.ndarray): mask = [mask] if isinstance(mask, list): if isinstance(mask[0], PIL.Image.Image): mask = [np.array(m.convert("L")).astype(np.float32) / 255.0 for m in mask] if isinstance(mask[0], np.ndarray): - mask = np.stack(mask, axis=0) + mask = np.stack(mask, axis=0) if mask[0].ndim < 3 else np.concatenate(mask, axis=0) mask = torch.from_numpy(mask) elif isinstance(mask[0], torch.Tensor): - mask = torch.stack(mask, dim=0) + mask = torch.stack(mask, dim=0) if mask[0].ndim < 3 else torch.cat(mask, dim=0) # Batch and add channel dim for single mask if mask.ndim == 2: diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_diffedit.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_diffedit.py index 1fe28d5e1ab8..43ce70546d28 100644 --- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_diffedit.py +++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_diffedit.py @@ -250,15 +250,17 @@ def test_stable_diffusion_diffedit_full(self): image=self.raw_image, source_prompt=source_prompt, target_prompt=target_prompt, - inpaint_strength=0.6, + mask_encode_strength=0.6, generator=generator, ) - inv_latents = pipe.invert(prompt=source_prompt, image=self.raw_image, generator=generator).latents + inv_latents = pipe.invert( + prompt=source_prompt, image=self.raw_image, inpaint_strength=0.6, generator=generator + ).latents image = pipe( prompt=target_prompt, - mask=mask_image, + mask_image=mask_image, image_latents=inv_latents, generator=generator, negative_prompt=source_prompt, From d4d8c8407d046769a6de9a28cf72122baeb264d5 Mon Sep 17 00:00:00 2001 From: Clarence Chen Date: Tue, 4 Apr 2023 18:41:00 -0700 Subject: [PATCH 24/41] Fix test --- .../stable_diffusion_2/test_stable_diffusion_diffedit.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_diffedit.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_diffedit.py index 43ce70546d28..b25d7ae5f65b 100644 --- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_diffedit.py +++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_diffedit.py @@ -29,7 +29,7 @@ StableDiffusionDiffEditPipeline, UNet2DConditionModel, ) -from diffusers.utils import load_image, load_numpy, slow +from diffusers.utils import load_image, slow from diffusers.utils.testing_utils import floats_tensor, require_torch_gpu from ...pipeline_params import TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS, TEXT_GUIDED_IMAGE_INPAINTING_PARAMS @@ -268,7 +268,7 @@ def test_stable_diffusion_diffedit_full(self): output_type="numpy", ).images[0] - expected_image = load_numpy( - "https://raw.githubusercontent.com/Xiang-cd/DiffEdit-stable-diffusion/main/assets/target.png" + expected_image = np.array( + load_image("https://raw.githubusercontent.com/Xiang-cd/DiffEdit-stable-diffusion/main/assets/target.png") ) assert np.abs((expected_image - image).max()) < 1e-1 From ff0fbc27f7bfc0c6cef20005d609152168a90a20 Mon Sep 17 00:00:00 2001 From: Clarence Chen Date: Tue, 4 Apr 2023 18:37:23 -0700 Subject: [PATCH 25/41] fix pipeline inheritance --- .../stable_diffusion/pipeline_stable_diffusion_diffedit.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py index 540c68f8e980..594abc4d07d3 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py @@ -230,7 +230,7 @@ def preprocess_mask(mask, batch_size: int = 1): return mask -class StableDiffusionDiffEditPipeline(DiffusionPipeline): +class StableDiffusionDiffEditPipeline(DiffusionPipeline, TextualInversionLoaderMixin): r""" Pipeline for text-guided image inpainting using Stable Diffusion using DiffEdit. *This is an experimental feature*. From 0db442d3bb7b3f3ea0c8d628fe7c31cb9402e087 Mon Sep 17 00:00:00 2001 From: Clarence Chen Date: Wed, 5 Apr 2023 00:22:47 -0700 Subject: [PATCH 26/41] Harmonize `prepare_image_latents` with StableDiffusionPix2PixZeroPipeline --- .../pipeline_stable_diffusion_diffedit.py | 11 +++++- .../pipeline_stable_diffusion_pix2pix_zero.py | 39 ++++++++++++------- 2 files changed, 34 insertions(+), 16 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py index 594abc4d07d3..e7cf513dac96 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py @@ -771,6 +771,7 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype latents = latents * self.scheduler.init_noise_sigma return latents + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_pix2pix_zero.StableDiffusionPix2PixZeroPipeline.prepare_image_latents def prepare_image_latents(self, image, batch_size, dtype, device, generator=None): if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): raise ValueError( @@ -793,8 +794,16 @@ def prepare_image_latents(self, image, batch_size, dtype, device, generator=None latents = self.vae.config.scaling_factor * latents - if batch_size > latents.shape[0]: + if batch_size != latents.shape[0]: if batch_size % latents.shape[0] == 0: + # expand image_latents for batch_size + deprecation_message = ( + f"You have passed {batch_size} text prompts (`prompt`), but only {latents.shape[0]} initial" + " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" + " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" + " your script to pass as many initial images as text prompts to suppress this warning." + ) + deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) additional_latents_per_image = batch_size // latents.shape[0] latents = torch.cat([latents] * additional_latents_per_image, dim=0) else: diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py index 6af923cb7743..b625f0b6b95e 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py @@ -36,6 +36,7 @@ from ...utils import ( PIL_INTERPOLATION, BaseOutput, + deprecate, is_accelerate_available, is_accelerate_version, logging, @@ -721,23 +722,31 @@ def prepare_image_latents(self, image, batch_size, dtype, device, generator=None ) if isinstance(generator, list): - init_latents = [ - self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size) - ] - init_latents = torch.cat(init_latents, dim=0) + latents = [self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)] + latents = torch.cat(latents, dim=0) else: - init_latents = self.vae.encode(image).latent_dist.sample(generator) - - init_latents = self.vae.config.scaling_factor * init_latents - - if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: - raise ValueError( - f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." - ) + latents = self.vae.encode(image).latent_dist.sample(generator) + + latents = self.vae.config.scaling_factor * latents + + if batch_size != latents.shape[0]: + if batch_size % latents.shape[0] == 0: + # expand image_latents for batch_size + deprecation_message = ( + f"You have passed {batch_size} text prompts (`prompt`), but only {latents.shape[0]} initial" + " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" + " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" + " your script to pass as many initial images as text prompts to suppress this warning." + ) + deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) + additional_latents_per_image = batch_size // latents.shape[0] + latents = torch.cat([latents] * additional_latents_per_image, dim=0) + else: + raise ValueError( + f"Cannot duplicate `image` of batch size {latents.shape[0]} to {batch_size} text prompts." + ) else: - init_latents = torch.cat([init_latents], dim=0) - - latents = init_latents + latents = torch.cat([latents], dim=0) return latents From a73f4db8da6c8143b040277668ea931b927557ff Mon Sep 17 00:00:00 2001 From: Clarence Chen Date: Wed, 5 Apr 2023 00:44:21 -0700 Subject: [PATCH 27/41] Register modules set to `None` in config for `test_save_load_optional_components` --- tests/test_pipelines_common.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_pipelines_common.py b/tests/test_pipelines_common.py index 13fbe924c799..d5a443300844 100644 --- a/tests/test_pipelines_common.py +++ b/tests/test_pipelines_common.py @@ -394,9 +394,10 @@ def test_save_load_optional_components(self): pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) - # set all optional components to None + # set all optional components to None and update pipeline config accordingly for optional_component in pipe._optional_components: setattr(pipe, optional_component, None) + pipe.register_modules(**{optional_component: None for optional_component in pipe._optional_components}) inputs = self.get_dummy_inputs(torch_device) output = pipe(**inputs)[0] From af62b302484f7593e6c075c70016dbc592df4586 Mon Sep 17 00:00:00 2001 From: Clarence Chen Date: Thu, 6 Apr 2023 12:28:37 -0700 Subject: [PATCH 28/41] Move fixed logic to specific test class --- .../test_stable_diffusion_diffedit.py | 38 ++++++++++++++++++- tests/test_pipelines_common.py | 3 +- 2 files changed, 38 insertions(+), 3 deletions(-) diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_diffedit.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_diffedit.py index b25d7ae5f65b..777844d7521f 100644 --- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_diffedit.py +++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_diffedit.py @@ -15,6 +15,7 @@ import gc import random +import tempfile import unittest import numpy as np @@ -30,7 +31,7 @@ UNet2DConditionModel, ) from diffusers.utils import load_image, slow -from diffusers.utils.testing_utils import floats_tensor, require_torch_gpu +from diffusers.utils.testing_utils import floats_tensor, require_torch_gpu, torch_device from ...pipeline_params import TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS, TEXT_GUIDED_IMAGE_INPAINTING_PARAMS from ...test_pipelines_common import PipelineTesterMixin @@ -176,6 +177,41 @@ def get_dummy_inversion_inputs(self, device, seed=0): } return inputs + def test_save_load_optional_components(self): + if not hasattr(self.pipeline_class, "_optional_components"): + return + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + # set all optional components to None and update pipeline config accordingly + for optional_component in pipe._optional_components: + setattr(pipe, optional_component, None) + pipe.register_modules(**{optional_component: None for optional_component in pipe._optional_components}) + + inputs = self.get_dummy_inputs(torch_device) + output = pipe(**inputs)[0] + + with tempfile.TemporaryDirectory() as tmpdir: + pipe.save_pretrained(tmpdir) + pipe_loaded = self.pipeline_class.from_pretrained(tmpdir) + pipe_loaded.to(torch_device) + pipe_loaded.set_progress_bar_config(disable=None) + + for optional_component in pipe._optional_components: + self.assertTrue( + getattr(pipe_loaded, optional_component) is None, + f"`{optional_component}` did not stay set to None after loading.", + ) + + inputs = self.get_dummy_inputs(torch_device) + output_loaded = pipe_loaded(**inputs)[0] + + max_diff = np.abs(output - output_loaded).max() + self.assertLess(max_diff, 1e-4) + def test_mask(self): device = "cpu" diff --git a/tests/test_pipelines_common.py b/tests/test_pipelines_common.py index d5a443300844..13fbe924c799 100644 --- a/tests/test_pipelines_common.py +++ b/tests/test_pipelines_common.py @@ -394,10 +394,9 @@ def test_save_load_optional_components(self): pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) - # set all optional components to None and update pipeline config accordingly + # set all optional components to None for optional_component in pipe._optional_components: setattr(pipe, optional_component, None) - pipe.register_modules(**{optional_component: None for optional_component in pipe._optional_components}) inputs = self.get_dummy_inputs(torch_device) output = pipe(**inputs)[0] From 135c390b52c2f90a80c6b1426b61664ef330e474 Mon Sep 17 00:00:00 2001 From: Clarence Chen Date: Wed, 12 Apr 2023 10:19:18 -0700 Subject: [PATCH 29/41] Clean changes to other pipelines --- .../pipeline_stable_diffusion_diffedit.py | 1 - .../pipeline_stable_diffusion_pix2pix_zero.py | 39 +++++++------------ 2 files changed, 15 insertions(+), 25 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py index e7cf513dac96..e835efe790e5 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py @@ -771,7 +771,6 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype latents = latents * self.scheduler.init_noise_sigma return latents - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_pix2pix_zero.StableDiffusionPix2PixZeroPipeline.prepare_image_latents def prepare_image_latents(self, image, batch_size, dtype, device, generator=None): if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): raise ValueError( diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py index da2e82d172b5..0239c8128171 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py @@ -36,7 +36,6 @@ from ...utils import ( PIL_INTERPOLATION, BaseOutput, - deprecate, is_accelerate_available, is_accelerate_version, logging, @@ -722,31 +721,23 @@ def prepare_image_latents(self, image, batch_size, dtype, device, generator=None ) if isinstance(generator, list): - latents = [self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)] - latents = torch.cat(latents, dim=0) + init_latents = [ + self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size) + ] + init_latents = torch.cat(init_latents, dim=0) else: - latents = self.vae.encode(image).latent_dist.sample(generator) - - latents = self.vae.config.scaling_factor * latents - - if batch_size != latents.shape[0]: - if batch_size % latents.shape[0] == 0: - # expand image_latents for batch_size - deprecation_message = ( - f"You have passed {batch_size} text prompts (`prompt`), but only {latents.shape[0]} initial" - " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" - " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" - " your script to pass as many initial images as text prompts to suppress this warning." - ) - deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) - additional_latents_per_image = batch_size // latents.shape[0] - latents = torch.cat([latents] * additional_latents_per_image, dim=0) - else: - raise ValueError( - f"Cannot duplicate `image` of batch size {latents.shape[0]} to {batch_size} text prompts." - ) + init_latents = self.vae.encode(image).latent_dist.sample(generator) + + init_latents = self.vae.config.scaling_factor * init_latents + + if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." + ) else: - latents = torch.cat([latents], dim=0) + init_latents = torch.cat([init_latents], dim=0) + + latents = init_latents return latents From a46787d9d37c8368a7144a587bf0829b4f67abd2 Mon Sep 17 00:00:00 2001 From: Clarence Chen Date: Wed, 12 Apr 2023 15:55:01 -0700 Subject: [PATCH 30/41] Update new tests to coordinate with #2953 --- .../stable_diffusion_2/test_stable_diffusion_diffedit.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_diffedit.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_diffedit.py index 777844d7521f..4a444ec3c15f 100644 --- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_diffedit.py +++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_diffedit.py @@ -228,7 +228,7 @@ def test_mask(self): expected_slice = np.array([0] * 9) max_diff = np.abs(mask_slice.flatten() - expected_slice).max() self.assertLessEqual(max_diff, 1e-3) - self.assertEqual(mask[0, -3, -4], 1) + self.assertEqual(mask[0, -3, -4], 0) def test_inversion(self): device = "cpu" @@ -244,7 +244,7 @@ def test_inversion(self): self.assertEqual(image.shape, (2, 32, 32, 3)) expected_slice = np.array( - [0.5588859, 0.535619, 0.52224344, 0.55604255, 0.48608556, 0.51105076, 0.50301707, 0.44348782, 0.48488846], + [0.5150, 0.5134, 0.5043, 0.5376, 0.4694, 0.51050, 0.5015, 0.4407, 0.4799], ) max_diff = np.abs(image_slice.flatten() - expected_slice).max() self.assertLessEqual(max_diff, 1e-3) @@ -305,6 +305,8 @@ def test_stable_diffusion_diffedit_full(self): ).images[0] expected_image = np.array( - load_image("https://raw.githubusercontent.com/Xiang-cd/DiffEdit-stable-diffusion/main/assets/target.png") + load_image( + "https://raw.githubusercontent.com/Xiang-cd/DiffEdit-stable-diffusion/main/assets/target.png" + ).resize((768, 768)) ) assert np.abs((expected_image - image).max()) < 1e-1 From 37fb12d92443327551ff9b7434c9fce9950cb09a Mon Sep 17 00:00:00 2001 From: Clarence Chen Date: Wed, 12 Apr 2023 16:46:37 -0700 Subject: [PATCH 31/41] Update slow tests for better results --- .../test_stable_diffusion_diffedit.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_diffedit.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_diffedit.py index 4a444ec3c15f..b6884bb1581c 100644 --- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_diffedit.py +++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_diffedit.py @@ -261,7 +261,7 @@ def tearDown(self): @classmethod def setUpClass(cls): raw_image = load_image( - "https://raw.githubusercontent.com/Xiang-cd/DiffEdit-stable-diffusion/main/assets/origin.png" + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/diffedit/fruit.png" ) raw_image = raw_image.convert("RGB").resize((768, 768)) @@ -286,12 +286,11 @@ def test_stable_diffusion_diffedit_full(self): image=self.raw_image, source_prompt=source_prompt, target_prompt=target_prompt, - mask_encode_strength=0.6, generator=generator, ) inv_latents = pipe.invert( - prompt=source_prompt, image=self.raw_image, inpaint_strength=0.6, generator=generator + prompt=source_prompt, image=self.raw_image, inpaint_strength=0.7, generator=generator ).latents image = pipe( @@ -300,13 +299,14 @@ def test_stable_diffusion_diffedit_full(self): image_latents=inv_latents, generator=generator, negative_prompt=source_prompt, - inpaint_strength=0.6, + inpaint_strength=0.7, output_type="numpy", ).images[0] expected_image = np.array( load_image( - "https://raw.githubusercontent.com/Xiang-cd/DiffEdit-stable-diffusion/main/assets/target.png" + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + "/diffedit/pears.png" ).resize((768, 768)) ) assert np.abs((expected_image - image).max()) < 1e-1 From 0c67b2f058a3b83b5a55710dc159a09ead63521d Mon Sep 17 00:00:00 2001 From: Clarence Chen Date: Thu, 13 Apr 2023 16:26:50 -0700 Subject: [PATCH 32/41] Safety to avoid potential problems with torch.inference_mode --- .../pipeline_stable_diffusion_diffedit.py | 41 ++++++++++--------- 1 file changed, 21 insertions(+), 20 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py index e835efe790e5..f704f220ed6f 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py @@ -1229,34 +1229,35 @@ def invert( noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # regularization of the noise prediction (not in original code or paper but borrowed from Pix2PixZero) - with torch.enable_grad(): - for _ in range(num_reg_steps): - if lambda_auto_corr > 0: - for _ in range(num_auto_corr_rolls): - var = torch.autograd.Variable(noise_pred.detach().clone(), requires_grad=True) + if num_reg_steps > 0: + with torch.enable_grad(): + for _ in range(num_reg_steps): + if lambda_auto_corr > 0: + for _ in range(num_auto_corr_rolls): + var = torch.autograd.Variable(noise_pred.detach().clone(), requires_grad=True) - # Derive epsilon from model output before regularizing to IID standard normal - var_epsilon = self.get_epsilon(var, latent_model_input.detach(), t) + # Derive epsilon from model output before regularizing to IID standard normal + var_epsilon = self.get_epsilon(var, latent_model_input.detach(), t) - l_ac = self.auto_corr_loss(var_epsilon, generator=generator) - l_ac.backward() + l_ac = self.auto_corr_loss(var_epsilon, generator=generator) + l_ac.backward() - grad = var.grad.detach() / num_auto_corr_rolls - noise_pred = noise_pred - lambda_auto_corr * grad + grad = var.grad.detach() / num_auto_corr_rolls + noise_pred = noise_pred - lambda_auto_corr * grad - if lambda_kl > 0: - var = torch.autograd.Variable(noise_pred.detach().clone(), requires_grad=True) + if lambda_kl > 0: + var = torch.autograd.Variable(noise_pred.detach().clone(), requires_grad=True) - # Derive epsilon from model output before regularizing to IID standard normal - var_epsilon = self.get_epsilon(var, latent_model_input.detach(), t) + # Derive epsilon from model output before regularizing to IID standard normal + var_epsilon = self.get_epsilon(var, latent_model_input.detach(), t) - l_kld = self.kl_divergence(var_epsilon) - l_kld.backward() + l_kld = self.kl_divergence(var_epsilon) + l_kld.backward() - grad = var.grad.detach() - noise_pred = noise_pred - lambda_kl * grad + grad = var.grad.detach() + noise_pred = noise_pred - lambda_kl * grad - noise_pred = noise_pred.detach() + noise_pred = noise_pred.detach() # compute the previous noisy sample x_t -> x_t-1 latents = self.inverse_scheduler.step(noise_pred, t, latents).prev_sample From d79cac5c2ebdb1d0da3d27dffb2d6510d2f0dc2e Mon Sep 17 00:00:00 2001 From: Clarence Chen Date: Thu, 13 Apr 2023 16:42:47 -0700 Subject: [PATCH 33/41] Add reference in SD Pipeline Overview --- docs/source/en/api/pipelines/stable_diffusion/overview.mdx | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/source/en/api/pipelines/stable_diffusion/overview.mdx b/docs/source/en/api/pipelines/stable_diffusion/overview.mdx index 70731fd294b9..a163b57f2a84 100644 --- a/docs/source/en/api/pipelines/stable_diffusion/overview.mdx +++ b/docs/source/en/api/pipelines/stable_diffusion/overview.mdx @@ -36,6 +36,7 @@ For more details about how Stable Diffusion works and how it differs from the ba | [StableDiffusionAttendAndExcitePipeline](./attend_and_excite) | **Experimental** – *Text-to-Image Generation * | | [Attend-and-Excite: Attention-Based Semantic Guidance for Text-to-Image Diffusion Models](https://huggingface.co/spaces/AttendAndExcite/Attend-and-Excite) | [StableDiffusionPix2PixZeroPipeline](./pix2pix_zero) | **Experimental** – *Text-Based Image Editing * | | [Zero-shot Image-to-Image Translation](https://arxiv.org/abs/2302.03027) | [StableDiffusionModelEditingPipeline](./model_editing) | **Experimental** – *Text-to-Image Model Editing * | | [Editing Implicit Assumptions in Text-to-Image Diffusion Models](https://arxiv.org/abs/2303.08084) +| [StableDiffusionDiffEditPipeline](./diffedit) | **Experimental** – *Text-Based Image Editing * | | [DiffEdit: Diffusion-based semantic image editing with mask guidance](https://arxiv.org/abs/2210.11427) From b6a81e888fb2eb183fc67e27f536d78354a0d5a9 Mon Sep 17 00:00:00 2001 From: Clarence Chen Date: Thu, 13 Apr 2023 16:42:47 -0700 Subject: [PATCH 34/41] Fix tests again --- .../test_stable_diffusion_diffedit.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_diffedit.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_diffedit.py index 973bfab5e560..de69a40e82a8 100644 --- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_diffedit.py +++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_diffedit.py @@ -303,10 +303,13 @@ def test_stable_diffusion_diffedit_full(self): output_type="numpy", ).images[0] - expected_image = np.array( - load_image( - "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" - "/diffedit/pears.png" - ).resize((768, 768)) + expected_image = ( + np.array( + load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + "/diffedit/pears.png" + ).resize((768, 768)) + ) + / 255 ) - assert np.abs((expected_image - image).max()) < 1e-1 + assert np.abs((expected_image - image).max()) < 1e-3 From 599750ebe0107f9073487e8e64c6f0771e1ac592 Mon Sep 17 00:00:00 2001 From: Clarence Chen Date: Tue, 18 Apr 2023 14:28:18 -0700 Subject: [PATCH 35/41] Enforce determinism in noise for generate_mask --- .../stable_diffusion/pipeline_stable_diffusion_diffedit.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py index f704f220ed6f..1374c8ca804e 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py @@ -1008,7 +1008,7 @@ def generate_mask( image_latents = self.prepare_image_latents( image, batch_size * num_maps_per_mask, self.vae.dtype, device, generator ) - noise = torch.randn_like(image_latents) + noise = randn_tensor(image_latents.shape, generator=generator, device=device, dtype=self.vae.dtype) image_latents = self.scheduler.add_noise(image_latents, noise, encode_timestep) latent_model_input = torch.cat([image_latents] * (4 if do_classifier_free_guidance else 2)) From 55cad5b7dff1c0704ad957cee5e1320b7bc6b58a Mon Sep 17 00:00:00 2001 From: Clarence Chen Date: Tue, 18 Apr 2023 14:53:50 -0700 Subject: [PATCH 36/41] Fix copies --- .../stable_diffusion/pipeline_stable_diffusion_diffedit.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py index 1374c8ca804e..153d688e4cb3 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py @@ -736,7 +736,7 @@ def get_timesteps(self, num_inference_steps, strength, device): init_timestep = min(int(num_inference_steps * strength), num_inference_steps) t_start = max(num_inference_steps - init_timestep, 0) - timesteps = self.scheduler.timesteps[t_start:] + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] return timesteps, num_inference_steps - t_start From a5986a9077b014b96373c8dda5c4b9e8bce025a8 Mon Sep 17 00:00:00 2001 From: Clarence Chen Date: Tue, 18 Apr 2023 15:08:16 -0700 Subject: [PATCH 37/41] Widen test tolerance for fp16 based on `test_stable_diffusion_upscale_pipeline_fp16` --- .../stable_diffusion_2/test_stable_diffusion_diffedit.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_diffedit.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_diffedit.py index de69a40e82a8..c20bc3b47d7b 100644 --- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_diffedit.py +++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_diffedit.py @@ -312,4 +312,4 @@ def test_stable_diffusion_diffedit_full(self): ) / 255 ) - assert np.abs((expected_image - image).max()) < 1e-3 + assert np.abs((expected_image - image).max()) < 5e-1 From 6101d4ac03408b0b0f7be2ac90311a8950f24a2c Mon Sep 17 00:00:00 2001 From: Clarence Chen Date: Thu, 20 Apr 2023 12:21:46 -0700 Subject: [PATCH 38/41] Add LoraLoaderMixin and update `prepare_image_latents` --- .../pipeline_stable_diffusion_diffedit.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py index 153d688e4cb3..e80617c978d4 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py @@ -1,4 +1,4 @@ -# Copyright 2023 The HuggingFace Team. All rights reserved. +# Copyright 2023 DiffEdit Authors and Pix2Pix Zero Authors and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -23,7 +23,7 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from ...configuration_utils import FrozenDict -from ...loaders import TextualInversionLoaderMixin +from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import DDIMInverseScheduler, KarrasDiffusionSchedulers from ...utils import ( @@ -230,13 +230,20 @@ def preprocess_mask(mask, batch_size: int = 1): return mask -class StableDiffusionDiffEditPipeline(DiffusionPipeline, TextualInversionLoaderMixin): +class StableDiffusionDiffEditPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin): r""" Pipeline for text-guided image inpainting using Stable Diffusion using DiffEdit. *This is an experimental feature*. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + In addition the pipeline inherits the following loading methods: + - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`] + - *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`] + + as well as the following saving methods: + - *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`] + Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. @@ -771,6 +778,7 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype latents = latents * self.scheduler.init_noise_sigma return latents + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_pix2pix_zero.StableDiffusionPix2PixZeroPipeline.prepare_image_latents def prepare_image_latents(self, image, batch_size, dtype, device, generator=None): if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): raise ValueError( From 7224532a84bbf9f9908d3f58da24851ee6b511cb Mon Sep 17 00:00:00 2001 From: Clarence Chen Date: Sat, 22 Apr 2023 15:50:14 -0700 Subject: [PATCH 39/41] clean up repeat and reg --- .../stable_diffusion/pipeline_stable_diffusion_diffedit.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py index e80617c978d4..74ac68eb97c8 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py @@ -1247,7 +1247,7 @@ def invert( # Derive epsilon from model output before regularizing to IID standard normal var_epsilon = self.get_epsilon(var, latent_model_input.detach(), t) - l_ac = self.auto_corr_loss(var_epsilon, generator=generator) + l_ac = auto_corr_loss(var_epsilon, generator=generator) l_ac.backward() grad = var.grad.detach() / num_auto_corr_rolls @@ -1472,7 +1472,7 @@ def __call__( f"`image_latents` must have batch size {batch_size} with latent images from {len(timesteps)} timesteps, " f"but has batch size {image_latents.shape[0]} with latent images from {image_latents.shape[1]} timesteps." ) - image_latents = torch.cat([image_latents.transpose(0, 1)] * num_images_per_prompt, 1) + image_latents = image_latents.transpose(0, 1).unsqueeze(1).repeat(1, num_images_per_prompt, 1, 1, 1, 1) image_latents = image_latents.to(device=device, dtype=prompt_embeds.dtype) # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline From 6b91ca624dc754a39788b9d9b9b0a0fc30a07e27 Mon Sep 17 00:00:00 2001 From: Clarence Chen Date: Sat, 22 Apr 2023 16:07:03 -0700 Subject: [PATCH 40/41] bugfix --- .../stable_diffusion/pipeline_stable_diffusion_diffedit.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py index 74ac68eb97c8..ff98045c9229 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py @@ -1259,7 +1259,7 @@ def invert( # Derive epsilon from model output before regularizing to IID standard normal var_epsilon = self.get_epsilon(var, latent_model_input.detach(), t) - l_kld = self.kl_divergence(var_epsilon) + l_kld = kl_divergence(var_epsilon) l_kld.backward() grad = var.grad.detach() @@ -1472,7 +1472,7 @@ def __call__( f"`image_latents` must have batch size {batch_size} with latent images from {len(timesteps)} timesteps, " f"but has batch size {image_latents.shape[0]} with latent images from {image_latents.shape[1]} timesteps." ) - image_latents = image_latents.transpose(0, 1).unsqueeze(1).repeat(1, num_images_per_prompt, 1, 1, 1, 1) + image_latents = image_latents.transpose(0, 1).repeat_interleave(num_images_per_prompt, dim=1) image_latents = image_latents.to(device=device, dtype=prompt_embeds.dtype) # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline From 9a49b01c889b66224bf8a990a84b76d3b34ad7c5 Mon Sep 17 00:00:00 2001 From: Clarence Chen Date: Wed, 26 Apr 2023 00:05:24 -0700 Subject: [PATCH 41/41] Remove invalid args from docs Suppress spurious warning by repeating image before latent to mask gen --- .../stable_diffusion/pipeline_stable_diffusion_diffedit.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py index ff98045c9229..9bef5269fa07 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py @@ -1005,7 +1005,7 @@ def generate_mask( ) # 4. Preprocess image - image = preprocess(image) + image = preprocess(image).repeat_interleave(num_maps_per_mask, dim=0) # 5. Set timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) @@ -1111,8 +1111,6 @@ def invert( The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). - num_images_per_prompt (`int`, *optional*, defaults to 1): - The number of images to generate per prompt. generator (`torch.Generator`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic.