diff --git a/src/diffusers/models/vae.py b/src/diffusers/models/vae.py index b4484823ac3d..400c3030af90 100644 --- a/src/diffusers/models/vae.py +++ b/src/diffusers/models/vae.py @@ -212,6 +212,7 @@ def forward(self, z): sample = z sample = self.conv_in(sample) + upscale_dtype = next(iter(self.up_blocks.parameters())).dtype if self.training and self.gradient_checkpointing: def create_custom_forward(module): @@ -222,6 +223,7 @@ def custom_forward(*inputs): # middle sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample) + sample = sample.to(upscale_dtype) # up for up_block in self.up_blocks: @@ -229,6 +231,7 @@ def custom_forward(*inputs): else: # middle sample = self.mid_block(sample) + sample = sample.to(upscale_dtype) # up for up_block in self.up_blocks: diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py index 14e5c4ab7cd1..87014f52dfc2 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py @@ -18,6 +18,7 @@ import numpy as np import PIL import torch +import torch.nn.functional as F from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from ...loaders import TextualInversionLoaderMixin @@ -698,10 +699,22 @@ def __call__( # make sure the VAE is in float32 mode, as it overflows in float16 self.vae.to(dtype=torch.float32) + # TODO(Patrick, William) - clean up when attention is refactored + use_torch_2_0_attn = hasattr(F, "scaled_dot_product_attention") + use_xformers = self.vae.decoder.mid_block.attentions[0]._use_memory_efficient_attention_xformers + # if xformers or torch_2_0 is used attention block does not need + # to be in float32 which can save lots of memory + if not use_torch_2_0_attn and not use_xformers: + self.vae.post_quant_conv.to(latents.dtype) + self.vae.decoder.conv_in.to(latents.dtype) + self.vae.decoder.mid_block.to(latents.dtype) + else: + latents = latents.float() + # 11. Convert to PIL - # has_nsfw_concept = False if output_type == "pil": - image = self.decode_latents(latents.float()) + image = self.decode_latents(latents) + image, has_nsfw_concept, _ = self.run_safety_checker(image, device, prompt_embeds.dtype) image = self.numpy_to_pil(image) @@ -710,11 +723,11 @@ def __call__( if self.watermarker is not None: image = self.watermarker.apply_watermark(image) elif output_type == "pt": - latents = 1 / self.vae.config.scaling_factor * latents.float() + latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents).sample has_nsfw_concept = None else: - image = self.decode_latents(latents.float()) + image = self.decode_latents(latents) has_nsfw_concept = None # Offload last model to CPU