Skip to content

[Community Pipelines] EDICT pipeline implementation #3153

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Apr 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 86 additions & 0 deletions examples/community/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ MagicMix | Diffusion Pipeline for semantic mixing of an image and a text prompt
| DDIM Noise Comparative Analysis Pipeline | Investigating how the diffusion models learn visual concepts from each noise level (which is a contribution of [P2 weighting (CVPR 2022)](https://arxiv.org/abs/2204.00227)) | [DDIM Noise Comparative Analysis Pipeline](#ddim-noise-comparative-analysis-pipeline) | - |[Aengus (Duc-Anh)](https://github.com/aengusng8) |
| CLIP Guided Img2Img Stable Diffusion Pipeline | Doing CLIP guidance for image to image generation with Stable Diffusion | [CLIP Guided Img2Img Stable Diffusion](#clip-guided-img2img-stable-diffusion) | - | [Nipun Jindal](https://github.com/nipunjindal/) |
| TensorRT Stable Diffusion Pipeline | Accelerates the Stable Diffusion Text2Image Pipeline using TensorRT | [TensorRT Stable Diffusion Pipeline](#tensorrt-text2image-stable-diffusion-pipeline) | - |[Asfiya Baig](https://github.com/asfiyab-nvidia) |
| EDICT Image Editing Pipeline | Diffusion pipeline for text-guided image editing | [EDICT Image Editing Pipeline](#edict-image-editing-pipeline) | - | [Joqsan Azocar](https://github.com/Joqsan) |



To load a custom pipeline you just need to pass the `custom_pipeline` argument to `DiffusionPipeline`, as one of the files in `diffusers/examples/community`. Feel free to send a PR with your own pipelines, we will merge them quickly.
Expand Down Expand Up @@ -1161,3 +1163,87 @@ prompt = "a beautiful photograph of Mt. Fuji during cherry blossom"
image = pipe(prompt).images[0]
image.save('tensorrt_mt_fuji.png')
```

### EDICT Image Editing Pipeline

This pipeline implements the text-guided image editing approach from the paper [EDICT: Exact Diffusion Inversion via Coupled Transformations](https://arxiv.org/abs/2211.12446). You have to pass:
- (`PIL`) `image` you want to edit.
- `base_prompt`: the text prompt describing the current image (before editing).
- `target_prompt`: the text prompt describing with the edits.

```python
from diffusers import DiffusionPipeline, DDIMScheduler
from transformers import CLIPTextModel
import torch, PIL, requests
from io import BytesIO
from IPython.display import display

def center_crop_and_resize(im):

width, height = im.size
d = min(width, height)
left = (width - d) / 2
upper = (height - d) / 2
right = (width + d) / 2
lower = (height + d) / 2

return im.crop((left, upper, right, lower)).resize((512, 512))

torch_dtype = torch.float16
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# scheduler and text_encoder param values as in the paper
scheduler = DDIMScheduler(
num_train_timesteps=1000,
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
set_alpha_to_one=False,
clip_sample=False,
)

text_encoder = CLIPTextModel.from_pretrained(
pretrained_model_name_or_path="openai/clip-vit-large-patch14",
torch_dtype=torch_dtype,
)

# initialize pipeline
pipeline = DiffusionPipeline.from_pretrained(
pretrained_model_name_or_path="CompVis/stable-diffusion-v1-4",
custom_pipeline="edict_pipeline",
revision="fp16",
scheduler=scheduler,
text_encoder=text_encoder,
leapfrog_steps=True,
torch_dtype=torch_dtype,
).to(device)

# download image
image_url = "https://huggingface.co/datasets/Joqsan/images/resolve/main/imagenet_dog_1.jpeg"
response = requests.get(image_url)
image = PIL.Image.open(BytesIO(response.content))

# preprocess it
cropped_image = center_crop_and_resize(image)

# define the prompts
base_prompt = "A dog"
target_prompt = "A golden retriever"

# run the pipeline
result_image = pipeline(
base_prompt=base_prompt,
target_prompt=target_prompt,
image=cropped_image,
)

display(result_image)
```

Init Image

![img2img_init_edict_text_editing](https://huggingface.co/datasets/Joqsan/images/resolve/main/imagenet_dog_1.jpeg)

Output Image

![img2img_edict_text_editing](https://huggingface.co/datasets/Joqsan/images/resolve/main/imagenet_dog_1_cropped_generated.png)
264 changes: 264 additions & 0 deletions examples/community/edict_pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,264 @@
from typing import Optional

import torch
from PIL import Image
from tqdm.auto import tqdm
from transformers import CLIPTextModel, CLIPTokenizer

from diffusers import AutoencoderKL, DDIMScheduler, DiffusionPipeline, UNet2DConditionModel
from diffusers.image_processor import VaeImageProcessor
from diffusers.utils import (
deprecate,
)


class EDICTPipeline(DiffusionPipeline):
def __init__(
self,
vae: AutoencoderKL,
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel,
scheduler: DDIMScheduler,
mixing_coeff: float = 0.93,
leapfrog_steps: bool = True,
):
self.mixing_coeff = mixing_coeff
self.leapfrog_steps = leapfrog_steps

super().__init__()
self.register_modules(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
)

self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)

def _encode_prompt(
self, prompt: str, negative_prompt: Optional[str] = None, do_classifier_free_guidance: bool = False
):
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)

prompt_embeds = self.text_encoder(text_inputs.input_ids.to(self.device)).last_hidden_state

prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=self.device)

if do_classifier_free_guidance:
uncond_tokens = "" if negative_prompt is None else negative_prompt

uncond_input = self.tokenizer(
uncond_tokens,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)

negative_prompt_embeds = self.text_encoder(uncond_input.input_ids.to(self.device)).last_hidden_state

prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])

return prompt_embeds

def denoise_mixing_layer(self, x: torch.Tensor, y: torch.Tensor):
x = self.mixing_coeff * x + (1 - self.mixing_coeff) * y
y = self.mixing_coeff * y + (1 - self.mixing_coeff) * x

return [x, y]

def noise_mixing_layer(self, x: torch.Tensor, y: torch.Tensor):
y = (y - (1 - self.mixing_coeff) * x) / self.mixing_coeff
x = (x - (1 - self.mixing_coeff) * y) / self.mixing_coeff

return [x, y]

def _get_alpha_and_beta(self, t: torch.Tensor):
# as self.alphas_cumprod is always in cpu
t = int(t)

alpha_prod = self.scheduler.alphas_cumprod[t] if t >= 0 else self.scheduler.final_alpha_cumprod

return alpha_prod, 1 - alpha_prod

def noise_step(
self,
base: torch.Tensor,
model_input: torch.Tensor,
model_output: torch.Tensor,
timestep: torch.Tensor,
):
prev_timestep = timestep - self.scheduler.config.num_train_timesteps / self.scheduler.num_inference_steps

alpha_prod_t, beta_prod_t = self._get_alpha_and_beta(timestep)
alpha_prod_t_prev, beta_prod_t_prev = self._get_alpha_and_beta(prev_timestep)

a_t = (alpha_prod_t_prev / alpha_prod_t) ** 0.5
b_t = -a_t * (beta_prod_t**0.5) + beta_prod_t_prev**0.5

next_model_input = (base - b_t * model_output) / a_t

return model_input, next_model_input.to(base.dtype)

def denoise_step(
self,
base: torch.Tensor,
model_input: torch.Tensor,
model_output: torch.Tensor,
timestep: torch.Tensor,
):
prev_timestep = timestep - self.scheduler.config.num_train_timesteps / self.scheduler.num_inference_steps

alpha_prod_t, beta_prod_t = self._get_alpha_and_beta(timestep)
alpha_prod_t_prev, beta_prod_t_prev = self._get_alpha_and_beta(prev_timestep)

a_t = (alpha_prod_t_prev / alpha_prod_t) ** 0.5
b_t = -a_t * (beta_prod_t**0.5) + beta_prod_t_prev**0.5
next_model_input = a_t * base + b_t * model_output

return model_input, next_model_input.to(base.dtype)

@torch.no_grad()
def decode_latents(self, latents: torch.Tensor):
latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
return image

@torch.no_grad()
def prepare_latents(
self,
image: Image.Image,
text_embeds: torch.Tensor,
timesteps: torch.Tensor,
guidance_scale: float,
generator: Optional[torch.Generator] = None,
):
do_classifier_free_guidance = guidance_scale > 1.0

image = image.to(device=self.device, dtype=text_embeds.dtype)
latent = self.vae.encode(image).latent_dist.sample(generator)

latent = self.vae.config.scaling_factor * latent

coupled_latents = [latent.clone(), latent.clone()]

for i, t in tqdm(enumerate(timesteps), total=len(timesteps)):
coupled_latents = self.noise_mixing_layer(x=coupled_latents[0], y=coupled_latents[1])

# j - model_input index, k - base index
for j in range(2):
k = j ^ 1

if self.leapfrog_steps:
if i % 2 == 0:
k, j = j, k

model_input = coupled_latents[j]
base = coupled_latents[k]

latent_model_input = torch.cat([model_input] * 2) if do_classifier_free_guidance else model_input

noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeds).sample

if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

base, model_input = self.noise_step(
base=base,
model_input=model_input,
model_output=noise_pred,
timestep=t,
)

coupled_latents[k] = model_input

return coupled_latents

@torch.no_grad()
def __call__(
self,
base_prompt: str,
target_prompt: str,
image: Image.Image,
guidance_scale: float = 3.0,
num_inference_steps: int = 50,
strength: float = 0.8,
negative_prompt: Optional[str] = None,
generator: Optional[torch.Generator] = None,
output_type: Optional[str] = "pil",
):
do_classifier_free_guidance = guidance_scale > 1.0

image = self.image_processor.preprocess(image)

base_embeds = self._encode_prompt(base_prompt, negative_prompt, do_classifier_free_guidance)
target_embeds = self._encode_prompt(target_prompt, negative_prompt, do_classifier_free_guidance)

self.scheduler.set_timesteps(num_inference_steps, self.device)

t_limit = num_inference_steps - int(num_inference_steps * strength)
fwd_timesteps = self.scheduler.timesteps[t_limit:]
bwd_timesteps = fwd_timesteps.flip(0)

coupled_latents = self.prepare_latents(image, base_embeds, bwd_timesteps, guidance_scale, generator)

for i, t in tqdm(enumerate(fwd_timesteps), total=len(fwd_timesteps)):
# j - model_input index, k - base index
for k in range(2):
j = k ^ 1

if self.leapfrog_steps:
if i % 2 == 1:
k, j = j, k

model_input = coupled_latents[j]
base = coupled_latents[k]

latent_model_input = torch.cat([model_input] * 2) if do_classifier_free_guidance else model_input

noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=target_embeds).sample

if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

base, model_input = self.denoise_step(
base=base,
model_input=model_input,
model_output=noise_pred,
timestep=t,
)

coupled_latents[k] = model_input

coupled_latents = self.denoise_mixing_layer(x=coupled_latents[0], y=coupled_latents[1])

# either one is fine
final_latent = coupled_latents[0]

if output_type not in ["latent", "pt", "np", "pil"]:
deprecation_message = (
f"the output_type {output_type} is outdated. Please make sure to set it to one of these instead: "
"`pil`, `np`, `pt`, `latent`"
)
deprecate("Unsupported output_type", "1.0.0", deprecation_message, standard_warn=False)
output_type = "np"

if output_type == "latent":
image = final_latent
else:
image = self.decode_latents(final_latent)
image = self.image_processor.postprocess(image, output_type=output_type)

return image