@@ -309,11 +309,7 @@ def __init__(
309
309
num_heads : int ,
310
310
qkv_bias : bool = True ,
311
311
proj_bias : bool = True ,
312
- pretrained_window_size : Optional [List [int ]] = None ,
313
312
):
314
- if pretrained_window_size is None :
315
- pretrained_window_size = [0 , 0 ]
316
- self .pretrained_window_size = pretrained_window_size
317
313
super ().__init__ (
318
314
dim ,
319
315
window_size ,
@@ -338,12 +334,10 @@ def define_relative_position_bias_table(self):
338
334
relative_coords_w = torch .arange (- (self .window_size [1 ] - 1 ), self .window_size [1 ], dtype = torch .float32 )
339
335
relative_coords_table = torch .stack (torch .meshgrid ([relative_coords_h , relative_coords_w ], indexing = "ij" ))
340
336
relative_coords_table = relative_coords_table .permute (1 , 2 , 0 ).contiguous ().unsqueeze (0 ) # 1, 2*Wh-1, 2*Ww-1, 2
341
- if self .pretrained_window_size [0 ] > 0 :
342
- relative_coords_table [:, :, :, 0 ] /= self .pretrained_window_size [0 ] - 1
343
- relative_coords_table [:, :, :, 1 ] /= self .pretrained_window_size [1 ] - 1
344
- else :
345
- relative_coords_table [:, :, :, 0 ] /= self .window_size [0 ] - 1
346
- relative_coords_table [:, :, :, 1 ] /= self .window_size [1 ] - 1
337
+
338
+ relative_coords_table [:, :, :, 0 ] /= self .window_size [0 ] - 1
339
+ relative_coords_table [:, :, :, 1 ] /= self .window_size [1 ] - 1
340
+
347
341
relative_coords_table *= 8 # normalize to -8, 8
348
342
relative_coords_table = (
349
343
torch .sign (relative_coords_table ) * torch .log2 (torch .abs (relative_coords_table ) + 1.0 ) / 3.0
@@ -446,7 +440,6 @@ class SwinTransformerBlockV2(SwinTransformerBlock):
446
440
stochastic_depth_prob: (float): Stochastic depth rate. Default: 0.0.
447
441
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
448
442
attn_layer (nn.Module): Attention layer. Default: ShiftedWindowAttentionV2.
449
- pretrained_window_size (int): Local window size in pre-training. Default: 0.
450
443
"""
451
444
452
445
def __init__ (
@@ -459,7 +452,6 @@ def __init__(
459
452
stochastic_depth_prob : float = 0.0 ,
460
453
norm_layer : Callable [..., nn .Module ] = nn .LayerNorm ,
461
454
attn_layer : Callable [..., nn .Module ] = ShiftedWindowAttentionV2 ,
462
- pretrained_window_size : int = 0 ,
463
455
):
464
456
super ().__init__ (
465
457
dim ,
@@ -470,7 +462,6 @@ def __init__(
470
462
stochastic_depth_prob = stochastic_depth_prob ,
471
463
norm_layer = norm_layer ,
472
464
attn_layer = attn_layer ,
473
- pretrained_window_size = [pretrained_window_size , pretrained_window_size ],
474
465
)
475
466
476
467
def forward (self , x : Tensor ):
@@ -494,7 +485,6 @@ class SwinTransformer(nn.Module):
494
485
num_classes (int): Number of classes for classification head. Default: 1000.
495
486
block (nn.Module, optional): SwinTransformer Block. Default: None.
496
487
norm_layer (nn.Module, optional): Normalization layer. Default: None.
497
- pretrained_window_sizes (List[int]): Pretrained window sizes of each layer for Swin Transformer V2. Default: [0, 0, 0, 0].
498
488
"""
499
489
500
490
def __init__ (
@@ -510,7 +500,6 @@ def __init__(
510
500
block : Callable [..., nn .Module ] = SwinTransformerBlock ,
511
501
norm_layer : Optional [Callable [..., nn .Module ]] = None ,
512
502
downsample_layer : Callable [..., nn .Module ] = PatchMerging ,
513
- pretrained_window_sizes : Optional [List [int ]] = None ,
514
503
):
515
504
super ().__init__ ()
516
505
_log_api_usage_once (self )
@@ -537,9 +526,6 @@ def __init__(
537
526
for i_stage in range (len (depths )):
538
527
stage : List [nn .Module ] = []
539
528
dim = embed_dim * 2 ** i_stage
540
- kwargs : Dict [str , Any ] = {}
541
- if pretrained_window_sizes is not None :
542
- kwargs ["pretrained_window_size" ] = pretrained_window_sizes [i_stage ]
543
529
for i_layer in range (depths [i_stage ]):
544
530
# adjust stochastic depth probability based on the depth of the stage block
545
531
sd_prob = stochastic_depth_prob * float (stage_block_id ) / (total_stage_blocks - 1 )
@@ -552,7 +538,6 @@ def __init__(
552
538
mlp_ratio = mlp_ratio ,
553
539
stochastic_depth_prob = sd_prob ,
554
540
norm_layer = norm_layer ,
555
- ** kwargs ,
556
541
)
557
542
)
558
543
stage_block_id += 1
0 commit comments