@@ -138,14 +138,14 @@ def __init__(self, dim, num_attention_heads, attention_head_dim):
138
138
self .norm2 = FP32LayerNorm (dim , elementwise_affine = False , bias = False )
139
139
self .ff = AuraFlowFeedForward (dim , dim * 4 )
140
140
141
- def forward (self , hidden_states : torch .FloatTensor , temb : torch .FloatTensor , i = 9999 ):
141
+ def forward (self , hidden_states : torch .FloatTensor , temb : torch .FloatTensor ):
142
142
residual = hidden_states
143
143
144
144
# Norm + Projection.
145
145
norm_hidden_states , gate_msa , shift_mlp , scale_mlp , gate_mlp = self .norm1 (hidden_states , emb = temb )
146
146
147
147
# Attention.
148
- attn_output = self .attn (hidden_states = norm_hidden_states , i = i )
148
+ attn_output = self .attn (hidden_states = norm_hidden_states )
149
149
150
150
# Process attention outputs for the `hidden_states`.
151
151
hidden_states = self .norm2 (residual + gate_msa .unsqueeze (1 ) * attn_output )
@@ -201,7 +201,7 @@ def __init__(self, dim, num_attention_heads, attention_head_dim):
201
201
self .ff_context = AuraFlowFeedForward (dim , dim * 4 )
202
202
203
203
def forward (
204
- self , hidden_states : torch .FloatTensor , encoder_hidden_states : torch .FloatTensor , temb : torch .FloatTensor , i = 0
204
+ self , hidden_states : torch .FloatTensor , encoder_hidden_states : torch .FloatTensor , temb : torch .FloatTensor
205
205
):
206
206
residual = hidden_states
207
207
residual_context = encoder_hidden_states
@@ -214,7 +214,7 @@ def forward(
214
214
215
215
# Attention.
216
216
attn_output , context_attn_output = self .attn (
217
- hidden_states = norm_hidden_states , encoder_hidden_states = norm_encoder_hidden_states , i = i
217
+ hidden_states = norm_hidden_states , encoder_hidden_states = norm_encoder_hidden_states
218
218
)
219
219
220
220
# Process attention outputs for the `hidden_states`.
@@ -366,7 +366,7 @@ def custom_forward(*inputs):
366
366
367
367
else :
368
368
encoder_hidden_states , hidden_states = block (
369
- hidden_states = hidden_states , encoder_hidden_states = encoder_hidden_states , temb = temb , i = index_block
369
+ hidden_states = hidden_states , encoder_hidden_states = encoder_hidden_states , temb = temb
370
370
)
371
371
372
372
# Single DiT blocks that combine the `hidden_states` (image) and `encoder_hidden_states` (text)
0 commit comments