@@ -152,6 +152,7 @@ def setup(self):
152152 self .value = nn .Dense (inner_dim , use_bias = False , dtype = self .dtype , name = "to_v" )
153153
154154 self .proj_attn = nn .Dense (self .query_dim , dtype = self .dtype , name = "to_out_0" )
155+ self .dropout_layer = nn .Dropout (rate = self .dropout )
155156
156157 def reshape_heads_to_batch_dim (self , tensor ):
157158 batch_size , seq_len , dim = tensor .shape
@@ -169,7 +170,6 @@ def reshape_batch_dim_to_heads(self, tensor):
169170 tensor = tensor .reshape (batch_size // head_size , seq_len , dim * head_size )
170171 return tensor
171172
172- @nn .compact
173173 def __call__ (self , hidden_states , context = None , deterministic = True ):
174174 context = hidden_states if context is None else context
175175
@@ -215,7 +215,7 @@ def __call__(self, hidden_states, context=None, deterministic=True):
215215
216216 hidden_states = self .reshape_batch_dim_to_heads (hidden_states )
217217 hidden_states = self .proj_attn (hidden_states )
218- return nn . Dropout ( rate = self .dropout , deterministic = deterministic )( hidden_states )
218+ return self .dropout_layer ( hidden_states , deterministic = deterministic )
219219
220220
221221class FlaxBasicTransformerBlock (nn .Module ):
@@ -261,8 +261,8 @@ def setup(self):
261261 self .norm1 = nn .LayerNorm (epsilon = 1e-5 , dtype = self .dtype )
262262 self .norm2 = nn .LayerNorm (epsilon = 1e-5 , dtype = self .dtype )
263263 self .norm3 = nn .LayerNorm (epsilon = 1e-5 , dtype = self .dtype )
264+ self .dropout_layer = nn .Dropout (rate = self .dropout )
264265
265- @nn .compact
266266 def __call__ (self , hidden_states , context , deterministic = True ):
267267 # self attention
268268 residual = hidden_states
@@ -282,7 +282,7 @@ def __call__(self, hidden_states, context, deterministic=True):
282282 hidden_states = self .ff (self .norm3 (hidden_states ), deterministic = deterministic )
283283 hidden_states = hidden_states + residual
284284
285- return nn . Dropout ( rate = self .dropout , deterministic = deterministic )( hidden_states )
285+ return self .dropout_layer ( hidden_states , deterministic = deterministic )
286286
287287
288288class FlaxTransformer2DModel (nn .Module ):
@@ -358,7 +358,8 @@ def setup(self):
358358 dtype = self .dtype ,
359359 )
360360
361- @nn .compact
361+ self .dropout_layer = nn .Dropout (rate = self .dropout )
362+
362363 def __call__ (self , hidden_states , context , deterministic = True ):
363364 batch , height , width , channels = hidden_states .shape
364365 residual = hidden_states
@@ -381,7 +382,7 @@ def __call__(self, hidden_states, context, deterministic=True):
381382 hidden_states = self .proj_out (hidden_states )
382383
383384 hidden_states = hidden_states + residual
384- return nn . Dropout ( rate = self .dropout , deterministic = deterministic )( hidden_states )
385+ return self .dropout_layer ( hidden_states , deterministic = deterministic )
385386
386387
387388class FlaxFeedForward (nn .Module ):
@@ -437,9 +438,9 @@ class FlaxGEGLU(nn.Module):
437438 def setup (self ):
438439 inner_dim = self .dim * 4
439440 self .proj = nn .Dense (inner_dim * 2 , dtype = self .dtype )
441+ self .dropout_layer = nn .Dropout (rate = self .dropout )
440442
441- @nn .compact
442443 def __call__ (self , hidden_states , deterministic = True ):
443444 hidden_states = self .proj (hidden_states )
444445 hidden_linear , hidden_gelu = jnp .split (hidden_states , 2 , axis = 2 )
445- return nn . Dropout ( rate = self .dropout , deterministic = deterministic ) (hidden_linear * nn .gelu (hidden_gelu ))
446+ return self .dropout_layer (hidden_linear * nn .gelu (hidden_gelu ), deterministic = deterministic )
0 commit comments