From aa4456521fa49d883725f4a9a46d50449111c702 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 28 Apr 2023 23:09:23 +0000 Subject: [PATCH 01/11] fix more --- src/diffusers/models/attention.py | 2 ++ src/diffusers/models/attention_processor.py | 2 ++ src/diffusers/models/modeling_utils.py | 10 ++++++-- src/diffusers/models/unet_2d_blocks.py | 25 +++++++++++-------- src/diffusers/models/unet_2d_condition.py | 2 +- .../pipelines/deepfloyd_if/pipeline_if.py | 9 ++++--- .../pipeline_stable_diffusion.py | 7 +++--- 7 files changed, 36 insertions(+), 21 deletions(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index fb5f6f48b324..37b0bbce888e 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -21,6 +21,7 @@ from ..utils.import_utils import is_xformers_available from .attention_processor import Attention from .embeddings import CombinedTimestepLabelEmbeddings +from torch._dynamo import allow_in_graph if is_xformers_available(): @@ -193,6 +194,7 @@ def forward(self, hidden_states): return hidden_states +@allow_in_graph class BasicTransformerBlock(nn.Module): r""" A basic Transformer block. diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index b8787aed91f2..9a701dc95a09 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -16,6 +16,7 @@ import torch import torch.nn.functional as F from torch import nn +from torch._dynamo import allow_in_graph from ..utils import deprecate, logging from ..utils.import_utils import is_xformers_available @@ -31,6 +32,7 @@ xformers = None +@allow_in_graph class Attention(nn.Module): r""" A cross attention layer. diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 521e99fdd69c..be6a8b58706f 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -77,8 +77,14 @@ def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]: def get_parameter_dtype(parameter: torch.nn.Module): try: - parameters_and_buffers = itertools.chain(parameter.parameters(), parameter.buffers()) - return next(parameters_and_buffers).dtype + params = tuple(parameter.parameters()) + if len(params) > 0: + return params[0] + + buffers = tuple(parameter.buffers()) + if len(buffers) > 0: + return buffers[0] + except StopIteration: # For torch.nn.DataParallel compatibility in PyTorch 1.5 diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index 439c5c34b601..6618c0198324 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -560,7 +560,8 @@ def forward( hidden_states, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, - ).sample + return_dict=False + )[0] hidden_states = resnet(hidden_states, temb) return hidden_states @@ -868,15 +869,16 @@ def custom_forward(*inputs): hidden_states, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, - ).sample + return_dict=False, + )[0] - output_states += (hidden_states,) + output_states = output_states + (hidden_states,) if self.downsamplers is not None: for downsampler in self.downsamplers: hidden_states = downsampler(hidden_states) - output_states += (hidden_states,) + output_states = output_states + (hidden_states,) return hidden_states, output_states @@ -949,13 +951,13 @@ def custom_forward(*inputs): else: hidden_states = resnet(hidden_states, temb) - output_states += (hidden_states,) + output_states = output_states + (hidden_states,) if self.downsamplers is not None: for downsampler in self.downsamplers: hidden_states = downsampler(hidden_states) - output_states += (hidden_states,) + output_states = output_states + (hidden_states,) return hidden_states, output_states @@ -1342,13 +1344,13 @@ def custom_forward(*inputs): else: hidden_states = resnet(hidden_states, temb) - output_states += (hidden_states,) + output_states = output_states + (hidden_states,) if self.downsamplers is not None: for downsampler in self.downsamplers: hidden_states = downsampler(hidden_states, temb) - output_states += (hidden_states,) + output_states = output_states + (hidden_states,) return hidden_states, output_states @@ -1466,13 +1468,13 @@ def forward( **cross_attention_kwargs, ) - output_states += (hidden_states,) + output_states = output_states + (hidden_states,) if self.downsamplers is not None: for downsampler in self.downsamplers: hidden_states = downsampler(hidden_states, temb) - output_states += (hidden_states,) + output_states = output_states + (hidden_states,) return hidden_states, output_states @@ -1859,7 +1861,8 @@ def custom_forward(*inputs): hidden_states, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, - ).sample + return_dict=False, + )[0] if self.upsamplers is not None: for upsampler in self.upsamplers: diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index 38e0fa3b5b2e..86eabff4e3a9 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -682,7 +682,7 @@ def forward( # `Timesteps` does not contain any weights and will always return f32 tensors # but time_embedding might actually be running in fp16. so we need to cast here. # there might be better ways to encapsulate this. - t_emb = t_emb.to(dtype=self.dtype) + t_emb = t_emb.to(dtype=sample.dtype) emb = self.time_embedding(t_emb, timestep_cond) diff --git a/src/diffusers/pipelines/deepfloyd_if/pipeline_if.py b/src/diffusers/pipelines/deepfloyd_if/pipeline_if.py index 479ffa9e6635..448389b9f1f6 100644 --- a/src/diffusers/pipelines/deepfloyd_if/pipeline_if.py +++ b/src/diffusers/pipelines/deepfloyd_if/pipeline_if.py @@ -793,7 +793,8 @@ def __call__( t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, - ).sample + return_dict=False, + )[0] # perform guidance if do_classifier_free_guidance: @@ -805,8 +806,8 @@ def __call__( # compute the previous noisy sample x_t -> x_t-1 intermediate_images = self.scheduler.step( - noise_pred, t, intermediate_images, **extra_step_kwargs - ).prev_sample + noise_pred, t, intermediate_images, **extra_step_kwargs, return_dict=False + )[0] # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): @@ -829,7 +830,7 @@ def __call__( # 11. Apply watermark if self.watermarker is not None: - self.watermarker.apply_watermark(image, self.unet.config.sample_size) + image = self.watermarker.apply_watermark(image, self.unet.config.sample_size) elif output_type == "pt": nsfw_detected = None watermark_detected = None diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 7347d70c4023..4168dc7e9788 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -440,7 +440,7 @@ def run_safety_checker(self, image, device, dtype): def decode_latents(self, latents): latents = 1 / self.vae.config.scaling_factor * latents - image = self.vae.decode(latents).sample + image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() @@ -686,7 +686,8 @@ def __call__( t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, - ).sample + return_dict=False, + )[0] # perform guidance if do_classifier_free_guidance: @@ -694,7 +695,7 @@ def __call__( noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): From 2190a90f58c828ca42f989986cd93d57dce75b43 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Sat, 29 Apr 2023 12:28:03 +0000 Subject: [PATCH 02/11] Fix more --- src/diffusers/models/attention.py | 4 ++-- src/diffusers/models/attention_processor.py | 4 ++-- src/diffusers/utils/__init__.py | 3 ++- src/diffusers/utils/torch_utils.py | 6 ++++++ 4 files changed, 12 insertions(+), 5 deletions(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 37b0bbce888e..134f84fc9d50 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -18,10 +18,10 @@ import torch.nn.functional as F from torch import nn +from ..utils import maybe_allow_in_graph from ..utils.import_utils import is_xformers_available from .attention_processor import Attention from .embeddings import CombinedTimestepLabelEmbeddings -from torch._dynamo import allow_in_graph if is_xformers_available(): @@ -194,7 +194,7 @@ def forward(self, hidden_states): return hidden_states -@allow_in_graph +@maybe_allow_in_graph class BasicTransformerBlock(nn.Module): r""" A basic Transformer block. diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 9a701dc95a09..1acdc9d0653a 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -18,7 +18,7 @@ from torch import nn from torch._dynamo import allow_in_graph -from ..utils import deprecate, logging +from ..utils import deprecate, logging, maybe_allow_in_graph from ..utils.import_utils import is_xformers_available @@ -32,7 +32,7 @@ xformers = None -@allow_in_graph +@maybe_allow_in_graph class Attention(nn.Module): r""" A cross attention layer. diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index f3e4c9d1d0ec..8af28681c97b 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -81,7 +81,7 @@ from .logging import get_logger from .outputs import BaseOutput from .pil_utils import PIL_INTERPOLATION, numpy_to_pil, pt_to_pil -from .torch_utils import is_compiled_module, randn_tensor +from .torch_utils import is_compiled_module, maybe_allow_in_graph, randn_tensor if is_torch_available(): @@ -101,6 +101,7 @@ torch_all_close, torch_device, ) + from .torch_utils import maybe_allow_in_graph from .testing_utils import export_to_video diff --git a/src/diffusers/utils/torch_utils.py b/src/diffusers/utils/torch_utils.py index b9815cbceede..fa4e463e773c 100644 --- a/src/diffusers/utils/torch_utils.py +++ b/src/diffusers/utils/torch_utils.py @@ -25,6 +25,12 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name +try: + from torch._dynamo import allow_in_graph as maybe_allow_in_graph +except (ImportError, ModuleNotFoundError): + def maybe_allow_in_graph(cls): + return cls + def randn_tensor( shape: Union[Tuple, List], From 363edcf551a2aaeb17092350482f0016777e9765 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Sat, 29 Apr 2023 12:44:19 +0000 Subject: [PATCH 03/11] fix more --- src/diffusers/models/attention_processor.py | 1 - src/diffusers/models/unet_2d_condition.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 1acdc9d0653a..7ac88b17999a 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -16,7 +16,6 @@ import torch import torch.nn.functional as F from torch import nn -from torch._dynamo import allow_in_graph from ..utils import deprecate, logging, maybe_allow_in_graph from ..utils.import_utils import is_xformers_available diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index 86eabff4e3a9..83169455fa3e 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -697,7 +697,7 @@ def forward( # there might be better ways to encapsulate this. class_labels = class_labels.to(dtype=sample.dtype) - class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) + class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype) if self.config.class_embeddings_concat: emb = torch.cat([emb, class_emb], dim=-1) From c8b1b1d0390b6f5a885990a7afb7a069167a2f5f Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Sat, 29 Apr 2023 14:59:57 +0200 Subject: [PATCH 04/11] Apply suggestions from code review --- src/diffusers/utils/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index 8af28681c97b..cd3a1b8f3dd4 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -81,7 +81,7 @@ from .logging import get_logger from .outputs import BaseOutput from .pil_utils import PIL_INTERPOLATION, numpy_to_pil, pt_to_pil -from .torch_utils import is_compiled_module, maybe_allow_in_graph, randn_tensor +from .torch_utils import is_compiled_module, randn_tensor if is_torch_available(): From a2b18a9e66c8bd981a3f619e68be65e5b8cf98ad Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Sat, 29 Apr 2023 13:30:02 +0000 Subject: [PATCH 05/11] fix --- src/diffusers/models/modeling_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index be6a8b58706f..6644042077d2 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -79,11 +79,11 @@ def get_parameter_dtype(parameter: torch.nn.Module): try: params = tuple(parameter.parameters()) if len(params) > 0: - return params[0] + return params[0].dtype buffers = tuple(parameter.buffers()) if len(buffers) > 0: - return buffers[0] + return buffers[0].dtype except StopIteration: # For torch.nn.DataParallel compatibility in PyTorch 1.5 From f51568a457d30a91cb1fa7e23283a69a4919be5f Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 1 May 2023 13:40:59 +0000 Subject: [PATCH 06/11] make style --- src/diffusers/models/unet_2d_blocks.py | 2 +- src/diffusers/utils/torch_utils.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index 6618c0198324..57153fa39807 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -560,7 +560,7 @@ def forward( hidden_states, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, - return_dict=False + return_dict=False, )[0] hidden_states = resnet(hidden_states, temb) diff --git a/src/diffusers/utils/torch_utils.py b/src/diffusers/utils/torch_utils.py index fa4e463e773c..2b626a3b425a 100644 --- a/src/diffusers/utils/torch_utils.py +++ b/src/diffusers/utils/torch_utils.py @@ -28,6 +28,7 @@ try: from torch._dynamo import allow_in_graph as maybe_allow_in_graph except (ImportError, ModuleNotFoundError): + def maybe_allow_in_graph(cls): return cls From abb68a2cacc20025539ccf49fd21b9a44196e7d3 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 1 May 2023 13:45:25 +0000 Subject: [PATCH 07/11] make fix-copies --- .../alt_diffusion/pipeline_alt_diffusion.py | 7 ++++--- .../pipeline_paint_by_example.py | 2 +- .../pipeline_semantic_stable_diffusion.py | 2 +- .../pipeline_cycle_diffusion.py | 2 +- ...line_stable_diffusion_attend_and_excite.py | 2 +- .../pipeline_stable_diffusion_controlnet.py | 2 +- .../pipeline_stable_diffusion_depth2img.py | 2 +- .../pipeline_stable_diffusion_diffedit.py | 2 +- ...peline_stable_diffusion_image_variation.py | 2 +- .../pipeline_stable_diffusion_inpaint.py | 2 +- ...ipeline_stable_diffusion_inpaint_legacy.py | 2 +- ...eline_stable_diffusion_instruct_pix2pix.py | 2 +- .../pipeline_stable_diffusion_k_diffusion.py | 2 +- ...ipeline_stable_diffusion_latent_upscale.py | 2 +- ...pipeline_stable_diffusion_model_editing.py | 2 +- .../pipeline_stable_diffusion_panorama.py | 2 +- .../pipeline_stable_diffusion_pix2pix_zero.py | 2 +- .../pipeline_stable_diffusion_sag.py | 2 +- .../pipeline_stable_diffusion_upscale.py | 2 +- .../pipeline_stable_unclip.py | 2 +- .../pipeline_stable_unclip_img2img.py | 2 +- .../pipeline_stable_diffusion_safe.py | 2 +- .../versatile_diffusion/modeling_text_unet.py | 21 +++++++++++-------- ...ipeline_versatile_diffusion_dual_guided.py | 2 +- ...ine_versatile_diffusion_image_variation.py | 2 +- ...eline_versatile_diffusion_text_to_image.py | 2 +- 26 files changed, 40 insertions(+), 36 deletions(-) diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py index ff9474ffd43a..b61703a2146d 100644 --- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py @@ -437,7 +437,7 @@ def run_safety_checker(self, image, device, dtype): def decode_latents(self, latents): latents = 1 / self.vae.config.scaling_factor * latents - image = self.vae.decode(latents).sample + image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() @@ -683,7 +683,8 @@ def __call__( t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, - ).sample + return_dict=False, + )[0] # perform guidance if do_classifier_free_guidance: @@ -691,7 +692,7 @@ def __call__( noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): diff --git a/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py b/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py index ca0a90a5b5ca..d6c069bbb7d0 100644 --- a/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +++ b/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py @@ -256,7 +256,7 @@ def prepare_extra_step_kwargs(self, generator, eta): # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): latents = 1 / self.vae.config.scaling_factor * latents - image = self.vae.decode(latents).sample + image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() diff --git a/src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py b/src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py index 3d5374875d12..fbe436ec9666 100644 --- a/src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +++ b/src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py @@ -134,7 +134,7 @@ def __init__( # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): latents = 1 / self.vae.config.scaling_factor * latents - image = self.vae.decode(latents).sample + image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py index e2accb6d2d2a..a40ba75d04bd 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py @@ -516,7 +516,7 @@ def run_safety_checker(self, image, device, dtype): # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): latents = 1 / self.vae.config.scaling_factor * latents - image = self.vae.decode(latents).sample + image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py index fba2a4e32f88..eec7debc38b7 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py @@ -454,7 +454,7 @@ def run_safety_checker(self, image, device, dtype): # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): latents = 1 / self.vae.config.scaling_factor * latents - image = self.vae.decode(latents).sample + image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py index 3bd7f82d7eb6..e36b0bcdf759 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py @@ -496,7 +496,7 @@ def run_safety_checker(self, image, device, dtype): # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): latents = 1 / self.vae.config.scaling_factor * latents - image = self.vae.decode(latents).sample + image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py index c4f9ae59a4e9..378eb927ca52 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py @@ -326,7 +326,7 @@ def run_safety_checker(self, image, device, dtype): # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): latents = 1 / self.vae.config.scaling_factor * latents - image = self.vae.decode(latents).sample + image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py index 9bef5269fa07..adada63b83f7 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py @@ -648,7 +648,7 @@ def prepare_extra_step_kwargs(self, generator, eta): # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): latents = 1 / self.vae.config.scaling_factor * latents - image = self.vae.decode(latents).sample + image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py index d543593fdbf5..2dc762d62529 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py @@ -195,7 +195,7 @@ def run_safety_checker(self, image, device, dtype): # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): latents = 1 / self.vae.config.scaling_factor * latents - image = self.vae.decode(latents).sample + image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index fb2e5dc424e3..cac7465298cc 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -525,7 +525,7 @@ def prepare_extra_step_kwargs(self, generator, eta): # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): latents = 1 / self.vae.config.scaling_factor * latents - image = self.vae.decode(latents).sample + image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py index 3ad1d5e92273..6d93fba2425e 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py @@ -446,7 +446,7 @@ def run_safety_checker(self, image, device, dtype): # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): latents = 1 / self.vae.config.scaling_factor * latents - image = self.vae.decode(latents).sample + image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py index 49944cdcd636..225e3719b98f 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py @@ -656,7 +656,7 @@ def prepare_extra_step_kwargs(self, generator, eta): # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): latents = 1 / self.vae.config.scaling_factor * latents - image = self.vae.decode(latents).sample + image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py index 99aca66db809..5a21bcafccbc 100755 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py @@ -358,7 +358,7 @@ def run_safety_checker(self, image, device, dtype): # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): latents = 1 / self.vae.config.scaling_factor * latents - image = self.vae.decode(latents).sample + image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py index 822bd49ce31c..fcda8d526c99 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py @@ -221,7 +221,7 @@ def _encode_prompt(self, prompt, device, do_classifier_free_guidance, negative_p # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): latents = 1 / self.vae.config.scaling_factor * latents - image = self.vae.decode(latents).sample + image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_model_editing.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_model_editing.py index b7ded03d529b..3926a4e70ad0 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_model_editing.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_model_editing.py @@ -385,7 +385,7 @@ def run_safety_checker(self, image, device, dtype): # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): latents = 1 / self.vae.config.scaling_factor * latents - image = self.vae.decode(latents).sample + image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py index 392b2a72a76f..facffd7a852a 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py @@ -349,7 +349,7 @@ def run_safety_checker(self, image, device, dtype): # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): latents = 1 / self.vae.config.scaling_factor * latents - image = self.vae.decode(latents).sample + image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py index 6444ec7c8506..b60987edfaca 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py @@ -590,7 +590,7 @@ def run_safety_checker(self, image, device, dtype): # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): latents = 1 / self.vae.config.scaling_factor * latents - image = self.vae.decode(latents).sample + image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py index ebac58e18f62..e58edde1c9fe 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py @@ -366,7 +366,7 @@ def run_safety_checker(self, image, device, dtype): # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): latents = 1 / self.vae.config.scaling_factor * latents - image = self.vae.decode(latents).sample + image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py index 87014f52dfc2..a8c29f32e9e5 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py @@ -373,7 +373,7 @@ def prepare_extra_step_kwargs(self, generator, eta): # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): latents = 1 / self.vae.config.scaling_factor * latents - image = self.vae.decode(latents).sample + image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py index fafb8d1d2800..3e34dcb98132 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py @@ -475,7 +475,7 @@ def _encode_prompt( # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): latents = 1 / self.vae.config.scaling_factor * latents - image = self.vae.decode(latents).sample + image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py index 22b7280f3679..9d6a6c8332fb 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py @@ -430,7 +430,7 @@ def _encode_image( # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): latents = 1 / self.vae.config.scaling_factor * latents - image = self.vae.decode(latents).sample + image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() diff --git a/src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py b/src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py index 87e7b3e6c9eb..f4f7eefcd07a 100644 --- a/src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +++ b/src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py @@ -364,7 +364,7 @@ def run_safety_checker(self, image, device, dtype, enable_safety_guidance): # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): latents = 1 / self.vae.config.scaling_factor * latents - image = self.vae.decode(latents).sample + image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() diff --git a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py index 0959e2bb3a8b..e9e31d67905b 100644 --- a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py @@ -785,7 +785,7 @@ def forward( # `Timesteps` does not contain any weights and will always return f32 tensors # but time_embedding might actually be running in fp16. so we need to cast here. # there might be better ways to encapsulate this. - t_emb = t_emb.to(dtype=self.dtype) + t_emb = t_emb.to(dtype=sample.dtype) emb = self.time_embedding(t_emb, timestep_cond) @@ -800,7 +800,7 @@ def forward( # there might be better ways to encapsulate this. class_labels = class_labels.to(dtype=sample.dtype) - class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) + class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype) if self.config.class_embeddings_concat: emb = torch.cat([emb, class_emb], dim=-1) @@ -1081,13 +1081,13 @@ def custom_forward(*inputs): else: hidden_states = resnet(hidden_states, temb) - output_states += (hidden_states,) + output_states = output_states + (hidden_states,) if self.downsamplers is not None: for downsampler in self.downsamplers: hidden_states = downsampler(hidden_states) - output_states += (hidden_states,) + output_states = output_states + (hidden_states,) return hidden_states, output_states @@ -1211,15 +1211,16 @@ def custom_forward(*inputs): hidden_states, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, - ).sample + return_dict=False, + )[0] - output_states += (hidden_states,) + output_states = output_states + (hidden_states,) if self.downsamplers is not None: for downsampler in self.downsamplers: hidden_states = downsampler(hidden_states) - output_states += (hidden_states,) + output_states = output_states + (hidden_states,) return hidden_states, output_states @@ -1424,7 +1425,8 @@ def custom_forward(*inputs): hidden_states, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, - ).sample + return_dict=False, + )[0] if self.upsamplers is not None: for upsampler in self.upsamplers: @@ -1528,7 +1530,8 @@ def forward( hidden_states, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, - ).sample + return_dict=False, + )[0] hidden_states = resnet(hidden_states, temb) return hidden_states diff --git a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py index 661a1bd3cf73..2827ed4a7378 100644 --- a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +++ b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py @@ -330,7 +330,7 @@ def normalize_embeddings(encoder_output): # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): latents = 1 / self.vae.config.scaling_factor * latents - image = self.vae.decode(latents).sample + image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() diff --git a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py index e3a2ee370362..46eee27bcbfc 100644 --- a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +++ b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py @@ -190,7 +190,7 @@ def normalize_embeddings(encoder_output): # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): latents = 1 / self.vae.config.scaling_factor * latents - image = self.vae.decode(latents).sample + image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() diff --git a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py index 26b9be2bfa76..cd5dd70a2cdc 100644 --- a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +++ b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py @@ -247,7 +247,7 @@ def normalize_embeddings(encoder_output): # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): latents = 1 / self.vae.config.scaling_factor * latents - image = self.vae.decode(latents).sample + image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() From db0dee02a16222e11ac261e96b40abba2f2b5c72 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 1 May 2023 13:50:37 +0000 Subject: [PATCH 08/11] fix --- .../pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py index e58edde1c9fe..27ba46c8b3e7 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py @@ -619,7 +619,7 @@ def __call__( def get_map_size(module, input, output): nonlocal map_size - map_size = output.sample.shape[-2:] + map_size = output[0].shape[-2:] with self.unet.mid_block.attentions[0].register_forward_hook(get_map_size): with self.progress_bar(total=num_inference_steps) as progress_bar: From 5b001cffebe534a57b7cd9c24c588091da7a5aa1 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 1 May 2023 14:04:45 +0000 Subject: [PATCH 09/11] make sure torch compile --- .../stable_diffusion/test_stable_diffusion.py | 40 +++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion.py b/tests/pipelines/stable_diffusion/test_stable_diffusion.py index fcfcd84c5d48..133223ce6e22 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion.py @@ -22,6 +22,7 @@ import numpy as np import torch from huggingface_hub import hf_hub_download +from packaging import version from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer from diffusers import ( @@ -865,6 +866,23 @@ def test_stable_diffusion_textual_inversion(self): max_diff = np.abs(expected_image - image).max() assert max_diff < 5e-2 + def test_stable_diffusion_compile(self): + if torch.__version__ + sd_pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", safety_checker=None) + sd_pipe.scheduler = DDIMScheduler.from_config(sd_pipe.scheduler.config) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.set_progress_bar_config(disable=None) + + inputs = self.get_inputs(torch_device) + image = sd_pipe(**inputs).images + image_slice = image[0, -3:, -3:, -1].flatten() + + assert image.shape == (1, 512, 512, 3) + expected_slice = np.array([0.38019, 0.28647, 0.27321, 0.40377, 0.38290, 0.35446, 0.39218, 0.38165, 0.42239]) + + + + @slow @require_torch_gpu @@ -921,6 +939,28 @@ def test_download_ckpt_diff_format_is_same(self): assert np.max(np.abs(image - image_ckpt)) < 1e-4 + def test_stable_diffusion_compile(self): + if version.parse(torch.__version__) >= version.parse('2.0'): + print(f"Test `test_stable_diffusion_ddim` is skipped because {torch.__version__} is < 2.0") + return + + sd_pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", safety_checker=None) + sd_pipe.scheduler = DDIMScheduler.from_config(sd_pipe.scheduler.config) + sd_pipe = sd_pipe.to(torch_device) + + sd_pipe.unet.to(memory_format=torch.channels_last) + sd_pipe.unet = torch.compile(stage_1.unet, mode="reduce-overhead", fullgraph=True) + + sd_pipe.set_progress_bar_config(disable=None) + + inputs = self.get_inputs(torch_device) + image = sd_pipe(**inputs).images + image_slice = image[0, -3:, -3:, -1].flatten() + + assert image.shape == (1, 512, 512, 3) + expected_slice = np.array([0.38019, 0.28647, 0.27321, 0.40377, 0.38290, 0.35446, 0.39218, 0.38165, 0.42239]) + assert np.abs(image_slice - expected_slice).max() < 1e-4 + @nightly @require_torch_gpu From 9a6599a5709127fac1c036b120dca9d5e1826205 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 1 May 2023 14:07:05 +0000 Subject: [PATCH 10/11] Clean --- .../stable_diffusion/test_stable_diffusion.py | 17 ----------------- 1 file changed, 17 deletions(-) diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion.py b/tests/pipelines/stable_diffusion/test_stable_diffusion.py index 133223ce6e22..a44abeb862ce 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion.py @@ -866,23 +866,6 @@ def test_stable_diffusion_textual_inversion(self): max_diff = np.abs(expected_image - image).max() assert max_diff < 5e-2 - def test_stable_diffusion_compile(self): - if torch.__version__ - sd_pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", safety_checker=None) - sd_pipe.scheduler = DDIMScheduler.from_config(sd_pipe.scheduler.config) - sd_pipe = sd_pipe.to(torch_device) - sd_pipe.set_progress_bar_config(disable=None) - - inputs = self.get_inputs(torch_device) - image = sd_pipe(**inputs).images - image_slice = image[0, -3:, -3:, -1].flatten() - - assert image.shape == (1, 512, 512, 3) - expected_slice = np.array([0.38019, 0.28647, 0.27321, 0.40377, 0.38290, 0.35446, 0.39218, 0.38165, 0.42239]) - - - - @slow @require_torch_gpu From 56908f4664607dcf99d3b48d44f228504350596c Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 1 May 2023 14:07:59 +0000 Subject: [PATCH 11/11] fix test --- tests/pipelines/stable_diffusion/test_stable_diffusion.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion.py b/tests/pipelines/stable_diffusion/test_stable_diffusion.py index a44abeb862ce..e1334e1ddd3b 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion.py @@ -923,7 +923,7 @@ def test_download_ckpt_diff_format_is_same(self): assert np.max(np.abs(image - image_ckpt)) < 1e-4 def test_stable_diffusion_compile(self): - if version.parse(torch.__version__) >= version.parse('2.0'): + if version.parse(torch.__version__) >= version.parse("2.0"): print(f"Test `test_stable_diffusion_ddim` is skipped because {torch.__version__} is < 2.0") return @@ -932,7 +932,7 @@ def test_stable_diffusion_compile(self): sd_pipe = sd_pipe.to(torch_device) sd_pipe.unet.to(memory_format=torch.channels_last) - sd_pipe.unet = torch.compile(stage_1.unet, mode="reduce-overhead", fullgraph=True) + sd_pipe.unet = torch.compile(sd_pipe.unet, mode="reduce-overhead", fullgraph=True) sd_pipe.set_progress_bar_config(disable=None)