2222from ...configuration_utils import ConfigMixin , register_to_config
2323from ...utils import USE_PEFT_BACKEND , logging , scale_lora_layers , unscale_lora_layers
2424from ..attention import AttentionMixin
25- from ..attention_processor import Attention , AttentionProcessor
25+ from ..attention_processor import Attention
2626from ..embeddings import get_timestep_embedding
2727from ..modeling_outputs import Transformer2DModelOutput
2828from ..modeling_utils import ModelMixin
@@ -78,6 +78,7 @@ def apply_rope(xq: Tensor, freqs_cis: Tensor) -> Tensor:
7878 xq_out = freqs_cis [..., 0 ] * xq_ [..., 0 ] + freqs_cis [..., 1 ] * xq_ [..., 1 ]
7979 return xq_out .reshape (* xq .shape ).type_as (xq )
8080
81+
8182class PhotonAttnProcessor2_0 :
8283 r"""
8384 Processor for implementing Photon-style attention with multi-source tokens and RoPE. Properly integrates with
@@ -133,6 +134,8 @@ def __call__(
133134 attn_output = attn .to_out [1 ](attn_output ) # dropout if present
134135
135136 return attn_output
137+
138+
136139# copied from https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py
137140class EmbedND (nn .Module ):
138141 r"""
@@ -299,9 +302,8 @@ class PhotonBlock(nn.Module):
299302 Produces scale/shift/gating parameters for modulated layers.
300303
301304 Methods:
302- attn_forward(img, txt, pe, modulation, spatial_conditioning=None, attention_mask=None):
303- Compute cross-attention between image and text tokens, with optional spatial conditioning and attention
304- masking.
305+ attn_forward(img, txt, pe, modulation, attention_mask=None):
306+ Compute cross-attention between image and text tokens, with optional attention masking.
305307
306308 Parameters:
307309 img (`torch.Tensor`):
@@ -312,8 +314,6 @@ class PhotonBlock(nn.Module):
312314 Rotary positional embeddings to apply to queries and keys.
313315 modulation (`ModulationOut`):
314316 Scale and shift parameters for modulating image tokens.
315- spatial_conditioning (`torch.Tensor`, *optional*):
316- Extra conditioning tokens of shape `(B, L_cond, hidden_size)`.
317317 attention_mask (`torch.Tensor`, *optional*):
318318 Boolean mask of shape `(B, L_txt)` where 0 marks padding.
319319
@@ -372,7 +372,6 @@ def _attn_forward(
372372 txt : Tensor ,
373373 pe : Tensor ,
374374 modulation : ModulationOut ,
375- spatial_conditioning : None | Tensor = None ,
376375 attention_mask : None | Tensor = None ,
377376 ) -> Tensor :
378377 # image tokens proj and norm
@@ -444,7 +443,6 @@ def forward(
444443 txt : Tensor ,
445444 vec : Tensor ,
446445 pe : Tensor ,
447- spatial_conditioning : Tensor | None = None ,
448446 attention_mask : Tensor | None = None ,
449447 ** _ : dict [str , Any ],
450448 ) -> Tensor :
@@ -461,9 +459,6 @@ def forward(
461459 broadcastable).
462460 pe (`torch.Tensor`):
463461 Rotary positional embeddings applied inside attention.
464- spatial_conditioning (`torch.Tensor`, *optional*):
465- Extra conditioning tokens of shape `(B, L_cond, hidden_size)`. Used only if spatial conditioning is
466- enabled in the block.
467462 attention_mask (`torch.Tensor`, *optional*):
468463 Boolean mask for text tokens of shape `(B, L_txt)`, where `0` marks padding.
469464 **_:
@@ -481,7 +476,6 @@ def forward(
481476 txt ,
482477 pe ,
483478 mod_attn ,
484- spatial_conditioning = spatial_conditioning ,
485479 attention_mask = attention_mask ,
486480 )
487481 img = img + mod_mlp .gate * self ._ffn_forward (img , mod_mlp )
@@ -698,14 +692,6 @@ def __init__(
698692
699693 self .gradient_checkpointing = False
700694
701- def _process_inputs (self , image_latent : Tensor , txt : Tensor , ** _ : Any ) -> tuple [Tensor , Tensor , Tensor ]:
702- txt = self .txt_in (txt )
703- img = img2seq (image_latent , self .patch_size )
704- bs , _ , h , w = image_latent .shape
705- img_ids = get_image_ids (bs , h , w , patch_size = self .patch_size , device = image_latent .device )
706- pe = self .pe_embedder (img_ids )
707- return img , txt , pe
708-
709695 def _compute_timestep_embedding (self , timestep : Tensor , dtype : torch .dtype ) -> Tensor :
710696 return self .time_in (
711697 get_timestep_embedding (
@@ -717,43 +703,6 @@ def _compute_timestep_embedding(self, timestep: Tensor, dtype: torch.dtype) -> T
717703 ).to (dtype )
718704 )
719705
720- def _forward_transformers (
721- self ,
722- image_latent : Tensor ,
723- cross_attn_conditioning : Tensor ,
724- timestep : Optional [Tensor ] = None ,
725- time_embedding : Optional [Tensor ] = None ,
726- attention_mask : Optional [Tensor ] = None ,
727- ** block_kwargs : Any ,
728- ) -> Tensor :
729- img = self .img_in (image_latent )
730-
731- if time_embedding is not None :
732- vec = time_embedding
733- else :
734- if timestep is None :
735- raise ValueError ("Please provide either a timestep or a timestep_embedding" )
736- vec = self ._compute_timestep_embedding (timestep , dtype = img .dtype )
737-
738- for block in self .blocks :
739- if torch .is_grad_enabled () and self .gradient_checkpointing :
740- img = self ._gradient_checkpointing_func (
741- block .__call__ ,
742- img ,
743- cross_attn_conditioning ,
744- vec ,
745- block_kwargs .get ("pe" ),
746- block_kwargs .get ("spatial_conditioning" ),
747- attention_mask ,
748- )
749- else :
750- img = block (
751- img = img , txt = cross_attn_conditioning , vec = vec , attention_mask = attention_mask , ** block_kwargs
752- )
753-
754- img = self .final_layer (img , vec )
755- return img
756-
757706 def forward (
758707 self ,
759708 image_latent : Tensor ,
@@ -797,6 +746,7 @@ def forward(
797746 lora_scale = attention_kwargs .pop ("scale" , 1.0 )
798747 else :
799748 lora_scale = 1.0
749+
800750 if USE_PEFT_BACKEND :
801751 # weight the lora layers by setting `lora_scale` for each PEFT layer
802752 scale_lora_layers (self , lora_scale )
@@ -805,12 +755,50 @@ def forward(
805755 logger .warning (
806756 "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
807757 )
808- img_seq , txt , pe = self ._process_inputs (image_latent , cross_attn_conditioning )
809- img_seq = self ._forward_transformers (img_seq , txt , timestep , pe = pe , attention_mask = cross_attn_mask )
810- output = seq2img (img_seq , self .patch_size , image_latent .shape )
758+
759+ # Process text conditioning
760+ txt = self .txt_in (cross_attn_conditioning )
761+
762+ # Convert image to sequence and embed
763+ img = img2seq (image_latent , self .patch_size )
764+ img = self .img_in (img )
765+
766+ # Generate positional embeddings
767+ bs , _ , h , w = image_latent .shape
768+ img_ids = get_image_ids (bs , h , w , patch_size = self .patch_size , device = image_latent .device )
769+ pe = self .pe_embedder (img_ids )
770+
771+ # Compute time embedding
772+ vec = self ._compute_timestep_embedding (timestep , dtype = img .dtype )
773+
774+ # Apply transformer blocks
775+ for block in self .blocks :
776+ if torch .is_grad_enabled () and self .gradient_checkpointing :
777+ img = self ._gradient_checkpointing_func (
778+ block .__call__ ,
779+ img ,
780+ txt ,
781+ vec ,
782+ pe ,
783+ cross_attn_mask ,
784+ )
785+ else :
786+ img = block (
787+ img = img ,
788+ txt = txt ,
789+ vec = vec ,
790+ pe = pe ,
791+ attention_mask = cross_attn_mask ,
792+ )
793+
794+ # Final layer and convert back to image
795+ img = self .final_layer (img , vec )
796+ output = seq2img (img , self .patch_size , image_latent .shape )
797+
811798 if USE_PEFT_BACKEND :
812799 # remove `lora_scale` from each PEFT layer
813800 unscale_lora_layers (self , lora_scale )
801+
814802 if not return_dict :
815803 return (output ,)
816804 return Transformer2DModelOutput (sample = output )
0 commit comments