13
13
14
14
15
15
__all__ = [
16
- "MViTV2" ,
17
- "MViT_V2_T_Weights" ,
18
- "MViT_V2_S_Weights" ,
19
- "MViT_V2_B_Weights" ,
20
- "mvit_v2_t" ,
21
- "mvit_v2_s" ,
22
- "mvit_v2_b" ,
16
+ "MViT" ,
17
+ "MViT_V1_B_Weights" ,
18
+ "mvit_v1_b" ,
23
19
]
24
20
25
21
26
- # TODO: check if we should implement relative pos embedding (Section 4.1 in the paper). Ref:
27
- # https://github.com/facebookresearch/mvit/blob/main/mvit/models/attention.py#L45
28
22
# TODO: add weights
29
23
# TODO: test on references
30
24
@@ -108,6 +102,7 @@ def __init__(
108
102
kernel_kv : List [int ],
109
103
stride_q : List [int ],
110
104
stride_kv : List [int ],
105
+ residual_pool : bool ,
111
106
dropout : float = 0.0 ,
112
107
norm_layer : Callable [..., nn .Module ] = nn .LayerNorm ,
113
108
) -> None :
@@ -116,6 +111,7 @@ def __init__(
116
111
self .num_heads = num_heads
117
112
self .head_dim = embed_dim // num_heads
118
113
self .scaler = 1.0 / math .sqrt (self .head_dim )
114
+ self .residual_pool = residual_pool
119
115
120
116
self .qkv = nn .Linear (embed_dim , 3 * embed_dim )
121
117
layers : List [nn .Module ] = [nn .Linear (embed_dim , embed_dim )]
@@ -182,7 +178,9 @@ def forward(self, x: torch.Tensor, thw: Tuple[int, int, int]) -> Tuple[torch.Ten
182
178
attn = torch .matmul (self .scaler * q , k .transpose (2 , 3 ))
183
179
attn = attn .softmax (dim = - 1 )
184
180
185
- x = torch .matmul (attn , v ).add_ (q )
181
+ x = torch .matmul (attn , v )
182
+ if self .residual_pool :
183
+ x .add_ (q )
186
184
x = x .transpose (1 , 2 ).reshape (B , - 1 , C )
187
185
x = self .project (x )
188
186
@@ -199,6 +197,7 @@ def __init__(
199
197
kernel_kv : List [int ],
200
198
stride_q : List [int ],
201
199
stride_kv : List [int ],
200
+ residual_pool : bool ,
202
201
dropout : float = 0.0 ,
203
202
stochastic_depth_prob : float = 0.0 ,
204
203
norm_layer : Callable [..., nn .Module ] = nn .LayerNorm ,
@@ -224,6 +223,7 @@ def __init__(
224
223
kernel_kv = kernel_kv ,
225
224
stride_q = stride_q ,
226
225
stride_kv = stride_kv ,
226
+ residual_pool = residual_pool ,
227
227
dropout = dropout ,
228
228
norm_layer = norm_layer ,
229
229
)
@@ -274,7 +274,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
274
274
return torch .cat ((class_token , x ), dim = 1 ).add_ (pos_embedding )
275
275
276
276
277
- class MViTV2 (nn .Module ):
277
+ class MViT (nn .Module ):
278
278
def __init__ (
279
279
self ,
280
280
spatial_size : Tuple [int , int ],
@@ -285,6 +285,7 @@ def __init__(
285
285
pool_kv_stride : List [int ],
286
286
pool_q_stride : List [int ],
287
287
pool_kvq_kernel : List [int ],
288
+ residual_pool : bool ,
288
289
dropout : float = 0.0 ,
289
290
attention_dropout : float = 0.0 ,
290
291
stochastic_depth_prob : float = 0.0 ,
@@ -293,7 +294,7 @@ def __init__(
293
294
norm_layer : Optional [Callable [..., nn .Module ]] = None ,
294
295
) -> None :
295
296
"""
296
- MViT V2 main class.
297
+ MViT main class.
297
298
298
299
Args:
299
300
spatial_size (tuple of ints): The spacial size of the input as ``(H, W)``.
@@ -374,6 +375,7 @@ def __init__(
374
375
kernel_kv = pool_kvq_kernel ,
375
376
stride_q = stride_q ,
376
377
stride_kv = stride_kv ,
378
+ residual_pool = residual_pool ,
377
379
dropout = attention_dropout ,
378
380
stochastic_depth_prob = sd_prob ,
379
381
norm_layer = norm_layer ,
@@ -426,15 +428,15 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
426
428
return x
427
429
428
430
429
- def _mvitv2 (
431
+ def _mvit (
430
432
embed_channels : List [int ],
431
433
blocks : List [int ],
432
434
heads : List [int ],
433
435
stochastic_depth_prob : float ,
434
436
weights : Optional [WeightsEnum ],
435
437
progress : bool ,
436
438
** kwargs : Any ,
437
- ) -> MViTV2 :
439
+ ) -> MViT :
438
440
if weights is not None :
439
441
_ovewrite_named_param (kwargs , "num_classes" , len (weights .meta ["categories" ]))
440
442
assert weights .meta ["min_size" ][0 ] == weights .meta ["min_size" ][1 ]
@@ -443,7 +445,7 @@ def _mvitv2(
443
445
spatial_size = kwargs .pop ("spatial_size" , (224 , 224 ))
444
446
temporal_size = kwargs .pop ("temporal_size" , 16 )
445
447
446
- model = MViTV2 (
448
+ model = MViT (
447
449
spatial_size = spatial_size ,
448
450
temporal_size = temporal_size ,
449
451
embed_channels = embed_channels ,
@@ -452,6 +454,7 @@ def _mvitv2(
452
454
pool_kv_stride = kwargs .pop ("pool_kv_stride" , [1 , 8 , 8 ]),
453
455
pool_q_stride = kwargs .pop ("pool_q_stride" , [1 , 2 , 2 ]),
454
456
pool_kvq_kernel = kwargs .pop ("pool_kvq_kernel" , [3 , 3 , 3 ]),
457
+ residual_pool = kwargs .pop ("residual_pool" , False ),
455
458
stochastic_depth_prob = stochastic_depth_prob ,
456
459
** kwargs ,
457
460
)
@@ -462,82 +465,34 @@ def _mvitv2(
462
465
return model
463
466
464
467
465
- class MViT_V2_T_Weights (WeightsEnum ):
468
+ class MViT_V1_B_Weights (WeightsEnum ):
466
469
pass
467
470
468
471
469
- class MViT_V2_S_Weights (WeightsEnum ):
470
- pass
471
-
472
-
473
- class MViT_V2_B_Weights (WeightsEnum ):
474
- pass
475
-
476
-
477
- def mvit_v2_t (* , weights : Optional [MViT_V2_T_Weights ] = None , progress : bool = True , ** kwargs : Any ) -> MViTV2 :
478
- """
479
- Constructs a tiny MViTV2 architecture from
480
- `MViTv2: Improved Multiscale Vision Transformers for Classification and Detection
481
- <https://arxiv.org/abs/2112.01526>`__ and `Multiscale Vision Transformers
482
- <https://arxiv.org/abs/2104.11227>`__.
483
-
484
- Args:
485
- weights (:class:`~torchvision.models.video.MViT_V2_T_Weights`, optional): The
486
- pretrained weights to use. See
487
- :class:`~torchvision.models.video.MViT_V2_T_Weights` below for
488
- more details, and possible values. By default, no pre-trained
489
- weights are used.
490
- progress (bool, optional): If True, displays a progress bar of the
491
- download to stderr. Default is True.
492
- **kwargs: parameters passed to the ``torchvision.models.video.MViTV2``
493
- base class. Please refer to the `source code
494
- <https://github.com/pytorch/vision/blob/main/torchvision/models/video/mvitv2.py>`_
495
- for more details about this class.
496
-
497
- .. autoclass:: torchvision.models.video.MViT_V2_T_Weights
498
- :members:
499
- """
500
- weights = MViT_V2_T_Weights .verify (weights )
501
-
502
- return _mvitv2 (
503
- spatial_size = (224 , 224 ),
504
- temporal_size = 16 ,
505
- embed_channels = [96 , 192 , 384 , 768 ],
506
- blocks = [1 , 2 , 5 , 2 ],
507
- heads = [1 , 2 , 4 , 8 ],
508
- stochastic_depth_prob = kwargs .pop ("stochastic_depth_prob" , 0.1 ),
509
- weights = weights ,
510
- progress = progress ,
511
- ** kwargs ,
512
- )
513
-
514
-
515
- def mvit_v2_s (* , weights : Optional [MViT_V2_S_Weights ] = None , progress : bool = True , ** kwargs : Any ) -> MViTV2 :
472
+ def mvit_v1_b (* , weights : Optional [MViT_V1_B_Weights ] = None , progress : bool = True , ** kwargs : Any ) -> MViT :
516
473
"""
517
- Constructs a small MViTV2 architecture from
518
- `MViTv2: Improved Multiscale Vision Transformers for Classification and Detection
519
- <https://arxiv.org/abs/2112.01526>`__ and `Multiscale Vision Transformers
520
- <https://arxiv.org/abs/2104.11227>`__.
474
+ Constructs a base MViT-B architecture from
475
+ `Multiscale Vision Transformers <https://arxiv.org/abs/2104.11227>`__.
521
476
522
477
Args:
523
- weights (:class:`~torchvision.models.video.MViT_V2_S_Weights `, optional): The
478
+ weights (:class:`~torchvision.models.video.MViT_V1_B_Weights `, optional): The
524
479
pretrained weights to use. See
525
- :class:`~torchvision.models.video.MViT_V2_S_Weights ` below for
480
+ :class:`~torchvision.models.video.MViT_V1_B_Weights ` below for
526
481
more details, and possible values. By default, no pre-trained
527
482
weights are used.
528
483
progress (bool, optional): If True, displays a progress bar of the
529
484
download to stderr. Default is True.
530
- **kwargs: parameters passed to the ``torchvision.models.video.MViTV2 ``
485
+ **kwargs: parameters passed to the ``torchvision.models.video.MViT ``
531
486
base class. Please refer to the `source code
532
- <https://github.com/pytorch/vision/blob/main/torchvision/models/video/mvitv2 .py>`_
487
+ <https://github.com/pytorch/vision/blob/main/torchvision/models/video/mvit .py>`_
533
488
for more details about this class.
534
489
535
- .. autoclass:: torchvision.models.video.MViT_V2_S_Weights
490
+ .. autoclass:: torchvision.models.video.MViT_V1_B_Weights
536
491
:members:
537
492
"""
538
- weights = MViT_V2_S_Weights .verify (weights )
493
+ weights = MViT_V1_B_Weights .verify (weights )
539
494
540
- return _mvitv2 (
495
+ return _mvit (
541
496
spatial_size = (224 , 224 ),
542
497
temporal_size = 16 ,
543
498
embed_channels = [96 , 192 , 384 , 768 ],
@@ -548,41 +503,3 @@ def mvit_v2_s(*, weights: Optional[MViT_V2_S_Weights] = None, progress: bool = T
548
503
progress = progress ,
549
504
** kwargs ,
550
505
)
551
-
552
-
553
- def mvit_v2_b (* , weights : Optional [MViT_V2_B_Weights ] = None , progress : bool = True , ** kwargs : Any ) -> MViTV2 :
554
- """
555
- Constructs a base MViTV2 architecture from
556
- `MViTv2: Improved Multiscale Vision Transformers for Classification and Detection
557
- <https://arxiv.org/abs/2112.01526>`__ and `Multiscale Vision Transformers
558
- <https://arxiv.org/abs/2104.11227>`__.
559
-
560
- Args:
561
- weights (:class:`~torchvision.models.video.MViT_V2_B_Weights`, optional): The
562
- pretrained weights to use. See
563
- :class:`~torchvision.models.video.MViT_V2_B_Weights` below for
564
- more details, and possible values. By default, no pre-trained
565
- weights are used.
566
- progress (bool, optional): If True, displays a progress bar of the
567
- download to stderr. Default is True.
568
- **kwargs: parameters passed to the ``torchvision.models.video.MViTV2``
569
- base class. Please refer to the `source code
570
- <https://github.com/pytorch/vision/blob/main/torchvision/models/video/mvitv2.py>`_
571
- for more details about this class.
572
-
573
- .. autoclass:: torchvision.models.video.MViT_V2_B_Weights
574
- :members:
575
- """
576
- weights = MViT_V2_B_Weights .verify (weights )
577
-
578
- return _mvitv2 (
579
- spatial_size = (224 , 224 ),
580
- temporal_size = 32 ,
581
- embed_channels = [96 , 192 , 384 , 768 ],
582
- blocks = [2 , 3 , 16 , 3 ],
583
- heads = [1 , 2 , 4 , 8 ],
584
- stochastic_depth_prob = kwargs .pop ("stochastic_depth_prob" , 0.3 ),
585
- weights = weights ,
586
- progress = progress ,
587
- ** kwargs ,
588
- )
0 commit comments