Skip to content

Commit 0e82fb1

Browse files
Torch compile graph fix (#3286)
* fix more * Fix more * fix more * Apply suggestions from code review * fix * make style * make fix-copies * fix * make sure torch compile * Clean * fix test
1 parent 709cf55 commit 0e82fb1

36 files changed

+109
-60
lines changed

src/diffusers/models/attention.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import torch.nn.functional as F
1919
from torch import nn
2020

21+
from ..utils import maybe_allow_in_graph
2122
from ..utils.import_utils import is_xformers_available
2223
from .attention_processor import Attention
2324
from .embeddings import CombinedTimestepLabelEmbeddings
@@ -193,6 +194,7 @@ def forward(self, hidden_states):
193194
return hidden_states
194195

195196

197+
@maybe_allow_in_graph
196198
class BasicTransformerBlock(nn.Module):
197199
r"""
198200
A basic Transformer block.

src/diffusers/models/attention_processor.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import torch.nn.functional as F
1818
from torch import nn
1919

20-
from ..utils import deprecate, logging
20+
from ..utils import deprecate, logging, maybe_allow_in_graph
2121
from ..utils.import_utils import is_xformers_available
2222

2323

@@ -31,6 +31,7 @@
3131
xformers = None
3232

3333

34+
@maybe_allow_in_graph
3435
class Attention(nn.Module):
3536
r"""
3637
A cross attention layer.

src/diffusers/models/modeling_utils.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,14 @@ def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
7777

7878
def get_parameter_dtype(parameter: torch.nn.Module):
7979
try:
80-
parameters_and_buffers = itertools.chain(parameter.parameters(), parameter.buffers())
81-
return next(parameters_and_buffers).dtype
80+
params = tuple(parameter.parameters())
81+
if len(params) > 0:
82+
return params[0].dtype
83+
84+
buffers = tuple(parameter.buffers())
85+
if len(buffers) > 0:
86+
return buffers[0].dtype
87+
8288
except StopIteration:
8389
# For torch.nn.DataParallel compatibility in PyTorch 1.5
8490

src/diffusers/models/unet_2d_blocks.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -560,7 +560,8 @@ def forward(
560560
hidden_states,
561561
encoder_hidden_states=encoder_hidden_states,
562562
cross_attention_kwargs=cross_attention_kwargs,
563-
).sample
563+
return_dict=False,
564+
)[0]
564565
hidden_states = resnet(hidden_states, temb)
565566

566567
return hidden_states
@@ -868,15 +869,16 @@ def custom_forward(*inputs):
868869
hidden_states,
869870
encoder_hidden_states=encoder_hidden_states,
870871
cross_attention_kwargs=cross_attention_kwargs,
871-
).sample
872+
return_dict=False,
873+
)[0]
872874

873-
output_states += (hidden_states,)
875+
output_states = output_states + (hidden_states,)
874876

875877
if self.downsamplers is not None:
876878
for downsampler in self.downsamplers:
877879
hidden_states = downsampler(hidden_states)
878880

879-
output_states += (hidden_states,)
881+
output_states = output_states + (hidden_states,)
880882

881883
return hidden_states, output_states
882884

@@ -949,13 +951,13 @@ def custom_forward(*inputs):
949951
else:
950952
hidden_states = resnet(hidden_states, temb)
951953

952-
output_states += (hidden_states,)
954+
output_states = output_states + (hidden_states,)
953955

954956
if self.downsamplers is not None:
955957
for downsampler in self.downsamplers:
956958
hidden_states = downsampler(hidden_states)
957959

958-
output_states += (hidden_states,)
960+
output_states = output_states + (hidden_states,)
959961

960962
return hidden_states, output_states
961963

@@ -1342,13 +1344,13 @@ def custom_forward(*inputs):
13421344
else:
13431345
hidden_states = resnet(hidden_states, temb)
13441346

1345-
output_states += (hidden_states,)
1347+
output_states = output_states + (hidden_states,)
13461348

13471349
if self.downsamplers is not None:
13481350
for downsampler in self.downsamplers:
13491351
hidden_states = downsampler(hidden_states, temb)
13501352

1351-
output_states += (hidden_states,)
1353+
output_states = output_states + (hidden_states,)
13521354

13531355
return hidden_states, output_states
13541356

@@ -1466,13 +1468,13 @@ def forward(
14661468
**cross_attention_kwargs,
14671469
)
14681470

1469-
output_states += (hidden_states,)
1471+
output_states = output_states + (hidden_states,)
14701472

14711473
if self.downsamplers is not None:
14721474
for downsampler in self.downsamplers:
14731475
hidden_states = downsampler(hidden_states, temb)
14741476

1475-
output_states += (hidden_states,)
1477+
output_states = output_states + (hidden_states,)
14761478

14771479
return hidden_states, output_states
14781480

@@ -1859,7 +1861,8 @@ def custom_forward(*inputs):
18591861
hidden_states,
18601862
encoder_hidden_states=encoder_hidden_states,
18611863
cross_attention_kwargs=cross_attention_kwargs,
1862-
).sample
1864+
return_dict=False,
1865+
)[0]
18631866

18641867
if self.upsamplers is not None:
18651868
for upsampler in self.upsamplers:

src/diffusers/models/unet_2d_condition.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -682,7 +682,7 @@ def forward(
682682
# `Timesteps` does not contain any weights and will always return f32 tensors
683683
# but time_embedding might actually be running in fp16. so we need to cast here.
684684
# there might be better ways to encapsulate this.
685-
t_emb = t_emb.to(dtype=self.dtype)
685+
t_emb = t_emb.to(dtype=sample.dtype)
686686

687687
emb = self.time_embedding(t_emb, timestep_cond)
688688

@@ -697,7 +697,7 @@ def forward(
697697
# there might be better ways to encapsulate this.
698698
class_labels = class_labels.to(dtype=sample.dtype)
699699

700-
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
700+
class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
701701

702702
if self.config.class_embeddings_concat:
703703
emb = torch.cat([emb, class_emb], dim=-1)

src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -437,7 +437,7 @@ def run_safety_checker(self, image, device, dtype):
437437

438438
def decode_latents(self, latents):
439439
latents = 1 / self.vae.config.scaling_factor * latents
440-
image = self.vae.decode(latents).sample
440+
image = self.vae.decode(latents, return_dict=False)[0]
441441
image = (image / 2 + 0.5).clamp(0, 1)
442442
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
443443
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
@@ -683,15 +683,16 @@ def __call__(
683683
t,
684684
encoder_hidden_states=prompt_embeds,
685685
cross_attention_kwargs=cross_attention_kwargs,
686-
).sample
686+
return_dict=False,
687+
)[0]
687688

688689
# perform guidance
689690
if do_classifier_free_guidance:
690691
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
691692
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
692693

693694
# compute the previous noisy sample x_t -> x_t-1
694-
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
695+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
695696

696697
# call the callback, if provided
697698
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):

src/diffusers/pipelines/deepfloyd_if/pipeline_if.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -793,7 +793,8 @@ def __call__(
793793
t,
794794
encoder_hidden_states=prompt_embeds,
795795
cross_attention_kwargs=cross_attention_kwargs,
796-
).sample
796+
return_dict=False,
797+
)[0]
797798

798799
# perform guidance
799800
if do_classifier_free_guidance:
@@ -805,8 +806,8 @@ def __call__(
805806

806807
# compute the previous noisy sample x_t -> x_t-1
807808
intermediate_images = self.scheduler.step(
808-
noise_pred, t, intermediate_images, **extra_step_kwargs
809-
).prev_sample
809+
noise_pred, t, intermediate_images, **extra_step_kwargs, return_dict=False
810+
)[0]
810811

811812
# call the callback, if provided
812813
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
@@ -829,7 +830,7 @@ def __call__(
829830

830831
# 11. Apply watermark
831832
if self.watermarker is not None:
832-
self.watermarker.apply_watermark(image, self.unet.config.sample_size)
833+
image = self.watermarker.apply_watermark(image, self.unet.config.sample_size)
833834
elif output_type == "pt":
834835
nsfw_detected = None
835836
watermark_detected = None

src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,7 @@ def prepare_extra_step_kwargs(self, generator, eta):
256256
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
257257
def decode_latents(self, latents):
258258
latents = 1 / self.vae.config.scaling_factor * latents
259-
image = self.vae.decode(latents).sample
259+
image = self.vae.decode(latents, return_dict=False)[0]
260260
image = (image / 2 + 0.5).clamp(0, 1)
261261
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
262262
image = image.cpu().permute(0, 2, 3, 1).float().numpy()

src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ def __init__(
134134
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
135135
def decode_latents(self, latents):
136136
latents = 1 / self.vae.config.scaling_factor * latents
137-
image = self.vae.decode(latents).sample
137+
image = self.vae.decode(latents, return_dict=False)[0]
138138
image = (image / 2 + 0.5).clamp(0, 1)
139139
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
140140
image = image.cpu().permute(0, 2, 3, 1).float().numpy()

src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -516,7 +516,7 @@ def run_safety_checker(self, image, device, dtype):
516516
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
517517
def decode_latents(self, latents):
518518
latents = 1 / self.vae.config.scaling_factor * latents
519-
image = self.vae.decode(latents).sample
519+
image = self.vae.decode(latents, return_dict=False)[0]
520520
image = (image / 2 + 0.5).clamp(0, 1)
521521
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
522522
image = image.cpu().permute(0, 2, 3, 1).float().numpy()

0 commit comments

Comments
 (0)