Skip to content

Commit 364432b

Browse files
fix: drop nn.compact
1 parent 63ecf70 commit 364432b

File tree

1 file changed

+9
-8
lines changed

1 file changed

+9
-8
lines changed

src/diffusers/models/attention_flax.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,7 @@ def setup(self):
152152
self.value = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_v")
153153

154154
self.proj_attn = nn.Dense(self.query_dim, dtype=self.dtype, name="to_out_0")
155+
self.dropout_layer = nn.Dropout(rate=self.dropout)
155156

156157
def reshape_heads_to_batch_dim(self, tensor):
157158
batch_size, seq_len, dim = tensor.shape
@@ -169,7 +170,6 @@ def reshape_batch_dim_to_heads(self, tensor):
169170
tensor = tensor.reshape(batch_size // head_size, seq_len, dim * head_size)
170171
return tensor
171172

172-
@nn.compact
173173
def __call__(self, hidden_states, context=None, deterministic=True):
174174
context = hidden_states if context is None else context
175175

@@ -215,7 +215,7 @@ def __call__(self, hidden_states, context=None, deterministic=True):
215215

216216
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
217217
hidden_states = self.proj_attn(hidden_states)
218-
return nn.Dropout(rate=self.dropout, deterministic=deterministic)(hidden_states)
218+
return self.dropout_layer(hidden_states, deterministic=deterministic)
219219

220220

221221
class FlaxBasicTransformerBlock(nn.Module):
@@ -261,8 +261,8 @@ def setup(self):
261261
self.norm1 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
262262
self.norm2 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
263263
self.norm3 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
264+
self.dropout_layer = nn.Dropout(rate=self.dropout)
264265

265-
@nn.compact
266266
def __call__(self, hidden_states, context, deterministic=True):
267267
# self attention
268268
residual = hidden_states
@@ -282,7 +282,7 @@ def __call__(self, hidden_states, context, deterministic=True):
282282
hidden_states = self.ff(self.norm3(hidden_states), deterministic=deterministic)
283283
hidden_states = hidden_states + residual
284284

285-
return nn.Dropout(rate=self.dropout, deterministic=deterministic)(hidden_states)
285+
return self.dropout_layer(hidden_states, deterministic=deterministic)
286286

287287

288288
class FlaxTransformer2DModel(nn.Module):
@@ -358,7 +358,8 @@ def setup(self):
358358
dtype=self.dtype,
359359
)
360360

361-
@nn.compact
361+
self.dropout_layer = nn.Dropout(rate=self.dropout)
362+
362363
def __call__(self, hidden_states, context, deterministic=True):
363364
batch, height, width, channels = hidden_states.shape
364365
residual = hidden_states
@@ -381,7 +382,7 @@ def __call__(self, hidden_states, context, deterministic=True):
381382
hidden_states = self.proj_out(hidden_states)
382383

383384
hidden_states = hidden_states + residual
384-
return nn.Dropout(rate=self.dropout, deterministic=deterministic)(hidden_states)
385+
return self.dropout_layer(hidden_states, deterministic=deterministic)
385386

386387

387388
class FlaxFeedForward(nn.Module):
@@ -437,9 +438,9 @@ class FlaxGEGLU(nn.Module):
437438
def setup(self):
438439
inner_dim = self.dim * 4
439440
self.proj = nn.Dense(inner_dim * 2, dtype=self.dtype)
441+
self.dropout_layer = nn.Dropout(rate=self.dropout)
440442

441-
@nn.compact
442443
def __call__(self, hidden_states, deterministic=True):
443444
hidden_states = self.proj(hidden_states)
444445
hidden_linear, hidden_gelu = jnp.split(hidden_states, 2, axis=2)
445-
return nn.Dropout(rate=self.dropout, deterministic=deterministic)(hidden_linear * nn.gelu(hidden_gelu))
446+
return self.dropout_layer(hidden_linear * nn.gelu(hidden_gelu), deterministic=deterministic)

0 commit comments

Comments
 (0)