From b33274e3de72b2b971bf58c22fff329ba96de45e Mon Sep 17 00:00:00 2001 From: Saurav Maheshkar Date: Wed, 28 Jun 2023 21:52:32 +0530 Subject: [PATCH 1/3] feat: add Dropout to Flax UNet --- src/diffusers/models/attention_flax.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/diffusers/models/attention_flax.py b/src/diffusers/models/attention_flax.py index 4f78b324a8e2..83e9ea3283b2 100644 --- a/src/diffusers/models/attention_flax.py +++ b/src/diffusers/models/attention_flax.py @@ -214,7 +214,7 @@ def __call__(self, hidden_states, context=None, deterministic=True): hidden_states = self.reshape_batch_dim_to_heads(hidden_states) hidden_states = self.proj_attn(hidden_states) - return hidden_states + return nn.Dropout(rate=self.dropout, deterministic=deterministic)(hidden_states) class FlaxBasicTransformerBlock(nn.Module): @@ -280,7 +280,7 @@ def __call__(self, hidden_states, context, deterministic=True): hidden_states = self.ff(self.norm3(hidden_states), deterministic=deterministic) hidden_states = hidden_states + residual - return hidden_states + return nn.Dropout(rate=self.dropout, deterministic=deterministic)(hidden_states) class FlaxTransformer2DModel(nn.Module): @@ -378,7 +378,7 @@ def __call__(self, hidden_states, context, deterministic=True): hidden_states = self.proj_out(hidden_states) hidden_states = hidden_states + residual - return hidden_states + return nn.Dropout(rate=self.dropout, deterministic=deterministic)(hidden_states) class FlaxFeedForward(nn.Module): @@ -409,7 +409,7 @@ def setup(self): self.net_2 = nn.Dense(self.dim, dtype=self.dtype) def __call__(self, hidden_states, deterministic=True): - hidden_states = self.net_0(hidden_states) + hidden_states = self.net_0(hidden_states, deterministic=deterministic) hidden_states = self.net_2(hidden_states) return hidden_states @@ -438,4 +438,4 @@ def setup(self): def __call__(self, hidden_states, deterministic=True): hidden_states = self.proj(hidden_states) hidden_linear, hidden_gelu = jnp.split(hidden_states, 2, axis=2) - return hidden_linear * nn.gelu(hidden_gelu) + return nn.Dropout(rate=self.dropout, deterministic=deterministic)(hidden_linear * nn.gelu(hidden_gelu)) From d4da099177659b5fe07dd5851f979e4a252e2cd6 Mon Sep 17 00:00:00 2001 From: Saurav Maheshkar Date: Thu, 29 Jun 2023 13:25:48 +0530 Subject: [PATCH 2/3] feat: add @compact decorator --- src/diffusers/models/attention_flax.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/diffusers/models/attention_flax.py b/src/diffusers/models/attention_flax.py index 83e9ea3283b2..ac59d4f4cfcd 100644 --- a/src/diffusers/models/attention_flax.py +++ b/src/diffusers/models/attention_flax.py @@ -169,6 +169,7 @@ def reshape_batch_dim_to_heads(self, tensor): tensor = tensor.reshape(batch_size // head_size, seq_len, dim * head_size) return tensor + @nn.compact def __call__(self, hidden_states, context=None, deterministic=True): context = hidden_states if context is None else context @@ -261,6 +262,7 @@ def setup(self): self.norm2 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype) self.norm3 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype) + @nn.compact def __call__(self, hidden_states, context, deterministic=True): # self attention residual = hidden_states @@ -356,6 +358,7 @@ def setup(self): dtype=self.dtype, ) + @nn.compact def __call__(self, hidden_states, context, deterministic=True): batch, height, width, channels = hidden_states.shape residual = hidden_states @@ -435,6 +438,7 @@ def setup(self): inner_dim = self.dim * 4 self.proj = nn.Dense(inner_dim * 2, dtype=self.dtype) + @nn.compact def __call__(self, hidden_states, deterministic=True): hidden_states = self.proj(hidden_states) hidden_linear, hidden_gelu = jnp.split(hidden_states, 2, axis=2) From 364432b5a58e988ad5a2f186c8a876148696bc62 Mon Sep 17 00:00:00 2001 From: Saurav Maheshkar Date: Fri, 7 Jul 2023 00:31:37 +0530 Subject: [PATCH 3/3] fix: drop nn.compact --- src/diffusers/models/attention_flax.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/src/diffusers/models/attention_flax.py b/src/diffusers/models/attention_flax.py index ac59d4f4cfcd..0b160d238431 100644 --- a/src/diffusers/models/attention_flax.py +++ b/src/diffusers/models/attention_flax.py @@ -152,6 +152,7 @@ def setup(self): self.value = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_v") self.proj_attn = nn.Dense(self.query_dim, dtype=self.dtype, name="to_out_0") + self.dropout_layer = nn.Dropout(rate=self.dropout) def reshape_heads_to_batch_dim(self, tensor): batch_size, seq_len, dim = tensor.shape @@ -169,7 +170,6 @@ def reshape_batch_dim_to_heads(self, tensor): tensor = tensor.reshape(batch_size // head_size, seq_len, dim * head_size) return tensor - @nn.compact def __call__(self, hidden_states, context=None, deterministic=True): context = hidden_states if context is None else context @@ -215,7 +215,7 @@ def __call__(self, hidden_states, context=None, deterministic=True): hidden_states = self.reshape_batch_dim_to_heads(hidden_states) hidden_states = self.proj_attn(hidden_states) - return nn.Dropout(rate=self.dropout, deterministic=deterministic)(hidden_states) + return self.dropout_layer(hidden_states, deterministic=deterministic) class FlaxBasicTransformerBlock(nn.Module): @@ -261,8 +261,8 @@ def setup(self): self.norm1 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype) self.norm2 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype) self.norm3 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype) + self.dropout_layer = nn.Dropout(rate=self.dropout) - @nn.compact def __call__(self, hidden_states, context, deterministic=True): # self attention residual = hidden_states @@ -282,7 +282,7 @@ def __call__(self, hidden_states, context, deterministic=True): hidden_states = self.ff(self.norm3(hidden_states), deterministic=deterministic) hidden_states = hidden_states + residual - return nn.Dropout(rate=self.dropout, deterministic=deterministic)(hidden_states) + return self.dropout_layer(hidden_states, deterministic=deterministic) class FlaxTransformer2DModel(nn.Module): @@ -358,7 +358,8 @@ def setup(self): dtype=self.dtype, ) - @nn.compact + self.dropout_layer = nn.Dropout(rate=self.dropout) + def __call__(self, hidden_states, context, deterministic=True): batch, height, width, channels = hidden_states.shape residual = hidden_states @@ -381,7 +382,7 @@ def __call__(self, hidden_states, context, deterministic=True): hidden_states = self.proj_out(hidden_states) hidden_states = hidden_states + residual - return nn.Dropout(rate=self.dropout, deterministic=deterministic)(hidden_states) + return self.dropout_layer(hidden_states, deterministic=deterministic) class FlaxFeedForward(nn.Module): @@ -437,9 +438,9 @@ class FlaxGEGLU(nn.Module): def setup(self): inner_dim = self.dim * 4 self.proj = nn.Dense(inner_dim * 2, dtype=self.dtype) + self.dropout_layer = nn.Dropout(rate=self.dropout) - @nn.compact def __call__(self, hidden_states, deterministic=True): hidden_states = self.proj(hidden_states) hidden_linear, hidden_gelu = jnp.split(hidden_states, 2, axis=2) - return nn.Dropout(rate=self.dropout, deterministic=deterministic)(hidden_linear * nn.gelu(hidden_gelu)) + return self.dropout_layer(hidden_linear * nn.gelu(hidden_gelu), deterministic=deterministic)