diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index bf04c3e6a3ca..af441ef86181 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -244,7 +244,9 @@ def __init__( self.to_k = nn.Linear(context_dim, inner_dim, bias=False) self.to_v = nn.Linear(context_dim, inner_dim, bias=False) - self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)) + self.to_out = nn.ModuleList([]) + self.to_out.append(nn.Linear(inner_dim, query_dim)) + self.to_out.append(nn.Dropout(dropout)) def reshape_heads_to_batch_dim(self, tensor): batch_size, seq_len, dim = tensor.shape @@ -283,7 +285,11 @@ def forward(self, hidden_states, context=None, mask=None): else: hidden_states = self._sliced_attention(query, key, value, sequence_length, dim) - return self.to_out(hidden_states) + # linear proj + hidden_states = self.to_out[0](hidden_states) + # dropout + hidden_states = self.to_out[1](hidden_states) + return hidden_states def _attention(self, query, key, value): # TODO: use baddbmm for better performance @@ -354,12 +360,19 @@ def __init__( super().__init__() inner_dim = int(dim * mult) dim_out = dim_out if dim_out is not None else dim - project_in = GEGLU(dim, inner_dim) + self.net = nn.ModuleList([]) - self.net = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)) + # project in + self.net.append(GEGLU(dim, inner_dim)) + # project dropout + self.net.append(nn.Dropout(dropout)) + # project out + self.net.append(nn.Linear(inner_dim, dim_out)) def forward(self, hidden_states): - return self.net(hidden_states) + for module in self.net: + hidden_states = module(hidden_states) + return hidden_states # feedforward