Skip to content

Allow fp16 attn for x4 upscaler #3239

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Apr 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/diffusers/models/vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@ def forward(self, z):
sample = z
sample = self.conv_in(sample)

upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
if self.training and self.gradient_checkpointing:

def create_custom_forward(module):
Expand All @@ -222,13 +223,15 @@ def custom_forward(*inputs):

# middle
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample)
sample = sample.to(upscale_dtype)

# up
for up_block in self.up_blocks:
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample)
else:
# middle
sample = self.mid_block(sample)
sample = sample.to(upscale_dtype)

# up
for up_block in self.up_blocks:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import numpy as np
import PIL
import torch
import torch.nn.functional as F
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer

from ...loaders import TextualInversionLoaderMixin
Expand Down Expand Up @@ -698,10 +699,22 @@ def __call__(
# make sure the VAE is in float32 mode, as it overflows in float16
self.vae.to(dtype=torch.float32)

# TODO(Patrick, William) - clean up when attention is refactored
use_torch_2_0_attn = hasattr(F, "scaled_dot_product_attention")
use_xformers = self.vae.decoder.mid_block.attentions[0]._use_memory_efficient_attention_xformers
# if xformers or torch_2_0 is used attention block does not need
# to be in float32 which can save lots of memory
if not use_torch_2_0_attn and not use_xformers:
self.vae.post_quant_conv.to(latents.dtype)
self.vae.decoder.conv_in.to(latents.dtype)
self.vae.decoder.mid_block.to(latents.dtype)
else:
latents = latents.float()

# 11. Convert to PIL
# has_nsfw_concept = False
if output_type == "pil":
image = self.decode_latents(latents.float())
image = self.decode_latents(latents)

image, has_nsfw_concept, _ = self.run_safety_checker(image, device, prompt_embeds.dtype)

image = self.numpy_to_pil(image)
Expand All @@ -710,11 +723,11 @@ def __call__(
if self.watermarker is not None:
image = self.watermarker.apply_watermark(image)
elif output_type == "pt":
latents = 1 / self.vae.config.scaling_factor * latents.float()
latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents).sample
has_nsfw_concept = None
else:
image = self.decode_latents(latents.float())
image = self.decode_latents(latents)
has_nsfw_concept = None

# Offload last model to CPU
Expand Down