Skip to content

Commit abbf3c1

Browse files
Allow fp16 attn for x4 upscaler (#3239)
* Add all files * update * Make sure vae is memory efficient for PT 1 * make style
1 parent da2ce1a commit abbf3c1

File tree

2 files changed

+20
-4
lines changed

2 files changed

+20
-4
lines changed

src/diffusers/models/vae.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,7 @@ def forward(self, z):
212212
sample = z
213213
sample = self.conv_in(sample)
214214

215+
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
215216
if self.training and self.gradient_checkpointing:
216217

217218
def create_custom_forward(module):
@@ -222,13 +223,15 @@ def custom_forward(*inputs):
222223

223224
# middle
224225
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample)
226+
sample = sample.to(upscale_dtype)
225227

226228
# up
227229
for up_block in self.up_blocks:
228230
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample)
229231
else:
230232
# middle
231233
sample = self.mid_block(sample)
234+
sample = sample.to(upscale_dtype)
232235

233236
# up
234237
for up_block in self.up_blocks:

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import numpy as np
1919
import PIL
2020
import torch
21+
import torch.nn.functional as F
2122
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
2223

2324
from ...loaders import TextualInversionLoaderMixin
@@ -698,10 +699,22 @@ def __call__(
698699
# make sure the VAE is in float32 mode, as it overflows in float16
699700
self.vae.to(dtype=torch.float32)
700701

702+
# TODO(Patrick, William) - clean up when attention is refactored
703+
use_torch_2_0_attn = hasattr(F, "scaled_dot_product_attention")
704+
use_xformers = self.vae.decoder.mid_block.attentions[0]._use_memory_efficient_attention_xformers
705+
# if xformers or torch_2_0 is used attention block does not need
706+
# to be in float32 which can save lots of memory
707+
if not use_torch_2_0_attn and not use_xformers:
708+
self.vae.post_quant_conv.to(latents.dtype)
709+
self.vae.decoder.conv_in.to(latents.dtype)
710+
self.vae.decoder.mid_block.to(latents.dtype)
711+
else:
712+
latents = latents.float()
713+
701714
# 11. Convert to PIL
702-
# has_nsfw_concept = False
703715
if output_type == "pil":
704-
image = self.decode_latents(latents.float())
716+
image = self.decode_latents(latents)
717+
705718
image, has_nsfw_concept, _ = self.run_safety_checker(image, device, prompt_embeds.dtype)
706719

707720
image = self.numpy_to_pil(image)
@@ -710,11 +723,11 @@ def __call__(
710723
if self.watermarker is not None:
711724
image = self.watermarker.apply_watermark(image)
712725
elif output_type == "pt":
713-
latents = 1 / self.vae.config.scaling_factor * latents.float()
726+
latents = 1 / self.vae.config.scaling_factor * latents
714727
image = self.vae.decode(latents).sample
715728
has_nsfw_concept = None
716729
else:
717-
image = self.decode_latents(latents.float())
730+
image = self.decode_latents(latents)
718731
has_nsfw_concept = None
719732

720733
# Offload last model to CPU

0 commit comments

Comments
 (0)