From 11b728dc2f9908d5db023af4ac0b8c6e67df1d98 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 26 Apr 2023 08:09:36 +0000 Subject: [PATCH 1/4] Add all files --- .../pipeline_stable_diffusion_upscale.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py index 45b26de284af..88710641fca4 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py @@ -697,15 +697,11 @@ def __call__( # 10. Post-processing # make sure the VAE is in float32 mode, as it overflows in float16 self.vae.to(dtype=torch.float32) - image = self.decode_latents(latents.float()) - - # Offload last model to CPU - if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: - self.final_offload_hook.offload() # 11. Convert to PIL # has_nsfw_concept = False if output_type == "pil": + image = self.decode_latents(latents.float()) image, has_nsfw_concept, _ = self.run_safety_checker(image, device, prompt_embeds.dtype) image = self.numpy_to_pil(image) @@ -713,9 +709,19 @@ def __call__( # 11. Apply watermark if self.watermarker is not None: image = self.watermarker.apply_watermark(image) + elif output_type == "pt": + latents = 1 / self.vae.config.scaling_factor * latents.float() + image = self.vae.decode(latents).sample + has_nsfw_concept = None else: + image = self.decode_latents(latents.float()) has_nsfw_concept = None + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + if not return_dict: return (image, has_nsfw_concept) From 5ae6f9ebbb3276b62b40ed84b298d2aaf186d472 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 26 Apr 2023 08:10:02 +0000 Subject: [PATCH 2/4] update --- .../stable_diffusion/pipeline_stable_diffusion_upscale.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py index 88710641fca4..14e5c4ab7cd1 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py @@ -721,7 +721,6 @@ def __call__( if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.final_offload_hook.offload() - if not return_dict: return (image, has_nsfw_concept) From b56562a7ff1cc9ee5774968b6a30f4160d896920 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 26 Apr 2023 09:55:57 +0000 Subject: [PATCH 3/4] Make sure vae is memory efficient for PT 1 --- src/diffusers/models/vae.py | 3 +++ .../pipeline_stable_diffusion_upscale.py | 20 +++++++++++++++---- 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/src/diffusers/models/vae.py b/src/diffusers/models/vae.py index b4484823ac3d..400c3030af90 100644 --- a/src/diffusers/models/vae.py +++ b/src/diffusers/models/vae.py @@ -212,6 +212,7 @@ def forward(self, z): sample = z sample = self.conv_in(sample) + upscale_dtype = next(iter(self.up_blocks.parameters())).dtype if self.training and self.gradient_checkpointing: def create_custom_forward(module): @@ -222,6 +223,7 @@ def custom_forward(*inputs): # middle sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample) + sample = sample.to(upscale_dtype) # up for up_block in self.up_blocks: @@ -229,6 +231,7 @@ def custom_forward(*inputs): else: # middle sample = self.mid_block(sample) + sample = sample.to(upscale_dtype) # up for up_block in self.up_blocks: diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py index 14e5c4ab7cd1..930d119fa0f0 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py @@ -16,6 +16,7 @@ from typing import Any, Callable, List, Optional, Union import numpy as np +import torch.nn.functional as F import PIL import torch from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer @@ -698,10 +699,21 @@ def __call__( # make sure the VAE is in float32 mode, as it overflows in float16 self.vae.to(dtype=torch.float32) + # TODO(Patrick, William) - clean up when attention is refactored + use_torch_2_0_attn = hasattr(F, "scaled_dot_product_attention") + use_xformers = self.vae.decoder.mid_block.attentions[0]._use_memory_efficient_attention_xformers + # if xformers or torch_2_0 is used attention block does not need + # to be in float32 which can save lots of memory + if not use_torch_2_0_attn and not use_xformers: + self.vae.post_quant_conv.to(latents.dtype) + self.vae.decoder.conv_in.to(latents.dtype) + self.vae.decoder.mid_block.to(latents.dtype) + else: + latents = latents.float() + # 11. Convert to PIL - # has_nsfw_concept = False if output_type == "pil": - image = self.decode_latents(latents.float()) + image = self.decode_latents(latents) image, has_nsfw_concept, _ = self.run_safety_checker(image, device, prompt_embeds.dtype) image = self.numpy_to_pil(image) @@ -710,11 +722,11 @@ def __call__( if self.watermarker is not None: image = self.watermarker.apply_watermark(image) elif output_type == "pt": - latents = 1 / self.vae.config.scaling_factor * latents.float() + latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents).sample has_nsfw_concept = None else: - image = self.decode_latents(latents.float()) + image = self.decode_latents(latents) has_nsfw_concept = None # Offload last model to CPU From 4a7a893aead4810405d3ce1db8cacc61ed513526 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 26 Apr 2023 09:58:39 +0000 Subject: [PATCH 4/4] make style --- .../stable_diffusion/pipeline_stable_diffusion_upscale.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py index 9e753232959e..87014f52dfc2 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py @@ -16,9 +16,9 @@ from typing import Any, Callable, List, Optional, Union import numpy as np -import torch.nn.functional as F import PIL import torch +import torch.nn.functional as F from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from ...loaders import TextualInversionLoaderMixin @@ -702,7 +702,7 @@ def __call__( # TODO(Patrick, William) - clean up when attention is refactored use_torch_2_0_attn = hasattr(F, "scaled_dot_product_attention") use_xformers = self.vae.decoder.mid_block.attentions[0]._use_memory_efficient_attention_xformers - # if xformers or torch_2_0 is used attention block does not need + # if xformers or torch_2_0 is used attention block does not need # to be in float32 which can save lots of memory if not use_torch_2_0_attn and not use_xformers: self.vae.post_quant_conv.to(latents.dtype)