Skip to content

Commit 50e66f2

Browse files
authored
[Chore] remove all is from auraflow. (#8980)
remove all is from auraflow.
1 parent 9b8c860 commit 50e66f2

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

src/diffusers/models/transformers/auraflow_transformer_2d.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -138,14 +138,14 @@ def __init__(self, dim, num_attention_heads, attention_head_dim):
138138
self.norm2 = FP32LayerNorm(dim, elementwise_affine=False, bias=False)
139139
self.ff = AuraFlowFeedForward(dim, dim * 4)
140140

141-
def forward(self, hidden_states: torch.FloatTensor, temb: torch.FloatTensor, i=9999):
141+
def forward(self, hidden_states: torch.FloatTensor, temb: torch.FloatTensor):
142142
residual = hidden_states
143143

144144
# Norm + Projection.
145145
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
146146

147147
# Attention.
148-
attn_output = self.attn(hidden_states=norm_hidden_states, i=i)
148+
attn_output = self.attn(hidden_states=norm_hidden_states)
149149

150150
# Process attention outputs for the `hidden_states`.
151151
hidden_states = self.norm2(residual + gate_msa.unsqueeze(1) * attn_output)
@@ -201,7 +201,7 @@ def __init__(self, dim, num_attention_heads, attention_head_dim):
201201
self.ff_context = AuraFlowFeedForward(dim, dim * 4)
202202

203203
def forward(
204-
self, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor, temb: torch.FloatTensor, i=0
204+
self, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor, temb: torch.FloatTensor
205205
):
206206
residual = hidden_states
207207
residual_context = encoder_hidden_states
@@ -214,7 +214,7 @@ def forward(
214214

215215
# Attention.
216216
attn_output, context_attn_output = self.attn(
217-
hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states, i=i
217+
hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states
218218
)
219219

220220
# Process attention outputs for the `hidden_states`.
@@ -366,7 +366,7 @@ def custom_forward(*inputs):
366366

367367
else:
368368
encoder_hidden_states, hidden_states = block(
369-
hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb, i=index_block
369+
hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb
370370
)
371371

372372
# Single DiT blocks that combine the `hidden_states` (image) and `encoder_hidden_states` (text)

0 commit comments

Comments
 (0)