Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions src/diffusers/models/attention_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ def setup(self):
self.value = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_v")

self.proj_attn = nn.Dense(self.query_dim, dtype=self.dtype, name="to_out_0")
self.dropout_layer = nn.Dropout(rate=self.dropout)

def reshape_heads_to_batch_dim(self, tensor):
batch_size, seq_len, dim = tensor.shape
Expand Down Expand Up @@ -214,7 +215,7 @@ def __call__(self, hidden_states, context=None, deterministic=True):

hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
hidden_states = self.proj_attn(hidden_states)
return hidden_states
return self.dropout_layer(hidden_states, deterministic=deterministic)


class FlaxBasicTransformerBlock(nn.Module):
Expand Down Expand Up @@ -260,6 +261,7 @@ def setup(self):
self.norm1 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
self.norm2 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
self.norm3 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
self.dropout_layer = nn.Dropout(rate=self.dropout)

def __call__(self, hidden_states, context, deterministic=True):
# self attention
Expand All @@ -280,7 +282,7 @@ def __call__(self, hidden_states, context, deterministic=True):
hidden_states = self.ff(self.norm3(hidden_states), deterministic=deterministic)
hidden_states = hidden_states + residual

return hidden_states
return self.dropout_layer(hidden_states, deterministic=deterministic)


class FlaxTransformer2DModel(nn.Module):
Expand Down Expand Up @@ -356,6 +358,8 @@ def setup(self):
dtype=self.dtype,
)

self.dropout_layer = nn.Dropout(rate=self.dropout)

def __call__(self, hidden_states, context, deterministic=True):
batch, height, width, channels = hidden_states.shape
residual = hidden_states
Expand All @@ -378,7 +382,7 @@ def __call__(self, hidden_states, context, deterministic=True):
hidden_states = self.proj_out(hidden_states)

hidden_states = hidden_states + residual
return hidden_states
return self.dropout_layer(hidden_states, deterministic=deterministic)


class FlaxFeedForward(nn.Module):
Expand Down Expand Up @@ -409,7 +413,7 @@ def setup(self):
self.net_2 = nn.Dense(self.dim, dtype=self.dtype)

def __call__(self, hidden_states, deterministic=True):
hidden_states = self.net_0(hidden_states)
hidden_states = self.net_0(hidden_states, deterministic=deterministic)
hidden_states = self.net_2(hidden_states)
return hidden_states

Expand All @@ -434,8 +438,9 @@ class FlaxGEGLU(nn.Module):
def setup(self):
inner_dim = self.dim * 4
self.proj = nn.Dense(inner_dim * 2, dtype=self.dtype)
self.dropout_layer = nn.Dropout(rate=self.dropout)

def __call__(self, hidden_states, deterministic=True):
hidden_states = self.proj(hidden_states)
hidden_linear, hidden_gelu = jnp.split(hidden_states, 2, axis=2)
return hidden_linear * nn.gelu(hidden_gelu)
return self.dropout_layer(hidden_linear * nn.gelu(hidden_gelu), deterministic=deterministic)