Skip to content

Commit 5f6359f

Browse files
author
davidb
committed
unify the structure of the forward block
1 parent c86aed2 commit 5f6359f

File tree

3 files changed

+82
-101
lines changed

3 files changed

+82
-101
lines changed

src/diffusers/models/attention_processor.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5605,7 +5605,6 @@ def __new__(cls, *args, **kwargs):
56055605
return processor
56065606

56075607

5608-
56095608
ADDED_KV_ATTENTION_PROCESSORS = (
56105609
AttnAddedKVProcessor,
56115610
SlicedAttnAddedKVProcessor,

src/diffusers/models/transformers/transformer_photon.py

Lines changed: 48 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from ...configuration_utils import ConfigMixin, register_to_config
2323
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
2424
from ..attention import AttentionMixin
25-
from ..attention_processor import Attention, AttentionProcessor
25+
from ..attention_processor import Attention
2626
from ..embeddings import get_timestep_embedding
2727
from ..modeling_outputs import Transformer2DModelOutput
2828
from ..modeling_utils import ModelMixin
@@ -78,6 +78,7 @@ def apply_rope(xq: Tensor, freqs_cis: Tensor) -> Tensor:
7878
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
7979
return xq_out.reshape(*xq.shape).type_as(xq)
8080

81+
8182
class PhotonAttnProcessor2_0:
8283
r"""
8384
Processor for implementing Photon-style attention with multi-source tokens and RoPE. Properly integrates with
@@ -133,6 +134,8 @@ def __call__(
133134
attn_output = attn.to_out[1](attn_output) # dropout if present
134135

135136
return attn_output
137+
138+
136139
# copied from https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py
137140
class EmbedND(nn.Module):
138141
r"""
@@ -299,9 +302,8 @@ class PhotonBlock(nn.Module):
299302
Produces scale/shift/gating parameters for modulated layers.
300303
301304
Methods:
302-
attn_forward(img, txt, pe, modulation, spatial_conditioning=None, attention_mask=None):
303-
Compute cross-attention between image and text tokens, with optional spatial conditioning and attention
304-
masking.
305+
attn_forward(img, txt, pe, modulation, attention_mask=None):
306+
Compute cross-attention between image and text tokens, with optional attention masking.
305307
306308
Parameters:
307309
img (`torch.Tensor`):
@@ -312,8 +314,6 @@ class PhotonBlock(nn.Module):
312314
Rotary positional embeddings to apply to queries and keys.
313315
modulation (`ModulationOut`):
314316
Scale and shift parameters for modulating image tokens.
315-
spatial_conditioning (`torch.Tensor`, *optional*):
316-
Extra conditioning tokens of shape `(B, L_cond, hidden_size)`.
317317
attention_mask (`torch.Tensor`, *optional*):
318318
Boolean mask of shape `(B, L_txt)` where 0 marks padding.
319319
@@ -372,7 +372,6 @@ def _attn_forward(
372372
txt: Tensor,
373373
pe: Tensor,
374374
modulation: ModulationOut,
375-
spatial_conditioning: None | Tensor = None,
376375
attention_mask: None | Tensor = None,
377376
) -> Tensor:
378377
# image tokens proj and norm
@@ -444,7 +443,6 @@ def forward(
444443
txt: Tensor,
445444
vec: Tensor,
446445
pe: Tensor,
447-
spatial_conditioning: Tensor | None = None,
448446
attention_mask: Tensor | None = None,
449447
**_: dict[str, Any],
450448
) -> Tensor:
@@ -461,9 +459,6 @@ def forward(
461459
broadcastable).
462460
pe (`torch.Tensor`):
463461
Rotary positional embeddings applied inside attention.
464-
spatial_conditioning (`torch.Tensor`, *optional*):
465-
Extra conditioning tokens of shape `(B, L_cond, hidden_size)`. Used only if spatial conditioning is
466-
enabled in the block.
467462
attention_mask (`torch.Tensor`, *optional*):
468463
Boolean mask for text tokens of shape `(B, L_txt)`, where `0` marks padding.
469464
**_:
@@ -481,7 +476,6 @@ def forward(
481476
txt,
482477
pe,
483478
mod_attn,
484-
spatial_conditioning=spatial_conditioning,
485479
attention_mask=attention_mask,
486480
)
487481
img = img + mod_mlp.gate * self._ffn_forward(img, mod_mlp)
@@ -698,14 +692,6 @@ def __init__(
698692

699693
self.gradient_checkpointing = False
700694

701-
def _process_inputs(self, image_latent: Tensor, txt: Tensor, **_: Any) -> tuple[Tensor, Tensor, Tensor]:
702-
txt = self.txt_in(txt)
703-
img = img2seq(image_latent, self.patch_size)
704-
bs, _, h, w = image_latent.shape
705-
img_ids = get_image_ids(bs, h, w, patch_size=self.patch_size, device=image_latent.device)
706-
pe = self.pe_embedder(img_ids)
707-
return img, txt, pe
708-
709695
def _compute_timestep_embedding(self, timestep: Tensor, dtype: torch.dtype) -> Tensor:
710696
return self.time_in(
711697
get_timestep_embedding(
@@ -717,43 +703,6 @@ def _compute_timestep_embedding(self, timestep: Tensor, dtype: torch.dtype) -> T
717703
).to(dtype)
718704
)
719705

720-
def _forward_transformers(
721-
self,
722-
image_latent: Tensor,
723-
cross_attn_conditioning: Tensor,
724-
timestep: Optional[Tensor] = None,
725-
time_embedding: Optional[Tensor] = None,
726-
attention_mask: Optional[Tensor] = None,
727-
**block_kwargs: Any,
728-
) -> Tensor:
729-
img = self.img_in(image_latent)
730-
731-
if time_embedding is not None:
732-
vec = time_embedding
733-
else:
734-
if timestep is None:
735-
raise ValueError("Please provide either a timestep or a timestep_embedding")
736-
vec = self._compute_timestep_embedding(timestep, dtype=img.dtype)
737-
738-
for block in self.blocks:
739-
if torch.is_grad_enabled() and self.gradient_checkpointing:
740-
img = self._gradient_checkpointing_func(
741-
block.__call__,
742-
img,
743-
cross_attn_conditioning,
744-
vec,
745-
block_kwargs.get("pe"),
746-
block_kwargs.get("spatial_conditioning"),
747-
attention_mask,
748-
)
749-
else:
750-
img = block(
751-
img=img, txt=cross_attn_conditioning, vec=vec, attention_mask=attention_mask, **block_kwargs
752-
)
753-
754-
img = self.final_layer(img, vec)
755-
return img
756-
757706
def forward(
758707
self,
759708
image_latent: Tensor,
@@ -797,6 +746,7 @@ def forward(
797746
lora_scale = attention_kwargs.pop("scale", 1.0)
798747
else:
799748
lora_scale = 1.0
749+
800750
if USE_PEFT_BACKEND:
801751
# weight the lora layers by setting `lora_scale` for each PEFT layer
802752
scale_lora_layers(self, lora_scale)
@@ -805,12 +755,50 @@ def forward(
805755
logger.warning(
806756
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
807757
)
808-
img_seq, txt, pe = self._process_inputs(image_latent, cross_attn_conditioning)
809-
img_seq = self._forward_transformers(img_seq, txt, timestep, pe=pe, attention_mask=cross_attn_mask)
810-
output = seq2img(img_seq, self.patch_size, image_latent.shape)
758+
759+
# Process text conditioning
760+
txt = self.txt_in(cross_attn_conditioning)
761+
762+
# Convert image to sequence and embed
763+
img = img2seq(image_latent, self.patch_size)
764+
img = self.img_in(img)
765+
766+
# Generate positional embeddings
767+
bs, _, h, w = image_latent.shape
768+
img_ids = get_image_ids(bs, h, w, patch_size=self.patch_size, device=image_latent.device)
769+
pe = self.pe_embedder(img_ids)
770+
771+
# Compute time embedding
772+
vec = self._compute_timestep_embedding(timestep, dtype=img.dtype)
773+
774+
# Apply transformer blocks
775+
for block in self.blocks:
776+
if torch.is_grad_enabled() and self.gradient_checkpointing:
777+
img = self._gradient_checkpointing_func(
778+
block.__call__,
779+
img,
780+
txt,
781+
vec,
782+
pe,
783+
cross_attn_mask,
784+
)
785+
else:
786+
img = block(
787+
img=img,
788+
txt=txt,
789+
vec=vec,
790+
pe=pe,
791+
attention_mask=cross_attn_mask,
792+
)
793+
794+
# Final layer and convert back to image
795+
img = self.final_layer(img, vec)
796+
output = seq2img(img, self.patch_size, image_latent.shape)
797+
811798
if USE_PEFT_BACKEND:
812799
# remove `lora_scale` from each PEFT layer
813800
unscale_lora_layers(self, lora_scale)
801+
814802
if not return_dict:
815803
return (output,)
816804
return Transformer2DModelOutput(sample=output)

src/diffusers/pipelines/photon/pipeline_photon.py

Lines changed: 34 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -16,22 +16,21 @@
1616
import inspect
1717
import re
1818
import urllib.parse as ul
19-
from typing import Any, Callable, Dict, List, Optional, Union
19+
from typing import Callable, Dict, List, Optional, Union
2020

2121
import ftfy
2222
import torch
2323
from transformers import (
2424
AutoTokenizer,
2525
GemmaTokenizerFast,
26-
T5EncoderModel,
2726
T5TokenizerFast,
2827
)
2928
from transformers.models.t5gemma.modeling_t5gemma import T5GemmaEncoder
3029

3130
from diffusers.image_processor import PixArtImageProcessor
3231
from diffusers.loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
3332
from diffusers.models import AutoencoderDC, AutoencoderKL
34-
from diffusers.models.transformers.transformer_photon import PhotonTransformer2DModel, seq2img
33+
from diffusers.models.transformers.transformer_photon import PhotonTransformer2DModel
3534
from diffusers.pipelines.photon.pipeline_output import PhotonPipelineOutput
3635
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
3736
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
@@ -45,29 +44,29 @@
4544
DEFAULT_RESOLUTION = 512
4645

4746
ASPECT_RATIO_256_BIN = {
48-
"0.46": [160, 352],
49-
"0.6": [192, 320],
50-
"0.78": [224, 288],
51-
"1.0": [256, 256],
52-
"1.29": [288, 224],
53-
"1.67": [320, 192],
54-
"2.2": [352, 160],
47+
"0.46": [160, 352],
48+
"0.6": [192, 320],
49+
"0.78": [224, 288],
50+
"1.0": [256, 256],
51+
"1.29": [288, 224],
52+
"1.67": [320, 192],
53+
"2.2": [352, 160],
5554
}
5655

5756
ASPECT_RATIO_512_BIN = {
58-
"0.5": [352, 704],
59-
"0.57": [384, 672],
60-
"0.6": [384, 640],
61-
"0.68": [416, 608],
62-
"0.78": [448, 576],
63-
"0.88": [480, 544],
64-
"1.0": [512, 512],
65-
"1.13": [544, 480],
66-
"1.29": [576, 448],
67-
"1.46": [608, 416],
68-
"1.67": [640, 384],
69-
"1.75": [672, 384],
70-
"2.0": [704, 352],
57+
"0.5": [352, 704],
58+
"0.57": [384, 672],
59+
"0.6": [384, 640],
60+
"0.68": [416, 608],
61+
"0.78": [448, 576],
62+
"0.88": [480, 544],
63+
"1.0": [512, 512],
64+
"1.13": [544, 480],
65+
"1.29": [576, 448],
66+
"1.46": [608, 416],
67+
"1.67": [640, 384],
68+
"1.75": [672, 384],
69+
"2.0": [704, 352],
7170
}
7271

7372
logger = logging.get_logger(__name__)
@@ -283,7 +282,7 @@ def __init__(
283282
def vae_spatial_compression_ratio(self):
284283
if hasattr(self.vae, "spatial_compression_ratio"):
285284
return self.vae.spatial_compression_ratio
286-
else: # Flux VAE
285+
else: # Flux VAE
287286
return 2 ** (len(self.vae.config.block_out_channels) - 1)
288287

289288
@property
@@ -461,8 +460,8 @@ def __call__(
461460
Whether or not to return a [`~pipelines.photon.PhotonPipelineOutput`] instead of a plain tuple.
462461
use_resolution_binning (`bool`, *optional*, defaults to `True`):
463462
If set to `True`, the requested height and width are first mapped to the closest resolutions using
464-
predefined aspect ratio bins. After the produced latents are decoded into images, they are resized back to
465-
the requested resolution. Useful for generating non-square images at optimal resolutions.
463+
predefined aspect ratio bins. After the produced latents are decoded into images, they are resized back
464+
to the requested resolution. Useful for generating non-square images at optimal resolutions.
466465
callback_on_step_end (`Callable`, *optional*):
467466
A function that calls at the end of each denoising steps during the inference. The function is called
468467
with the following arguments: `callback_on_step_end(self, step, timestep, callback_kwargs)`.
@@ -572,20 +571,15 @@ def __call__(
572571
# Normalize timestep for the transformer
573572
t_cont = (t.float() / self.scheduler.config.num_train_timesteps).view(1).to(device)
574573

575-
# Process inputs for transformer
576-
img_seq, txt, pe = self.transformer._process_inputs(latents_in, ca_embed)
577-
578-
# Forward through transformer layers
579-
img_seq = self.transformer._forward_transformers(
580-
img_seq,
581-
txt,
582-
time_embedding=self.transformer._compute_timestep_embedding(t_cont, img_seq.dtype),
583-
pe=pe,
584-
attention_mask=ca_mask,
585-
)
586-
587-
# Convert back to image format
588-
noise_pred = seq2img(img_seq, self.transformer.patch_size, latents_in.shape)
574+
# Forward through transformer
575+
noise_pred = self.transformer(
576+
image_latent=latents_in,
577+
timestep=t_cont,
578+
cross_attn_conditioning=ca_embed,
579+
micro_conditioning=None,
580+
cross_attn_mask=ca_mask,
581+
return_dict=False,
582+
)[0]
589583

590584
# Apply CFG
591585
if self.do_classifier_free_guidance:

0 commit comments

Comments
 (0)