@@ -179,15 +179,9 @@ def __init__(
179
179
self .num_heads = num_heads
180
180
self .attention_dropout = attention_dropout
181
181
self .dropout = dropout
182
-
183
182
self .qkv = nn .Linear (dim , dim * 3 , bias = qkv_bias )
184
183
self .proj = nn .Linear (dim , dim , bias = proj_bias )
185
184
186
- # define a parameter table of relative position bias
187
- self .relative_position_bias_table = nn .Parameter (
188
- torch .zeros ((2 * window_size - 1 ) * (2 * window_size - 1 ), num_heads )
189
- ) # 2*Wh-1 * 2*Ww-1, nH
190
-
191
185
# get pair-wise relative position index for each token inside the window
192
186
coords_h = torch .arange (self .window_size )
193
187
coords_w = torch .arange (self .window_size )
@@ -199,22 +193,25 @@ def __init__(
199
193
relative_coords [:, :, 1 ] += self .window_size - 1
200
194
relative_coords [:, :, 0 ] *= 2 * self .window_size - 1
201
195
relative_position_index = relative_coords .sum (- 1 ).view (- 1 ) # Wh*Ww*Wh*Ww
202
- self . register_buffer ( "relative_position_index" , relative_position_index )
203
-
204
- nn . init . trunc_normal_ ( self . relative_position_bias_table , std = 0.02 )
205
-
206
- def forward ( self , x : Tensor ):
207
- relative_position_bias = self . relative_position_bias_table [self . relative_position_index ] # type: ignore[index]
196
+
197
+ # define a parameter table of relative position bias
198
+ relative_position_bias_table = torch . zeros (( 2 * window_size - 1 ) * ( 2 * window_size - 1 ), num_heads ) # 2*Wh-1 * 2*Ww-1, nH
199
+ nn . init . trunc_normal_ ( relative_position_bias_table , std = 0.02 )
200
+
201
+ relative_position_bias = relative_position_bias_table [relative_position_index ] # type: ignore[index]
208
202
relative_position_bias = relative_position_bias .view (
209
203
self .window_size * self .window_size , self .window_size * self .window_size , - 1
210
204
)
211
- relative_position_bias = relative_position_bias .permute (2 , 0 , 1 ).contiguous ().unsqueeze (0 )
205
+ self .relative_position_bias = nn .Parameter (relative_position_bias .permute (2 , 0 , 1 ).contiguous ().unsqueeze (0 ))
206
+
207
+ def forward (self , x : Tensor ):
208
+
212
209
213
210
return shifted_window_attention (
214
211
x ,
215
212
self .qkv .weight ,
216
213
self .proj .weight ,
217
- relative_position_bias ,
214
+ self . relative_position_bias ,
218
215
self .window_size ,
219
216
self .num_heads ,
220
217
shift_size = self .shift_size ,
0 commit comments