Skip to content

Commit f061544

Browse files
committed
move relative_position_bias to __init__
1 parent 2500ff3 commit f061544

File tree

1 file changed

+11
-14
lines changed

1 file changed

+11
-14
lines changed

torchvision/models/swin_transformer.py

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -179,15 +179,9 @@ def __init__(
179179
self.num_heads = num_heads
180180
self.attention_dropout = attention_dropout
181181
self.dropout = dropout
182-
183182
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
184183
self.proj = nn.Linear(dim, dim, bias=proj_bias)
185184

186-
# define a parameter table of relative position bias
187-
self.relative_position_bias_table = nn.Parameter(
188-
torch.zeros((2 * window_size - 1) * (2 * window_size - 1), num_heads)
189-
) # 2*Wh-1 * 2*Ww-1, nH
190-
191185
# get pair-wise relative position index for each token inside the window
192186
coords_h = torch.arange(self.window_size)
193187
coords_w = torch.arange(self.window_size)
@@ -199,22 +193,25 @@ def __init__(
199193
relative_coords[:, :, 1] += self.window_size - 1
200194
relative_coords[:, :, 0] *= 2 * self.window_size - 1
201195
relative_position_index = relative_coords.sum(-1).view(-1) # Wh*Ww*Wh*Ww
202-
self.register_buffer("relative_position_index", relative_position_index)
203-
204-
nn.init.trunc_normal_(self.relative_position_bias_table, std=0.02)
205-
206-
def forward(self, x: Tensor):
207-
relative_position_bias = self.relative_position_bias_table[self.relative_position_index] # type: ignore[index]
196+
197+
# define a parameter table of relative position bias
198+
relative_position_bias_table = torch.zeros((2 * window_size - 1) * (2 * window_size - 1), num_heads) # 2*Wh-1 * 2*Ww-1, nH
199+
nn.init.trunc_normal_(relative_position_bias_table, std=0.02)
200+
201+
relative_position_bias = relative_position_bias_table[relative_position_index] # type: ignore[index]
208202
relative_position_bias = relative_position_bias.view(
209203
self.window_size * self.window_size, self.window_size * self.window_size, -1
210204
)
211-
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous().unsqueeze(0)
205+
self.relative_position_bias = nn.Parameter(relative_position_bias.permute(2, 0, 1).contiguous().unsqueeze(0))
206+
207+
def forward(self, x: Tensor):
208+
212209

213210
return shifted_window_attention(
214211
x,
215212
self.qkv.weight,
216213
self.proj.weight,
217-
relative_position_bias,
214+
self.relative_position_bias,
218215
self.window_size,
219216
self.num_heads,
220217
shift_size=self.shift_size,

0 commit comments

Comments
 (0)