From 21fc39b0b8a25735cf81b2455eb12185d4e1460c Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Fri, 7 Oct 2022 13:57:21 +0200 Subject: [PATCH 1/4] handle dtype in vae and image2image pipeline --- src/diffusers/models/vae.py | 8 +++- .../pipeline_stable_diffusion_img2img.py | 46 ++++++++++--------- 2 files changed, 31 insertions(+), 23 deletions(-) diff --git a/src/diffusers/models/vae.py b/src/diffusers/models/vae.py index fe89b41c074e..55f1d757b8df 100644 --- a/src/diffusers/models/vae.py +++ b/src/diffusers/models/vae.py @@ -337,12 +337,16 @@ def __init__(self, parameters, deterministic=False): self.std = torch.exp(0.5 * self.logvar) self.var = torch.exp(self.logvar) if self.deterministic: - self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) + self.var = self.std = torch.zeros_like(self.mean).to( + device=self.parameters.device, dtype=self.parameters.dtype + ) def sample(self, generator: Optional[torch.Generator] = None) -> torch.FloatTensor: device = self.parameters.device sample_device = "cpu" if device.type == "mps" else device - sample = torch.randn(self.mean.shape, generator=generator, device=sample_device).to(device) + sample = torch.randn(self.mean.shape, generator=generator, device=sample_device).to( + device=device, dtype=self.parameters.dtype + ) x = self.mean + self.std * sample return x diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index 15bdd0208825..37f4be658435 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -217,26 +217,6 @@ def __call__( if isinstance(init_image, PIL.Image.Image): init_image = preprocess(init_image) - # encode the init image into latents and scale the latents - init_latent_dist = self.vae.encode(init_image.to(self.device)).latent_dist - init_latents = init_latent_dist.sample(generator=generator) - init_latents = 0.18215 * init_latents - - # expand init_latents for batch_size - init_latents = torch.cat([init_latents] * batch_size * num_images_per_prompt, dim=0) - - # get the original timestep using init_timestep - offset = self.scheduler.config.get("steps_offset", 0) - init_timestep = int(num_inference_steps * strength) + offset - init_timestep = min(init_timestep, num_inference_steps) - - timesteps = self.scheduler.timesteps[-init_timestep] - timesteps = torch.tensor([timesteps] * batch_size * num_images_per_prompt, device=self.device) - - # add noise to latents using the timesteps - noise = torch.randn(init_latents.shape, generator=generator, device=self.device) - init_latents = self.scheduler.add_noise(init_latents, noise, timesteps) - # get prompt text embeddings text_inputs = self.tokenizer( prompt, @@ -297,6 +277,28 @@ def __call__( # to avoid doing two forward passes text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + # encode the init image into latents and scale the latents + latents_dtype = text_embeddings.dtype + init_image = init_image.to(device=self.device, dtype=latents_dtype) + init_latent_dist = self.vae.encode(init_image).latent_dist + init_latents = init_latent_dist.sample(generator=generator) + init_latents = 0.18215 * init_latents + + # expand init_latents for batch_size + init_latents = torch.cat([init_latents] * batch_size * num_images_per_prompt, dim=0) + + # get the original timestep using init_timestep + offset = self.scheduler.config.get("steps_offset", 0) + init_timestep = int(num_inference_steps * strength) + offset + init_timestep = min(init_timestep, num_inference_steps) + + timesteps = self.scheduler.timesteps[-init_timestep] + timesteps = torch.tensor([timesteps] * batch_size * num_images_per_prompt, device=self.device) + + # add noise to latents using the timesteps + noise = torch.randn(init_latents.shape, generator=generator, device=self.device, dtype=latents_dtype) + init_latents = self.scheduler.add_noise(init_latents, noise, timesteps).to(latents_dtype) + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 @@ -341,7 +343,9 @@ def __call__( image = image.cpu().permute(0, 2, 3, 1).numpy() safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device) - image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_checker_input.pixel_values) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype) + ) if output_type == "pil": image = self.numpy_to_pil(image) From faf41fde28554665f8acadd6fb6b02ca093564d3 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Fri, 7 Oct 2022 14:08:29 +0200 Subject: [PATCH 2/4] handle dtype in add noise --- src/diffusers/schedulers/scheduling_ddim.py | 4 ++++ src/diffusers/schedulers/scheduling_ddpm.py | 4 ++++ .../schedulers/scheduling_lms_discrete.py | 20 +++++++++++++++---- src/diffusers/schedulers/scheduling_pndm.py | 4 ++++ 4 files changed, 28 insertions(+), 4 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index 2dc85a93adc9..da8e81ea2a9f 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -306,6 +306,10 @@ def add_noise( if timesteps.device != original_samples.device: timesteps = timesteps.to(original_samples.device) + # Make sure that alphas_cumprod are in the same dtype as the samples + if self.alphas_cumprod.dtype != original_samples.dtype: + self.alphas_cumprod = self.alphas_cumprod.to(original_samples.dtype) + sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = sqrt_alpha_prod.flatten() while len(sqrt_alpha_prod.shape) < len(original_samples.shape): diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index e1db9079d149..2edd0f76b11f 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -300,6 +300,10 @@ def add_noise( if timesteps.device != original_samples.device: timesteps = timesteps.to(original_samples.device) + # Make sure that alphas_cumprod are in the same dtype as the samples + if self.alphas_cumprod.dtype != original_samples.dtype: + self.alphas_cumprod = self.alphas_cumprod.to(original_samples.dtype) + sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = sqrt_alpha_prod.flatten() while len(sqrt_alpha_prod.shape) < len(original_samples.shape): diff --git a/src/diffusers/schedulers/scheduling_lms_discrete.py b/src/diffusers/schedulers/scheduling_lms_discrete.py index 802da468cda6..665052d4b542 100644 --- a/src/diffusers/schedulers/scheduling_lms_discrete.py +++ b/src/diffusers/schedulers/scheduling_lms_discrete.py @@ -257,9 +257,21 @@ def add_noise( noise: torch.FloatTensor, timesteps: torch.FloatTensor, ) -> torch.FloatTensor: - sigmas = self.sigmas.to(original_samples.device) - schedule_timesteps = self.timesteps.to(original_samples.device) - timesteps = timesteps.to(original_samples.device) + if self.sigmas.device != original_samples.device: + self.sigmas = self.sigmas.to(original_samples.device) + + # Make sure sigmas are in the same dtype as the samples + if self.sigmas.dtype != original_samples.dtype: + self.sigmas = self.sigmas.to(original_samples.dtype) + + if timesteps.device != original_samples.device: + timesteps = timesteps.to(original_samples.device) + + if self.timesteps.device != original_samples.device: + self.timesteps = self.timesteps.to(original_samples.device) + + schedule_timesteps = self.timesteps + if isinstance(timesteps, torch.IntTensor) or isinstance(timesteps, torch.LongTensor): deprecate( "timesteps as indices", @@ -273,7 +285,7 @@ def add_noise( else: step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] - sigma = sigmas[step_indices].flatten() + sigma = self.sigmas[step_indices].flatten() while len(sigma.shape) < len(original_samples.shape): sigma = sigma.unsqueeze(-1) diff --git a/src/diffusers/schedulers/scheduling_pndm.py b/src/diffusers/schedulers/scheduling_pndm.py index f6a6d6153be5..aafcbaf5aa45 100644 --- a/src/diffusers/schedulers/scheduling_pndm.py +++ b/src/diffusers/schedulers/scheduling_pndm.py @@ -406,6 +406,10 @@ def add_noise( if timesteps.device != original_samples.device: timesteps = timesteps.to(original_samples.device) + # Make sure that alphas_cumprod are in the same dtype as the samples + if self.alphas_cumprod.dtype != original_samples.dtype: + self.alphas_cumprod = self.alphas_cumprod.to(original_samples.dtype) + sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = sqrt_alpha_prod.flatten() while len(sqrt_alpha_prod.shape) < len(original_samples.shape): From 3536b4366b2ff1e13b253e1a72caf2340f6c3834 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Fri, 7 Oct 2022 16:05:17 +0200 Subject: [PATCH 3/4] don't modify vae and pipeline --- src/diffusers/models/vae.py | 8 +--- .../pipeline_stable_diffusion_img2img.py | 46 +++++++++---------- 2 files changed, 23 insertions(+), 31 deletions(-) diff --git a/src/diffusers/models/vae.py b/src/diffusers/models/vae.py index 55f1d757b8df..fe89b41c074e 100644 --- a/src/diffusers/models/vae.py +++ b/src/diffusers/models/vae.py @@ -337,16 +337,12 @@ def __init__(self, parameters, deterministic=False): self.std = torch.exp(0.5 * self.logvar) self.var = torch.exp(self.logvar) if self.deterministic: - self.var = self.std = torch.zeros_like(self.mean).to( - device=self.parameters.device, dtype=self.parameters.dtype - ) + self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) def sample(self, generator: Optional[torch.Generator] = None) -> torch.FloatTensor: device = self.parameters.device sample_device = "cpu" if device.type == "mps" else device - sample = torch.randn(self.mean.shape, generator=generator, device=sample_device).to( - device=device, dtype=self.parameters.dtype - ) + sample = torch.randn(self.mean.shape, generator=generator, device=sample_device).to(device) x = self.mean + self.std * sample return x diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index 37f4be658435..15bdd0208825 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -217,6 +217,26 @@ def __call__( if isinstance(init_image, PIL.Image.Image): init_image = preprocess(init_image) + # encode the init image into latents and scale the latents + init_latent_dist = self.vae.encode(init_image.to(self.device)).latent_dist + init_latents = init_latent_dist.sample(generator=generator) + init_latents = 0.18215 * init_latents + + # expand init_latents for batch_size + init_latents = torch.cat([init_latents] * batch_size * num_images_per_prompt, dim=0) + + # get the original timestep using init_timestep + offset = self.scheduler.config.get("steps_offset", 0) + init_timestep = int(num_inference_steps * strength) + offset + init_timestep = min(init_timestep, num_inference_steps) + + timesteps = self.scheduler.timesteps[-init_timestep] + timesteps = torch.tensor([timesteps] * batch_size * num_images_per_prompt, device=self.device) + + # add noise to latents using the timesteps + noise = torch.randn(init_latents.shape, generator=generator, device=self.device) + init_latents = self.scheduler.add_noise(init_latents, noise, timesteps) + # get prompt text embeddings text_inputs = self.tokenizer( prompt, @@ -277,28 +297,6 @@ def __call__( # to avoid doing two forward passes text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) - # encode the init image into latents and scale the latents - latents_dtype = text_embeddings.dtype - init_image = init_image.to(device=self.device, dtype=latents_dtype) - init_latent_dist = self.vae.encode(init_image).latent_dist - init_latents = init_latent_dist.sample(generator=generator) - init_latents = 0.18215 * init_latents - - # expand init_latents for batch_size - init_latents = torch.cat([init_latents] * batch_size * num_images_per_prompt, dim=0) - - # get the original timestep using init_timestep - offset = self.scheduler.config.get("steps_offset", 0) - init_timestep = int(num_inference_steps * strength) + offset - init_timestep = min(init_timestep, num_inference_steps) - - timesteps = self.scheduler.timesteps[-init_timestep] - timesteps = torch.tensor([timesteps] * batch_size * num_images_per_prompt, device=self.device) - - # add noise to latents using the timesteps - noise = torch.randn(init_latents.shape, generator=generator, device=self.device, dtype=latents_dtype) - init_latents = self.scheduler.add_noise(init_latents, noise, timesteps).to(latents_dtype) - # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 @@ -343,9 +341,7 @@ def __call__( image = image.cpu().permute(0, 2, 3, 1).numpy() safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device) - image, has_nsfw_concept = self.safety_checker( - images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype) - ) + image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_checker_input.pixel_values) if output_type == "pil": image = self.numpy_to_pil(image) From 9f48d7975a74a38b8a84de5bec6b4ea186b2a78a Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Fri, 7 Oct 2022 16:10:48 +0200 Subject: [PATCH 4/4] remove the if --- src/diffusers/schedulers/scheduling_ddim.py | 12 +++--------- src/diffusers/schedulers/scheduling_ddpm.py | 12 +++--------- .../schedulers/scheduling_lms_discrete.py | 16 ++++------------ src/diffusers/schedulers/scheduling_pndm.py | 12 +++--------- 4 files changed, 13 insertions(+), 39 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index da8e81ea2a9f..2d24ecac1d95 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -300,15 +300,9 @@ def add_noise( noise: torch.FloatTensor, timesteps: torch.IntTensor, ) -> torch.FloatTensor: - if self.alphas_cumprod.device != original_samples.device: - self.alphas_cumprod = self.alphas_cumprod.to(original_samples.device) - - if timesteps.device != original_samples.device: - timesteps = timesteps.to(original_samples.device) - - # Make sure that alphas_cumprod are in the same dtype as the samples - if self.alphas_cumprod.dtype != original_samples.dtype: - self.alphas_cumprod = self.alphas_cumprod.to(original_samples.dtype) + # Make sure alphas_cumprod and timestep have same device and dtype as original_samples + self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) + timesteps = timesteps.to(original_samples.device) sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = sqrt_alpha_prod.flatten() diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index 2edd0f76b11f..77ed98137708 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -294,15 +294,9 @@ def add_noise( noise: torch.FloatTensor, timesteps: torch.IntTensor, ) -> torch.FloatTensor: - if self.alphas_cumprod.device != original_samples.device: - self.alphas_cumprod = self.alphas_cumprod.to(original_samples.device) - - if timesteps.device != original_samples.device: - timesteps = timesteps.to(original_samples.device) - - # Make sure that alphas_cumprod are in the same dtype as the samples - if self.alphas_cumprod.dtype != original_samples.dtype: - self.alphas_cumprod = self.alphas_cumprod.to(original_samples.dtype) + # Make sure alphas_cumprod and timestep have same device and dtype as original_samples + self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) + timesteps = timesteps.to(original_samples.device) sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = sqrt_alpha_prod.flatten() diff --git a/src/diffusers/schedulers/scheduling_lms_discrete.py b/src/diffusers/schedulers/scheduling_lms_discrete.py index 665052d4b542..1f6187c727c9 100644 --- a/src/diffusers/schedulers/scheduling_lms_discrete.py +++ b/src/diffusers/schedulers/scheduling_lms_discrete.py @@ -257,18 +257,10 @@ def add_noise( noise: torch.FloatTensor, timesteps: torch.FloatTensor, ) -> torch.FloatTensor: - if self.sigmas.device != original_samples.device: - self.sigmas = self.sigmas.to(original_samples.device) - - # Make sure sigmas are in the same dtype as the samples - if self.sigmas.dtype != original_samples.dtype: - self.sigmas = self.sigmas.to(original_samples.dtype) - - if timesteps.device != original_samples.device: - timesteps = timesteps.to(original_samples.device) - - if self.timesteps.device != original_samples.device: - self.timesteps = self.timesteps.to(original_samples.device) + # Make sure sigmas and timesteps have the same device and dtype as original_samples + self.sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) + self.timesteps = self.timesteps.to(original_samples.device) + timesteps = timesteps.to(original_samples.device) schedule_timesteps = self.timesteps diff --git a/src/diffusers/schedulers/scheduling_pndm.py b/src/diffusers/schedulers/scheduling_pndm.py index aafcbaf5aa45..b29712e1e736 100644 --- a/src/diffusers/schedulers/scheduling_pndm.py +++ b/src/diffusers/schedulers/scheduling_pndm.py @@ -400,15 +400,9 @@ def add_noise( noise: torch.FloatTensor, timesteps: torch.IntTensor, ) -> torch.Tensor: - if self.alphas_cumprod.device != original_samples.device: - self.alphas_cumprod = self.alphas_cumprod.to(original_samples.device) - - if timesteps.device != original_samples.device: - timesteps = timesteps.to(original_samples.device) - - # Make sure that alphas_cumprod are in the same dtype as the samples - if self.alphas_cumprod.dtype != original_samples.dtype: - self.alphas_cumprod = self.alphas_cumprod.to(original_samples.dtype) + # Make sure alphas_cumprod and timestep have same device and dtype as original_samples + self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) + timesteps = timesteps.to(original_samples.device) sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = sqrt_alpha_prod.flatten()