Skip to content

Commit f77bab1

Browse files
Joao GomesYosuaMichael
authored andcommitted
[fbsync] Add Video SwinTransformer (#6521)
Summary: * Just start adding mere copy paste * Replace d with t and D with T * Align swin transformer video to image a bit * Rename d -> t * align with 2d impl * align with 2d impl * Add helpful comments and config for 3d * add docs * Add docs * Add configurations * Add docs * Fix bugs * Fix wrong edit * Fix wrong edit * Fix bugs * Fix bugs * Fix as per fx suggestions * Update torchvision/models/video/swin_transformer.py * Fix as per fx suggestions * Fix expect files and code * Update the expect files * Modify video swin * Add min size and min temporal size, num params * Add flops and size * Fix types * Fix url recipe Reviewed By: YosuaMichael Differential Revision: D41376277 fbshipit-source-id: 00ec3c40b12dff7d6404af7c327e6fc209fc6618 Co-authored-by: Yosua Michael Maranatha <[email protected]>
1 parent 8723e6c commit f77bab1

File tree

9 files changed

+773
-2
lines changed

9 files changed

+773
-2
lines changed

docs/source/models.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -518,6 +518,7 @@ pre-trained weights:
518518
models/video_mvit
519519
models/video_resnet
520520
models/video_s3d
521+
models/video_swin_transformer
521522

522523
|
523524

docs/source/models/swin_transformer.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ Model builders
1515
--------------
1616

1717
The following model builders can be used to instantiate an SwinTransformer model (original and V2) with and without pre-trained weights.
18-
All the model builders internally rely on the ``torchvision.models.swin_transformer.SwinTransformer``
18+
All the model builders internally rely on the ``torchvision.models.swin_transformer.SwinTransformer``
1919
base class. Please refer to the `source code
2020
<https://github.com/pytorch/vision/blob/main/torchvision/models/swin_transformer.py>`_ for
2121
more details about this class.
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
Video SwinTransformer
2+
=====================
3+
4+
.. currentmodule:: torchvision.models.video
5+
6+
The Video SwinTransformer model is based on the `Video Swin Transformer <https://arxiv.org/abs/2106.13230>`__ paper.
7+
8+
.. betastatus:: video module
9+
10+
11+
Model builders
12+
--------------
13+
14+
The following model builders can be used to instantiate a VideoResNet model, with or
15+
without pre-trained weights. All the model builders internally rely on the
16+
``torchvision.models.video.swin_transformer.SwinTransformer3d`` base class. Please refer to the `source
17+
code
18+
<https://github.com/pytorch/vision/blob/main/torchvision/models/video/swin_transformer.py>`_ for
19+
more details about this class.
20+
21+
.. autosummary::
22+
:toctree: generated/
23+
:template: function.rst
24+
25+
swin3d_t
26+
swin3d_s
27+
swin3d_b
1.05 KB
Binary file not shown.
1.05 KB
Binary file not shown.
1.05 KB
Binary file not shown.

torchvision/models/swin_transformer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -494,6 +494,8 @@ def __init__(
494494
)
495495

496496
def forward(self, x: Tensor):
497+
# Here is the difference, we apply norm after the attention in V2.
498+
# In V1 we applied norm before the attention.
497499
x = x + self.stochastic_depth(self.norm1(self.attn(x)))
498500
x = x + self.stochastic_depth(self.norm2(self.mlp(x)))
499501
return x
@@ -587,7 +589,7 @@ def __init__(
587589

588590
num_features = embed_dim * 2 ** (len(depths) - 1)
589591
self.norm = norm_layer(num_features)
590-
self.permute = Permute([0, 3, 1, 2])
592+
self.permute = Permute([0, 3, 1, 2]) # B H W C -> B C H W
591593
self.avgpool = nn.AdaptiveAvgPool2d(1)
592594
self.flatten = nn.Flatten(1)
593595
self.head = nn.Linear(num_features, num_classes)

torchvision/models/video/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
from .mvit import *
22
from .resnet import *
33
from .s3d import *
4+
from .swin_transformer import *

0 commit comments

Comments
 (0)