From 5fd91e7d60b2f5dd7b9c0c35268c17aeae9ffca5 Mon Sep 17 00:00:00 2001 From: Joqsan Azocar Date: Mon, 17 Apr 2023 19:54:13 +0300 Subject: [PATCH 1/8] EDICT pipeline initial commit - Starting point taking from https://github.com/Joqsan/edict-diffusion --- examples/community/edict_pipeline.py | 328 +++++++++++++++++++++++++++ 1 file changed, 328 insertions(+) create mode 100644 examples/community/edict_pipeline.py diff --git a/examples/community/edict_pipeline.py b/examples/community/edict_pipeline.py new file mode 100644 index 000000000000..4ca2a41c5e9a --- /dev/null +++ b/examples/community/edict_pipeline.py @@ -0,0 +1,328 @@ +from typing import Optional, Union + +import numpy as np +import torch +from diffusers import AutoencoderKL, UNet2DConditionModel +from PIL import Image +from tqdm.auto import tqdm +from transformers import CLIPTextModel, CLIPTokenizer + +def preprocess(image): + if isinstance(image, Image.Image): + w, h = image.size + w, h = map(lambda x: x - x % 8, (w, h)) # resize to integer multiple of 8 + + image = np.array(image.resize((w, h), resample=Image.Resampling.LANCZOS))[ + None, : + ] + image = np.array(image).astype(np.float32) / 255.0 + image = image.transpose(0, 3, 1, 2) + image = 2.0 * image - 1.0 + image = torch.from_numpy(image) + else: + raise TypeError("Expected object of type PIL.Image.Image") + return image + + +class EDICTScheduler: + def __init__( + self, + p: float = 0.93, + beta_1: float = 0.00085, + beta_T: float = 0.012, + num_train_timesteps: int = 1000, # T = 1000 + set_alpha_to_one: bool = False, + ): + self.p = p + self.num_train_timesteps = num_train_timesteps + + # scaled linear + betas = ( + torch.linspace( + beta_1**0.5, beta_T**0.5, num_train_timesteps, dtype=torch.float32 + ) + ** 2 + ) + + alphas = 1.0 - betas + self.alphas_cumprod = torch.cumprod(alphas, dim=0) + + self.final_alpha_cumprod = ( + torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0] + ) + + # For PEP 412's sake + self.num_inference_steps = None + self.timesteps = torch.from_numpy( + np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64) + ) + + def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device]): + self.num_inference_steps = num_inference_steps + step_ratio = self.num_train_timesteps // self.num_inference_steps + + timesteps = ( + (np.arange(0, num_inference_steps) * step_ratio) + .round()[::-1] + .copy() + .astype(np.int64) + ) + self.timesteps = torch.from_numpy(timesteps).to(device) + + def denoise_mixing_layer(self, x: torch.Tensor, y: torch.Tensor): + x = self.p * x + (1 - self.p) * y + y = self.p * y + (1 - self.p) * x + + return [x, y] + + def noise_mixing_layer(self, x: torch.Tensor, y: torch.Tensor): + y = (y - (1 - self.p) * x) / self.p + x = (x - (1 - self.p) * y) / self.p + + return [x, y] + + def get_alpha_and_beta(self, t: torch.Tensor): + # as self.alphas_cumprod is always in cpu + t = int(t) + + alpha_prod = self.alphas_cumprod[t] if t >= 0 else self.final_alpha_cumprod + + return alpha_prod, 1 - alpha_prod + + def noise_step( + self, + base: torch.Tensor, + model_input: torch.Tensor, + model_output: torch.Tensor, + timestep: torch.Tensor, + ): + prev_timestep = timestep - self.num_train_timesteps / self.num_inference_steps + + alpha_prod_t, beta_prod_t = self.get_alpha_and_beta(timestep) + alpha_prod_t_prev, beta_prod_t_prev = self.get_alpha_and_beta(prev_timestep) + + a_t = (alpha_prod_t_prev / alpha_prod_t) ** 0.5 + b_t = -a_t * (beta_prod_t**0.5) + beta_prod_t_prev**0.5 + + next_model_input = (base - b_t * model_output) / a_t + + return model_input, next_model_input.to(base.dtype) + + def denoise_step( + self, + base: torch.Tensor, + model_input: torch.Tensor, + model_output: torch.Tensor, + timestep: torch.Tensor, + ): + prev_timestep = timestep - self.num_train_timesteps / self.num_inference_steps + + alpha_prod_t, beta_prod_t = self.get_alpha_and_beta(timestep) + alpha_prod_t_prev, beta_prod_t_prev = self.get_alpha_and_beta(prev_timestep) + + a_t = (alpha_prod_t_prev / alpha_prod_t) ** 0.5 + b_t = -a_t * (beta_prod_t**0.5) + beta_prod_t_prev**0.5 + next_model_input = a_t * base + b_t * model_output + + return model_input, next_model_input.to(base.dtype) + + + +class Pipeline: + def __init__( + self, + scheduler: EDICTScheduler, + clip_path: str = "openai/clip-vit-large-patch14", + sd_path: str = "CompVis/stable-diffusion-v1-4", + vae_path: str = None, + revision: str = "fp16", + torch_dtype: torch.dtype = torch.float16, + leapfrog_steps: bool = True, + device: Union[str, torch.device] = "cuda", + ): + self.scheduler = scheduler + self.leapfrog_steps = leapfrog_steps + self.device = device + + self.unet = UNet2DConditionModel.from_pretrained( + sd_path, + subfolder="unet", + revision=revision, + torch_dtype=torch_dtype, + ).to(device) + + self.vae = AutoencoderKL.from_pretrained( + sd_path if vae_path is None else vae_path, + subfolder="vae" if vae_path is None else None, + revision=revision, + torch_dtype=torch_dtype, + ).to(device) + + self.tokenizer = CLIPTokenizer.from_pretrained(clip_path) + self.encoder = CLIPTextModel.from_pretrained( + clip_path, torch_dtype=torch_dtype + ).to(device) + + def encode_prompt(self, prompt: str, negative_prompt: Optional[str] = None): + null_prompt = "" if negative_prompt is None else negative_prompt + + tokens_uncond = self.tokenizer( + null_prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + return_overflowing_tokens=True, + ) + embeds_uncond = self.encoder( + tokens_uncond.input_ids.to(self.device) + ).last_hidden_state + + tokens_cond = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + return_overflowing_tokens=True, + ) + embeds_cond = self.encoder( + tokens_cond.input_ids.to(self.device) + ).last_hidden_state + + return torch.cat([embeds_uncond, embeds_cond]) + + @torch.no_grad() + def decode_latents(self, latents: torch.Tensor): + # latents = 1 / self.vae.config.scaling_factor * latents + latents = 1 / 0.18215 * latents + image = self.vae.decode(latents).sample + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + @torch.no_grad() + def prepare_latents( + self, + image: Image.Image, + text_embeds: torch.Tensor, + timesteps: torch.Tensor, + guidance_scale: float, + ): + generator = torch.cuda.manual_seed(1) + image = image.to(device=self.device, dtype=text_embeds.dtype) + latent = self.vae.encode(image).latent_dist.sample(generator) + + # init_latents = self.vae.config.scaling_factor * init_latents + latent = 0.18215 * latent + + coupled_latents = [latent.clone(), latent.clone()] + + for i, t in tqdm(enumerate(timesteps), total=len(timesteps)): + coupled_latents = self.scheduler.noise_mixing_layer( + x=coupled_latents[0], y=coupled_latents[1] + ) + + # j - model_input index, k - base index + for j in range(2): + k = j ^ 1 + + if self.leapfrog_steps: + if i % 2 == 0: + k, j = j, k + + model_input = coupled_latents[j] + base = coupled_latents[k] + + latent_model_input = torch.cat([model_input] * 2) + + noise_pred = self.unet( + latent_model_input, t, encoder_hidden_states=text_embeds + ).sample + + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * ( + noise_pred_text - noise_pred_uncond + ) + + base, model_input = self.scheduler.noise_step( + base=base, + model_input=model_input, + model_output=noise_pred, + timestep=t, + ) + + coupled_latents[k] = model_input + + return coupled_latents + + @torch.no_grad() + def __call__( + self, + base_prompt: str, + target_prompt: str, + image: Image.Image, + guidance_scale: float = 3.0, + steps: int = 50, + strength: float = 0.8, + ): + + image = preprocess(image) # from PIL.Image to torch.Tensor + + base_embeds = self.encode_prompt(base_prompt) + target_embeds = self.encode_prompt(target_prompt) + + self.scheduler.set_timesteps(steps, self.device) + + t_limit = steps - int(steps * strength) + fwd_timesteps = self.scheduler.timesteps[t_limit:] + bwd_timesteps = fwd_timesteps.flip(0) + + latent_pair = self.prepare_latents( + image, base_embeds, bwd_timesteps, guidance_scale + ) + + for i, t in tqdm(enumerate(fwd_timesteps), total=len(fwd_timesteps)): + # j - model_input index, k - base index + for k in range(2): + j = k ^ 1 + + if self.leapfrog_steps: + if i % 2 == 1: + k, j = j, k + + model_input = latent_pair[j] + base = latent_pair[k] + + latent_model_input = torch.cat([model_input] * 2) + + noise_pred = self.unet( + latent_model_input, t, encoder_hidden_states=target_embeds + ).sample + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + + noise_pred = noise_pred_uncond + guidance_scale * ( + noise_pred_text - noise_pred_uncond + ) + + base, model_input = self.scheduler.denoise_step( + base=base, + model_input=model_input, + model_output=noise_pred, + timestep=t, + ) + + latent_pair[k] = model_input + + latent_pair = self.scheduler.denoise_mixing_layer( + x=latent_pair[0], y=latent_pair[1] + ) + + # either one is fine + final_latent = latent_pair[0] + + image = self.decode_latents(final_latent) + image = (image[0] * 255).round().astype("uint8") + pil_image = Image.fromarray(image) + + return pil_image \ No newline at end of file From f24a51c2f0687460b1a4a18ca41e27445b4320c4 Mon Sep 17 00:00:00 2001 From: Joqsan Azocar Date: Mon, 17 Apr 2023 19:58:58 +0300 Subject: [PATCH 2/8] refactor __init__() method --- examples/community/edict_pipeline.py | 43 +++++++++++----------------- 1 file changed, 16 insertions(+), 27 deletions(-) diff --git a/examples/community/edict_pipeline.py b/examples/community/edict_pipeline.py index 4ca2a41c5e9a..f1ddde9bf446 100644 --- a/examples/community/edict_pipeline.py +++ b/examples/community/edict_pipeline.py @@ -7,6 +7,8 @@ from tqdm.auto import tqdm from transformers import CLIPTextModel, CLIPTokenizer +from diffusers import DiffusionPipeline + def preprocess(image): if isinstance(image, Image.Image): w, h = image.size @@ -128,40 +130,27 @@ def denoise_step( -class Pipeline: +class Pipeline(DiffusionPipeline): def __init__( self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, scheduler: EDICTScheduler, - clip_path: str = "openai/clip-vit-large-patch14", - sd_path: str = "CompVis/stable-diffusion-v1-4", - vae_path: str = None, - revision: str = "fp16", - torch_dtype: torch.dtype = torch.float16, leapfrog_steps: bool = True, - device: Union[str, torch.device] = "cuda", ): self.scheduler = scheduler self.leapfrog_steps = leapfrog_steps - self.device = device - - self.unet = UNet2DConditionModel.from_pretrained( - sd_path, - subfolder="unet", - revision=revision, - torch_dtype=torch_dtype, - ).to(device) - - self.vae = AutoencoderKL.from_pretrained( - sd_path if vae_path is None else vae_path, - subfolder="vae" if vae_path is None else None, - revision=revision, - torch_dtype=torch_dtype, - ).to(device) - - self.tokenizer = CLIPTokenizer.from_pretrained(clip_path) - self.encoder = CLIPTextModel.from_pretrained( - clip_path, torch_dtype=torch_dtype - ).to(device) + + super().__init__() + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + ) + def encode_prompt(self, prompt: str, negative_prompt: Optional[str] = None): null_prompt = "" if negative_prompt is None else negative_prompt From 0d372f60ff19186fd89d1be0b94812355bd8cf06 Mon Sep 17 00:00:00 2001 From: Joqsan Azocar Date: Mon, 17 Apr 2023 22:03:19 +0300 Subject: [PATCH 3/8] minor refactoring --- examples/community/edict_pipeline.py | 84 ++++++++++------------------ 1 file changed, 29 insertions(+), 55 deletions(-) diff --git a/examples/community/edict_pipeline.py b/examples/community/edict_pipeline.py index f1ddde9bf446..020a4ae1e9ab 100644 --- a/examples/community/edict_pipeline.py +++ b/examples/community/edict_pipeline.py @@ -12,11 +12,9 @@ def preprocess(image): if isinstance(image, Image.Image): w, h = image.size - w, h = map(lambda x: x - x % 8, (w, h)) # resize to integer multiple of 8 + w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8 - image = np.array(image.resize((w, h), resample=Image.Resampling.LANCZOS))[ - None, : - ] + image = np.array(image.resize((w, h), resample=Image.Resampling.LANCZOS))[None, :] image = np.array(image).astype(np.float32) / 255.0 image = image.transpose(0, 3, 1, 2) image = 2.0 * image - 1.0 @@ -32,43 +30,30 @@ def __init__( p: float = 0.93, beta_1: float = 0.00085, beta_T: float = 0.012, - num_train_timesteps: int = 1000, # T = 1000 + num_train_timesteps: int = 1000, set_alpha_to_one: bool = False, ): self.p = p self.num_train_timesteps = num_train_timesteps # scaled linear - betas = ( - torch.linspace( - beta_1**0.5, beta_T**0.5, num_train_timesteps, dtype=torch.float32 - ) - ** 2 - ) + betas = torch.linspace(beta_1**0.5, beta_T**0.5, num_train_timesteps, dtype=torch.float32) ** 2 alphas = 1.0 - betas self.alphas_cumprod = torch.cumprod(alphas, dim=0) - self.final_alpha_cumprod = ( - torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0] - ) + self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0] # For PEP 412's sake self.num_inference_steps = None - self.timesteps = torch.from_numpy( - np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64) - ) + self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64)) def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device]): self.num_inference_steps = num_inference_steps step_ratio = self.num_train_timesteps // self.num_inference_steps - timesteps = ( - (np.arange(0, num_inference_steps) * step_ratio) - .round()[::-1] - .copy() - .astype(np.int64) - ) + timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64) + self.timesteps = torch.from_numpy(timesteps).to(device) def denoise_mixing_layer(self, x: torch.Tensor, y: torch.Tensor): @@ -130,7 +115,7 @@ def denoise_step( -class Pipeline(DiffusionPipeline): +class EDICTPipeline(DiffusionPipeline): def __init__( self, vae: AutoencoderKL, @@ -153,16 +138,16 @@ def __init__( def encode_prompt(self, prompt: str, negative_prompt: Optional[str] = None): - null_prompt = "" if negative_prompt is None else negative_prompt + negative_prompt = "" if negative_prompt is None else negative_prompt tokens_uncond = self.tokenizer( - null_prompt, + negative_prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", - return_overflowing_tokens=True, ) + embeds_uncond = self.encoder( tokens_uncond.input_ids.to(self.device) ).last_hidden_state @@ -173,8 +158,8 @@ def encode_prompt(self, prompt: str, negative_prompt: Optional[str] = None): max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", - return_overflowing_tokens=True, ) + embeds_cond = self.encoder( tokens_cond.input_ids.to(self.device) ).last_hidden_state @@ -208,9 +193,7 @@ def prepare_latents( coupled_latents = [latent.clone(), latent.clone()] for i, t in tqdm(enumerate(timesteps), total=len(timesteps)): - coupled_latents = self.scheduler.noise_mixing_layer( - x=coupled_latents[0], y=coupled_latents[1] - ) + coupled_latents = self.scheduler.noise_mixing_layer(x=coupled_latents[0], y=coupled_latents[1]) # j - model_input index, k - base index for j in range(2): @@ -225,9 +208,7 @@ def prepare_latents( latent_model_input = torch.cat([model_input] * 2) - noise_pred = self.unet( - latent_model_input, t, encoder_hidden_states=text_embeds - ).sample + noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeds).sample noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * ( @@ -252,24 +233,23 @@ def __call__( target_prompt: str, image: Image.Image, guidance_scale: float = 3.0, - steps: int = 50, + num_inference_steps: int = 50, strength: float = 0.8, + negative_prompt: Optional[str] = None, ): image = preprocess(image) # from PIL.Image to torch.Tensor - base_embeds = self.encode_prompt(base_prompt) - target_embeds = self.encode_prompt(target_prompt) + base_embeds = self.encode_prompt(base_prompt, negative_prompt) + target_embeds = self.encode_prompt(target_prompt, negative_prompt) - self.scheduler.set_timesteps(steps, self.device) + self.scheduler.set_timesteps(num_inference_steps, self.device) - t_limit = steps - int(steps * strength) + t_limit = num_inference_steps - int(num_inference_steps * strength) fwd_timesteps = self.scheduler.timesteps[t_limit:] bwd_timesteps = fwd_timesteps.flip(0) - latent_pair = self.prepare_latents( - image, base_embeds, bwd_timesteps, guidance_scale - ) + coupled_latents = self.prepare_latents(image, base_embeds, bwd_timesteps, guidance_scale) for i, t in tqdm(enumerate(fwd_timesteps), total=len(fwd_timesteps)): # j - model_input index, k - base index @@ -280,19 +260,15 @@ def __call__( if i % 2 == 1: k, j = j, k - model_input = latent_pair[j] - base = latent_pair[k] + model_input = coupled_latents[j] + base = coupled_latents[k] latent_model_input = torch.cat([model_input] * 2) - noise_pred = self.unet( - latent_model_input, t, encoder_hidden_states=target_embeds - ).sample + noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=target_embeds).sample noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * ( - noise_pred_text - noise_pred_uncond - ) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) base, model_input = self.scheduler.denoise_step( base=base, @@ -301,14 +277,12 @@ def __call__( timestep=t, ) - latent_pair[k] = model_input + coupled_latents[k] = model_input - latent_pair = self.scheduler.denoise_mixing_layer( - x=latent_pair[0], y=latent_pair[1] - ) + coupled_latents = self.scheduler.denoise_mixing_layer(x=coupled_latents[0], y=coupled_latents[1]) # either one is fine - final_latent = latent_pair[0] + final_latent = coupled_latents[0] image = self.decode_latents(final_latent) image = (image[0] * 255).round().astype("uint8") From b0de5615655d2b07eb3d46f3552e13bb1a0a1baa Mon Sep 17 00:00:00 2001 From: Joqsan Azocar Date: Tue, 18 Apr 2023 13:39:08 +0300 Subject: [PATCH 4/8] refactor scheduler code - remove scheduler and move its methods to the EDICTPipeline class --- examples/community/edict_pipeline.py | 135 ++++++++++----------------- 1 file changed, 51 insertions(+), 84 deletions(-) diff --git a/examples/community/edict_pipeline.py b/examples/community/edict_pipeline.py index 020a4ae1e9ab..46cb7200f131 100644 --- a/examples/community/edict_pipeline.py +++ b/examples/community/edict_pipeline.py @@ -2,12 +2,11 @@ import numpy as np import torch -from diffusers import AutoencoderKL, UNet2DConditionModel +from diffusers import DiffusionPipeline, AutoencoderKL, UNet2DConditionModel, DDIMScheduler from PIL import Image from tqdm.auto import tqdm from transformers import CLIPTextModel, CLIPTokenizer -from diffusers import DiffusionPipeline def preprocess(image): if isinstance(image, Image.Image): @@ -24,37 +23,58 @@ def preprocess(image): return image -class EDICTScheduler: +class EDICTPipeline(DiffusionPipeline): def __init__( self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: DDIMScheduler, p: float = 0.93, - beta_1: float = 0.00085, - beta_T: float = 0.012, - num_train_timesteps: int = 1000, - set_alpha_to_one: bool = False, + leapfrog_steps: bool = True, ): + self.leapfrog_steps = leapfrog_steps self.p = p - self.num_train_timesteps = num_train_timesteps - - # scaled linear - betas = torch.linspace(beta_1**0.5, beta_T**0.5, num_train_timesteps, dtype=torch.float32) ** 2 + + super().__init__() + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + ) - alphas = 1.0 - betas - self.alphas_cumprod = torch.cumprod(alphas, dim=0) - self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0] + def encode_prompt(self, prompt: str, negative_prompt: Optional[str] = None): + negative_prompt = "" if negative_prompt is None else negative_prompt - # For PEP 412's sake - self.num_inference_steps = None - self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64)) + tokens_uncond = self.tokenizer( + negative_prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) - def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device]): - self.num_inference_steps = num_inference_steps - step_ratio = self.num_train_timesteps // self.num_inference_steps + embeds_uncond = self.text_encoder( + tokens_uncond.input_ids.to(self.device) + ).last_hidden_state - timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64) + tokens_cond = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) - self.timesteps = torch.from_numpy(timesteps).to(device) + embeds_cond = self.text_encoder( + tokens_cond.input_ids.to(self.device) + ).last_hidden_state + + return torch.cat([embeds_uncond, embeds_cond]) def denoise_mixing_layer(self, x: torch.Tensor, y: torch.Tensor): x = self.p * x + (1 - self.p) * y @@ -72,7 +92,7 @@ def get_alpha_and_beta(self, t: torch.Tensor): # as self.alphas_cumprod is always in cpu t = int(t) - alpha_prod = self.alphas_cumprod[t] if t >= 0 else self.final_alpha_cumprod + alpha_prod = self.scheduler.alphas_cumprod[t] if t >= 0 else self.scheduler.final_alpha_cumprod return alpha_prod, 1 - alpha_prod @@ -83,13 +103,13 @@ def noise_step( model_output: torch.Tensor, timestep: torch.Tensor, ): - prev_timestep = timestep - self.num_train_timesteps / self.num_inference_steps + prev_timestep = timestep - self.scheduler.config.num_train_timesteps / self.scheduler.num_inference_steps alpha_prod_t, beta_prod_t = self.get_alpha_and_beta(timestep) alpha_prod_t_prev, beta_prod_t_prev = self.get_alpha_and_beta(prev_timestep) a_t = (alpha_prod_t_prev / alpha_prod_t) ** 0.5 - b_t = -a_t * (beta_prod_t**0.5) + beta_prod_t_prev**0.5 + b_t = -a_t * (beta_prod_t**0.5) + beta_prod_t_prev ** 0.5 next_model_input = (base - b_t * model_output) / a_t @@ -102,7 +122,7 @@ def denoise_step( model_output: torch.Tensor, timestep: torch.Tensor, ): - prev_timestep = timestep - self.num_train_timesteps / self.num_inference_steps + prev_timestep = timestep - self.scheduler.config.num_train_timesteps / self.scheduler.num_inference_steps alpha_prod_t, beta_prod_t = self.get_alpha_and_beta(timestep) alpha_prod_t_prev, beta_prod_t_prev = self.get_alpha_and_beta(prev_timestep) @@ -112,60 +132,7 @@ def denoise_step( next_model_input = a_t * base + b_t * model_output return model_input, next_model_input.to(base.dtype) - - - -class EDICTPipeline(DiffusionPipeline): - def __init__( - self, - vae: AutoencoderKL, - text_encoder: CLIPTextModel, - tokenizer: CLIPTokenizer, - unet: UNet2DConditionModel, - scheduler: EDICTScheduler, - leapfrog_steps: bool = True, - ): - self.scheduler = scheduler - self.leapfrog_steps = leapfrog_steps - - super().__init__() - self.register_modules( - vae=vae, - text_encoder=text_encoder, - tokenizer=tokenizer, - unet=unet, - ) - - - def encode_prompt(self, prompt: str, negative_prompt: Optional[str] = None): - negative_prompt = "" if negative_prompt is None else negative_prompt - - tokens_uncond = self.tokenizer( - negative_prompt, - padding="max_length", - max_length=self.tokenizer.model_max_length, - truncation=True, - return_tensors="pt", - ) - - embeds_uncond = self.encoder( - tokens_uncond.input_ids.to(self.device) - ).last_hidden_state - - tokens_cond = self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer.model_max_length, - truncation=True, - return_tensors="pt", - ) - - embeds_cond = self.encoder( - tokens_cond.input_ids.to(self.device) - ).last_hidden_state - - return torch.cat([embeds_uncond, embeds_cond]) - + @torch.no_grad() def decode_latents(self, latents: torch.Tensor): # latents = 1 / self.vae.config.scaling_factor * latents @@ -193,7 +160,7 @@ def prepare_latents( coupled_latents = [latent.clone(), latent.clone()] for i, t in tqdm(enumerate(timesteps), total=len(timesteps)): - coupled_latents = self.scheduler.noise_mixing_layer(x=coupled_latents[0], y=coupled_latents[1]) + coupled_latents = self.noise_mixing_layer(x=coupled_latents[0], y=coupled_latents[1]) # j - model_input index, k - base index for j in range(2): @@ -215,7 +182,7 @@ def prepare_latents( noise_pred_text - noise_pred_uncond ) - base, model_input = self.scheduler.noise_step( + base, model_input = self.noise_step( base=base, model_input=model_input, model_output=noise_pred, @@ -270,7 +237,7 @@ def __call__( noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) - base, model_input = self.scheduler.denoise_step( + base, model_input = self.denoise_step( base=base, model_input=model_input, model_output=noise_pred, @@ -279,7 +246,7 @@ def __call__( coupled_latents[k] = model_input - coupled_latents = self.scheduler.denoise_mixing_layer(x=coupled_latents[0], y=coupled_latents[1]) + coupled_latents = self.denoise_mixing_layer(x=coupled_latents[0], y=coupled_latents[1]) # either one is fine final_latent = coupled_latents[0] From 0ba76a260d1586ee66acce216de034660596d0e9 Mon Sep 17 00:00:00 2001 From: Joqsan Azocar Date: Tue, 18 Apr 2023 15:40:54 +0300 Subject: [PATCH 5/8] make CFG optional - refactor encode_prompt(). - include optional generator for sampling with vae. - minor variable renaming --- examples/community/edict_pipeline.py | 102 +++++++++++++++------------ 1 file changed, 56 insertions(+), 46 deletions(-) diff --git a/examples/community/edict_pipeline.py b/examples/community/edict_pipeline.py index 46cb7200f131..a3305fdbc4b6 100644 --- a/examples/community/edict_pipeline.py +++ b/examples/community/edict_pipeline.py @@ -6,14 +6,14 @@ from PIL import Image from tqdm.auto import tqdm from transformers import CLIPTextModel, CLIPTokenizer - +from diffusers.utils import PIL_INTERPOLATION def preprocess(image): if isinstance(image, Image.Image): w, h = image.size w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8 - image = np.array(image.resize((w, h), resample=Image.Resampling.LANCZOS))[None, :] + image = np.array(image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] image = np.array(image).astype(np.float32) / 255.0 image = image.transpose(0, 3, 1, 2) image = 2.0 * image - 1.0 @@ -31,12 +31,12 @@ def __init__( tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, scheduler: DDIMScheduler, - p: float = 0.93, + mixing_coeff: float = 0.93, leapfrog_steps: bool = True, ): + self.mixing_coeff = mixing_coeff self.leapfrog_steps = leapfrog_steps - self.p = p - + super().__init__() self.register_modules( vae=vae, @@ -47,48 +47,54 @@ def __init__( ) - def encode_prompt(self, prompt: str, negative_prompt: Optional[str] = None): - negative_prompt = "" if negative_prompt is None else negative_prompt - - tokens_uncond = self.tokenizer( - negative_prompt, + def _encode_prompt(self, prompt: str, negative_prompt: Optional[str] = None, do_classifier_free_guidance: bool = False): + text_inputs = self.tokenizer( + prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) - embeds_uncond = self.text_encoder( - tokens_uncond.input_ids.to(self.device) + prompt_embeds = self.text_encoder( + text_inputs.input_ids.to(self.device) ).last_hidden_state - tokens_cond = self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer.model_max_length, - truncation=True, - return_tensors="pt", - ) - - embeds_cond = self.text_encoder( - tokens_cond.input_ids.to(self.device) - ).last_hidden_state + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=self.device) + + if do_classifier_free_guidance: + + uncond_tokens = "" if negative_prompt is None else negative_prompt + + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(self.device) + ).last_hidden_state - return torch.cat([embeds_uncond, embeds_cond]) + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + return prompt_embeds def denoise_mixing_layer(self, x: torch.Tensor, y: torch.Tensor): - x = self.p * x + (1 - self.p) * y - y = self.p * y + (1 - self.p) * x + x = self.mixing_coeff * x + (1 - self.mixing_coeff) * y + y = self.mixing_coeff * y + (1 - self.mixing_coeff) * x return [x, y] def noise_mixing_layer(self, x: torch.Tensor, y: torch.Tensor): - y = (y - (1 - self.p) * x) / self.p - x = (x - (1 - self.p) * y) / self.p + y = (y - (1 - self.mixing_coeff) * x) / self.mixing_coeff + x = (x - (1 - self.mixing_coeff) * y) / self.mixing_coeff return [x, y] - def get_alpha_and_beta(self, t: torch.Tensor): + def _get_alpha_and_beta(self, t: torch.Tensor): # as self.alphas_cumprod is always in cpu t = int(t) @@ -105,8 +111,8 @@ def noise_step( ): prev_timestep = timestep - self.scheduler.config.num_train_timesteps / self.scheduler.num_inference_steps - alpha_prod_t, beta_prod_t = self.get_alpha_and_beta(timestep) - alpha_prod_t_prev, beta_prod_t_prev = self.get_alpha_and_beta(prev_timestep) + alpha_prod_t, beta_prod_t = self._get_alpha_and_beta(timestep) + alpha_prod_t_prev, beta_prod_t_prev = self._get_alpha_and_beta(prev_timestep) a_t = (alpha_prod_t_prev / alpha_prod_t) ** 0.5 b_t = -a_t * (beta_prod_t**0.5) + beta_prod_t_prev ** 0.5 @@ -124,8 +130,8 @@ def denoise_step( ): prev_timestep = timestep - self.scheduler.config.num_train_timesteps / self.scheduler.num_inference_steps - alpha_prod_t, beta_prod_t = self.get_alpha_and_beta(timestep) - alpha_prod_t_prev, beta_prod_t_prev = self.get_alpha_and_beta(prev_timestep) + alpha_prod_t, beta_prod_t = self._get_alpha_and_beta(timestep) + alpha_prod_t_prev, beta_prod_t_prev = self._get_alpha_and_beta(prev_timestep) a_t = (alpha_prod_t_prev / alpha_prod_t) ** 0.5 b_t = -a_t * (beta_prod_t**0.5) + beta_prod_t_prev**0.5 @@ -149,8 +155,10 @@ def prepare_latents( text_embeds: torch.Tensor, timesteps: torch.Tensor, guidance_scale: float, + generator: Optional[torch.Generator] = None, ): - generator = torch.cuda.manual_seed(1) + do_classifier_free_guidance = guidance_scale > 1.0 + image = image.to(device=self.device, dtype=text_embeds.dtype) latent = self.vae.encode(image).latent_dist.sample(generator) @@ -173,14 +181,13 @@ def prepare_latents( model_input = coupled_latents[j] base = coupled_latents[k] - latent_model_input = torch.cat([model_input] * 2) + latent_model_input = torch.cat([model_input] * 2) if do_classifier_free_guidance else model_input noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeds).sample - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * ( - noise_pred_text - noise_pred_uncond - ) + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) base, model_input = self.noise_step( base=base, @@ -203,12 +210,14 @@ def __call__( num_inference_steps: int = 50, strength: float = 0.8, negative_prompt: Optional[str] = None, + generator: Optional[torch.Generator] = None, ): + do_classifier_free_guidance = guidance_scale > 1.0 image = preprocess(image) # from PIL.Image to torch.Tensor - base_embeds = self.encode_prompt(base_prompt, negative_prompt) - target_embeds = self.encode_prompt(target_prompt, negative_prompt) + base_embeds = self._encode_prompt(base_prompt, negative_prompt, do_classifier_free_guidance) + target_embeds = self._encode_prompt(target_prompt, negative_prompt, do_classifier_free_guidance) self.scheduler.set_timesteps(num_inference_steps, self.device) @@ -216,7 +225,7 @@ def __call__( fwd_timesteps = self.scheduler.timesteps[t_limit:] bwd_timesteps = fwd_timesteps.flip(0) - coupled_latents = self.prepare_latents(image, base_embeds, bwd_timesteps, guidance_scale) + coupled_latents = self.prepare_latents(image, base_embeds, bwd_timesteps, guidance_scale, generator) for i, t in tqdm(enumerate(fwd_timesteps), total=len(fwd_timesteps)): # j - model_input index, k - base index @@ -230,12 +239,13 @@ def __call__( model_input = coupled_latents[j] base = coupled_latents[k] - latent_model_input = torch.cat([model_input] * 2) + latent_model_input = torch.cat([model_input] * 2) if do_classifier_free_guidance else model_input noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=target_embeds).sample - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) base, model_input = self.denoise_step( base=base, From 33b54e0016fe4be169f9c6391f95af70d6d951eb Mon Sep 17 00:00:00 2001 From: Joqsan Azocar Date: Wed, 19 Apr 2023 11:37:59 +0300 Subject: [PATCH 6/8] add EDICT pipeline description to README.md --- examples/community/README.md | 87 ++++++++++++++++++++++++++++++++++++ 1 file changed, 87 insertions(+) diff --git a/examples/community/README.md b/examples/community/README.md index 11da90764579..502471621a74 100644 --- a/examples/community/README.md +++ b/examples/community/README.md @@ -31,6 +31,7 @@ MagicMix | Diffusion Pipeline for semantic mixing of an image and a text prompt | UnCLIP Image Interpolation Pipeline | Diffusion Pipeline that allows passing two images/image_embeddings and produces images while interpolating between their image-embeddings | [UnCLIP Image Interpolation Pipeline](#unclip-image-interpolation-pipeline) | - | [Naga Sai Abhinay Devarinti](https://github.com/Abhinay1997/) | | DDIM Noise Comparative Analysis Pipeline | Investigating how the diffusion models learn visual concepts from each noise level (which is a contribution of [P2 weighting (CVPR 2022)](https://arxiv.org/abs/2204.00227)) | [DDIM Noise Comparative Analysis Pipeline](#ddim-noise-comparative-analysis-pipeline) | - |[Aengus (Duc-Anh)](https://github.com/aengusng8) | | CLIP Guided Img2Img Stable Diffusion Pipeline | Doing CLIP guidance for image to image generation with Stable Diffusion | [CLIP Guided Img2Img Stable Diffusion](#clip-guided-img2img-stable-diffusion) | - | [Nipun Jindal](https://github.com/nipunjindal/) | +| EDICT Image Editing Pipeline | Diffusion pipeline for text-guided image editing | [EDICT Image Editing Pipeline](#edict-image-editing-pipeline) | - | [Joqsan Azocar](https://github.com/Joqsan) | @@ -1130,3 +1131,89 @@ Init Image Output Image ![img2img_clip_guidance](https://huggingface.co/datasets/njindal/images/resolve/main/clip_guided_img2img.jpg) + + + +### EDICT Image Editing Pipeline + +This pipeline implements the text-guided image editing approach from the paper [EDICT: Exact Diffusion Inversion via Coupled Transformations](https://arxiv.org/abs/2211.12446). You have to pass: +- (`PIL`) `image` you want to edit. +- `base_prompt`: the text prompt describing the current image (before editing). +- `target_prompt`: the text prompt describing with the edits. + +```python +from diffusers import DiffusionPipeline, DDIMScheduler +from transformers import CLIPTextModel +import torch, PIL, requests +from io import BytesIO +from IPython.display import display + +def center_crop_and_resize(im): + + width, height = im.size + d = min(width, height) + left = (width - d) / 2 + upper = (height - d) / 2 + right = (width + d) / 2 + lower = (height + d) / 2 + + return im.crop((left, upper, right, lower)).resize((512, 512)) + +torch_dtype = torch.float16 +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + +# scheduler and text_encoder param values as in the paper +scheduler = DDIMScheduler( + num_train_timesteps=1000, + beta_start=0.00085, + beta_end=0.012, + beta_schedule="scaled_linear", + set_alpha_to_one=False, + clip_sample=False, +) + +text_encoder = CLIPTextModel.from_pretrained( + pretrained_model_name_or_path="openai/clip-vit-large-patch14", + torch_dtype=torch_dtype, +) + +# initialize pipeline +pipeline = DiffusionPipeline.from_pretrained( + pretrained_model_name_or_path="CompVis/stable-diffusion-v1-4", + custom_pipeline="edict_pipeline", + revision="fp16", + scheduler=scheduler, + text_encoder=text_encoder, + leapfrog_steps=True, + torch_dtype=torch_dtype, +).to(device) + +# download image +image_url = "https://huggingface.co/datasets/Joqsan/images/resolve/main/imagenet_dog_1.jpeg" +response = requests.get(image_url) +image = PIL.Image.open(BytesIO(response.content)) + +# preprocess it +cropped_image = center_crop_and_resize(image) + +# define the prompts +base_prompt = "A dog" +target_prompt = "A golden retriever" + +# run the pipeline +result_image = pipeline( + base_prompt=base_prompt, + target_prompt=target_prompt, + image=cropped_image, +) + +display(result_image) +``` + +Init Image + +![img2img_init_edict_text_editing](https://huggingface.co/datasets/Joqsan/images/resolve/main/imagenet_dog_1.jpeg) + +Output Image + +![img2img_edict_text_editing](https://huggingface.co/datasets/Joqsan/images/resolve/main/imagenet_dog_1_cropped_generated.png) \ No newline at end of file From 20d2733fc5cddc2fd4bc7417b310a51ccd5bdbd4 Mon Sep 17 00:00:00 2001 From: Joqsan Azocar Date: Wed, 19 Apr 2023 12:24:10 +0300 Subject: [PATCH 7/8] replace preprocess() with VaeImageProcessor --- examples/community/edict_pipeline.py | 50 +++++++++++++++------------- 1 file changed, 26 insertions(+), 24 deletions(-) diff --git a/examples/community/edict_pipeline.py b/examples/community/edict_pipeline.py index a3305fdbc4b6..6ac64080aa72 100644 --- a/examples/community/edict_pipeline.py +++ b/examples/community/edict_pipeline.py @@ -6,21 +6,12 @@ from PIL import Image from tqdm.auto import tqdm from transformers import CLIPTextModel, CLIPTokenizer -from diffusers.utils import PIL_INTERPOLATION +from diffusers.utils import ( + PIL_INTERPOLATION, + deprecate, +) -def preprocess(image): - if isinstance(image, Image.Image): - w, h = image.size - w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8 - - image = np.array(image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] - image = np.array(image).astype(np.float32) / 255.0 - image = image.transpose(0, 3, 1, 2) - image = 2.0 * image - 1.0 - image = torch.from_numpy(image) - else: - raise TypeError("Expected object of type PIL.Image.Image") - return image +from diffusers.image_processor import VaeImageProcessor class EDICTPipeline(DiffusionPipeline): @@ -46,6 +37,9 @@ def __init__( scheduler=scheduler, ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + def _encode_prompt(self, prompt: str, negative_prompt: Optional[str] = None, do_classifier_free_guidance: bool = False): text_inputs = self.tokenizer( @@ -141,11 +135,9 @@ def denoise_step( @torch.no_grad() def decode_latents(self, latents: torch.Tensor): - # latents = 1 / self.vae.config.scaling_factor * latents - latents = 1 / 0.18215 * latents + latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents).sample image = (image / 2 + 0.5).clamp(0, 1) - image = image.cpu().permute(0, 2, 3, 1).float().numpy() return image @torch.no_grad() @@ -162,8 +154,7 @@ def prepare_latents( image = image.to(device=self.device, dtype=text_embeds.dtype) latent = self.vae.encode(image).latent_dist.sample(generator) - # init_latents = self.vae.config.scaling_factor * init_latents - latent = 0.18215 * latent + latent = self.vae.config.scaling_factor * latent coupled_latents = [latent.clone(), latent.clone()] @@ -211,10 +202,11 @@ def __call__( strength: float = 0.8, negative_prompt: Optional[str] = None, generator: Optional[torch.Generator] = None, + output_type: Optional[str] = "pil", ): do_classifier_free_guidance = guidance_scale > 1.0 - image = preprocess(image) # from PIL.Image to torch.Tensor + image = self.image_processor.preprocess(image) base_embeds = self._encode_prompt(base_prompt, negative_prompt, do_classifier_free_guidance) target_embeds = self._encode_prompt(target_prompt, negative_prompt, do_classifier_free_guidance) @@ -261,8 +253,18 @@ def __call__( # either one is fine final_latent = coupled_latents[0] - image = self.decode_latents(final_latent) - image = (image[0] * 255).round().astype("uint8") - pil_image = Image.fromarray(image) + if output_type not in ["latent", "pt", "np", "pil"]: + deprecation_message = ( + f"the output_type {output_type} is outdated. Please make sure to set it to one of these instead: " + "`pil`, `np`, `pt`, `latent`" + ) + deprecate("Unsupported output_type", "1.0.0", deprecation_message, standard_warn=False) + output_type = "np" - return pil_image \ No newline at end of file + if output_type == "latent": + image = final_latent + else: + image = self.decode_latents(final_latent) + image = self.image_processor.postprocess(image, output_type=output_type) + + return image From dd7d37385cedce3767410b5b6f019a1a881b0b6b Mon Sep 17 00:00:00 2001 From: Joqsan Azocar Date: Wed, 19 Apr 2023 12:40:11 +0300 Subject: [PATCH 8/8] run make style and make quality commands --- examples/community/edict_pipeline.py | 30 +++++++++++----------------- 1 file changed, 12 insertions(+), 18 deletions(-) diff --git a/examples/community/edict_pipeline.py b/examples/community/edict_pipeline.py index 6ac64080aa72..ac977f79abec 100644 --- a/examples/community/edict_pipeline.py +++ b/examples/community/edict_pipeline.py @@ -1,18 +1,16 @@ -from typing import Optional, Union +from typing import Optional -import numpy as np import torch -from diffusers import DiffusionPipeline, AutoencoderKL, UNet2DConditionModel, DDIMScheduler from PIL import Image from tqdm.auto import tqdm from transformers import CLIPTextModel, CLIPTokenizer + +from diffusers import AutoencoderKL, DDIMScheduler, DiffusionPipeline, UNet2DConditionModel +from diffusers.image_processor import VaeImageProcessor from diffusers.utils import ( - PIL_INTERPOLATION, deprecate, ) -from diffusers.image_processor import VaeImageProcessor - class EDICTPipeline(DiffusionPipeline): def __init__( @@ -40,8 +38,9 @@ def __init__( self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) - - def _encode_prompt(self, prompt: str, negative_prompt: Optional[str] = None, do_classifier_free_guidance: bool = False): + def _encode_prompt( + self, prompt: str, negative_prompt: Optional[str] = None, do_classifier_free_guidance: bool = False + ): text_inputs = self.tokenizer( prompt, padding="max_length", @@ -50,14 +49,11 @@ def _encode_prompt(self, prompt: str, negative_prompt: Optional[str] = None, do_ return_tensors="pt", ) - prompt_embeds = self.text_encoder( - text_inputs.input_ids.to(self.device) - ).last_hidden_state + prompt_embeds = self.text_encoder(text_inputs.input_ids.to(self.device)).last_hidden_state prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=self.device) if do_classifier_free_guidance: - uncond_tokens = "" if negative_prompt is None else negative_prompt uncond_input = self.tokenizer( @@ -68,9 +64,7 @@ def _encode_prompt(self, prompt: str, negative_prompt: Optional[str] = None, do_ return_tensors="pt", ) - negative_prompt_embeds = self.text_encoder( - uncond_input.input_ids.to(self.device) - ).last_hidden_state + negative_prompt_embeds = self.text_encoder(uncond_input.input_ids.to(self.device)).last_hidden_state prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) @@ -109,7 +103,7 @@ def noise_step( alpha_prod_t_prev, beta_prod_t_prev = self._get_alpha_and_beta(prev_timestep) a_t = (alpha_prod_t_prev / alpha_prod_t) ** 0.5 - b_t = -a_t * (beta_prod_t**0.5) + beta_prod_t_prev ** 0.5 + b_t = -a_t * (beta_prod_t**0.5) + beta_prod_t_prev**0.5 next_model_input = (base - b_t * model_output) / a_t @@ -132,7 +126,7 @@ def denoise_step( next_model_input = a_t * base + b_t * model_output return model_input, next_model_input.to(base.dtype) - + @torch.no_grad() def decode_latents(self, latents: torch.Tensor): latents = 1 / self.vae.config.scaling_factor * latents @@ -234,7 +228,7 @@ def __call__( latent_model_input = torch.cat([model_input] * 2) if do_classifier_free_guidance else model_input noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=target_embeds).sample - + if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)