Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 18 additions & 5 deletions src/diffusers/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,9 @@ def __init__(
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)

self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
self.to_out = nn.ModuleList([])
self.to_out.append(nn.Linear(inner_dim, query_dim))
self.to_out.append(nn.Dropout(dropout))

def reshape_heads_to_batch_dim(self, tensor):
batch_size, seq_len, dim = tensor.shape
Expand Down Expand Up @@ -283,7 +285,11 @@ def forward(self, hidden_states, context=None, mask=None):
else:
hidden_states = self._sliced_attention(query, key, value, sequence_length, dim)

return self.to_out(hidden_states)
# linear proj
hidden_states = self.to_out[0](hidden_states)
# dropout
hidden_states = self.to_out[1](hidden_states)
return hidden_states

def _attention(self, query, key, value):
# TODO: use baddbmm for better performance
Expand Down Expand Up @@ -354,12 +360,19 @@ def __init__(
super().__init__()
inner_dim = int(dim * mult)
dim_out = dim_out if dim_out is not None else dim
project_in = GEGLU(dim, inner_dim)
self.net = nn.ModuleList([])

self.net = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out))
# project in
self.net.append(GEGLU(dim, inner_dim))
# project dropout
self.net.append(nn.Dropout(dropout))
# project out
self.net.append(nn.Linear(inner_dim, dim_out))

def forward(self, hidden_states):
return self.net(hidden_states)
for module in self.net:
hidden_states = module(hidden_states)
return hidden_states


# feedforward
Expand Down