diff --git a/src/diffusers/models/attention_flax.py b/src/diffusers/models/attention_flax.py index 4f78b324a8e2..0b160d238431 100644 --- a/src/diffusers/models/attention_flax.py +++ b/src/diffusers/models/attention_flax.py @@ -152,6 +152,7 @@ def setup(self): self.value = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_v") self.proj_attn = nn.Dense(self.query_dim, dtype=self.dtype, name="to_out_0") + self.dropout_layer = nn.Dropout(rate=self.dropout) def reshape_heads_to_batch_dim(self, tensor): batch_size, seq_len, dim = tensor.shape @@ -214,7 +215,7 @@ def __call__(self, hidden_states, context=None, deterministic=True): hidden_states = self.reshape_batch_dim_to_heads(hidden_states) hidden_states = self.proj_attn(hidden_states) - return hidden_states + return self.dropout_layer(hidden_states, deterministic=deterministic) class FlaxBasicTransformerBlock(nn.Module): @@ -260,6 +261,7 @@ def setup(self): self.norm1 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype) self.norm2 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype) self.norm3 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype) + self.dropout_layer = nn.Dropout(rate=self.dropout) def __call__(self, hidden_states, context, deterministic=True): # self attention @@ -280,7 +282,7 @@ def __call__(self, hidden_states, context, deterministic=True): hidden_states = self.ff(self.norm3(hidden_states), deterministic=deterministic) hidden_states = hidden_states + residual - return hidden_states + return self.dropout_layer(hidden_states, deterministic=deterministic) class FlaxTransformer2DModel(nn.Module): @@ -356,6 +358,8 @@ def setup(self): dtype=self.dtype, ) + self.dropout_layer = nn.Dropout(rate=self.dropout) + def __call__(self, hidden_states, context, deterministic=True): batch, height, width, channels = hidden_states.shape residual = hidden_states @@ -378,7 +382,7 @@ def __call__(self, hidden_states, context, deterministic=True): hidden_states = self.proj_out(hidden_states) hidden_states = hidden_states + residual - return hidden_states + return self.dropout_layer(hidden_states, deterministic=deterministic) class FlaxFeedForward(nn.Module): @@ -409,7 +413,7 @@ def setup(self): self.net_2 = nn.Dense(self.dim, dtype=self.dtype) def __call__(self, hidden_states, deterministic=True): - hidden_states = self.net_0(hidden_states) + hidden_states = self.net_0(hidden_states, deterministic=deterministic) hidden_states = self.net_2(hidden_states) return hidden_states @@ -434,8 +438,9 @@ class FlaxGEGLU(nn.Module): def setup(self): inner_dim = self.dim * 4 self.proj = nn.Dense(inner_dim * 2, dtype=self.dtype) + self.dropout_layer = nn.Dropout(rate=self.dropout) def __call__(self, hidden_states, deterministic=True): hidden_states = self.proj(hidden_states) hidden_linear, hidden_gelu = jnp.split(hidden_states, 2, axis=2) - return hidden_linear * nn.gelu(hidden_gelu) + return self.dropout_layer(hidden_linear * nn.gelu(hidden_gelu), deterministic=deterministic)