Skip to content

Commit 61c9894

Browse files
committed
Switch implementation to v1 variant.
1 parent 69095dd commit 61c9894

File tree

8 files changed

+36
-128
lines changed

8 files changed

+36
-128
lines changed

docs/source/models.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -459,7 +459,7 @@ pre-trained weights:
459459
.. toctree::
460460
:maxdepth: 1
461461

462-
models/video_mvitv2
462+
models/video_mvit
463463
models/video_resnet
464464

465465
|

docs/source/models/video_mvitv2.rst renamed to docs/source/models/video_mvit.rst

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,15 @@ The MViT V2 model is based on the
1212
Model builders
1313
--------------
1414

15-
The following model builders can be used to instantiate a MViTV2 model, with or
15+
The following model builders can be used to instantiate a MViT model, with or
1616
without pre-trained weights. All the model builders internally rely on the
17-
``torchvision.models.video.MViTV2`` base class. Please refer to the `source
17+
``torchvision.models.video.MViT`` base class. Please refer to the `source
1818
code
19-
<https://github.com/pytorch/vision/blob/main/torchvision/models/video/mvitv2.py>`_ for
19+
<https://github.com/pytorch/vision/blob/main/torchvision/models/video/mvit.py>`_ for
2020
more details about this class.
2121

2222
.. autosummary::
2323
:toctree: generated/
2424
:template: function.rst
2525

26-
mvit_v2_t
27-
mvit_v2_s
28-
mvit_v2_b
26+
mvit_v1_b
-939 Bytes
Binary file not shown.
-939 Bytes
Binary file not shown.

test/test_models.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -309,15 +309,9 @@ def _check_input_backprop(model, inputs):
309309
"image_size": 56,
310310
"input_shape": (1, 3, 56, 56),
311311
},
312-
"mvit_v2_t": {
312+
"mvit_v1_b": {
313313
"input_shape": (1, 3, 16, 224, 224),
314314
},
315-
"mvit_v2_s": {
316-
"input_shape": (1, 3, 16, 224, 224),
317-
},
318-
"mvit_v2_b": {
319-
"input_shape": (1, 3, 32, 224, 224),
320-
},
321315
}
322316
# speeding up slow models:
323317
slow_models = [
@@ -347,7 +341,6 @@ def _check_input_backprop(model, inputs):
347341
skipped_big_models = {
348342
"vit_h_14",
349343
"regnet_y_128gf",
350-
"mvit_v2_b",
351344
}
352345

353346
# The following contains configuration and expected values to be used tests that are model specific

torchvision/models/video/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
from .mvitv2 import *
1+
from .mvit import *
22
from .resnet import *

torchvision/models/video/mvitv2.py renamed to torchvision/models/video/mvit.py

Lines changed: 29 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -13,18 +13,12 @@
1313

1414

1515
__all__ = [
16-
"MViTV2",
17-
"MViT_V2_T_Weights",
18-
"MViT_V2_S_Weights",
19-
"MViT_V2_B_Weights",
20-
"mvit_v2_t",
21-
"mvit_v2_s",
22-
"mvit_v2_b",
16+
"MViT",
17+
"MViT_V1_B_Weights",
18+
"mvit_v1_b",
2319
]
2420

2521

26-
# TODO: check if we should implement relative pos embedding (Section 4.1 in the paper). Ref:
27-
# https://github.com/facebookresearch/mvit/blob/main/mvit/models/attention.py#L45
2822
# TODO: add weights
2923
# TODO: test on references
3024

@@ -108,6 +102,7 @@ def __init__(
108102
kernel_kv: List[int],
109103
stride_q: List[int],
110104
stride_kv: List[int],
105+
residual_pool: bool,
111106
dropout: float = 0.0,
112107
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
113108
) -> None:
@@ -116,6 +111,7 @@ def __init__(
116111
self.num_heads = num_heads
117112
self.head_dim = embed_dim // num_heads
118113
self.scaler = 1.0 / math.sqrt(self.head_dim)
114+
self.residual_pool = residual_pool
119115

120116
self.qkv = nn.Linear(embed_dim, 3 * embed_dim)
121117
layers: List[nn.Module] = [nn.Linear(embed_dim, embed_dim)]
@@ -182,7 +178,9 @@ def forward(self, x: torch.Tensor, thw: Tuple[int, int, int]) -> Tuple[torch.Ten
182178
attn = torch.matmul(self.scaler * q, k.transpose(2, 3))
183179
attn = attn.softmax(dim=-1)
184180

185-
x = torch.matmul(attn, v).add_(q)
181+
x = torch.matmul(attn, v)
182+
if self.residual_pool:
183+
x.add_(q)
186184
x = x.transpose(1, 2).reshape(B, -1, C)
187185
x = self.project(x)
188186

@@ -199,6 +197,7 @@ def __init__(
199197
kernel_kv: List[int],
200198
stride_q: List[int],
201199
stride_kv: List[int],
200+
residual_pool: bool,
202201
dropout: float = 0.0,
203202
stochastic_depth_prob: float = 0.0,
204203
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
@@ -224,6 +223,7 @@ def __init__(
224223
kernel_kv=kernel_kv,
225224
stride_q=stride_q,
226225
stride_kv=stride_kv,
226+
residual_pool=residual_pool,
227227
dropout=dropout,
228228
norm_layer=norm_layer,
229229
)
@@ -274,7 +274,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
274274
return torch.cat((class_token, x), dim=1).add_(pos_embedding)
275275

276276

277-
class MViTV2(nn.Module):
277+
class MViT(nn.Module):
278278
def __init__(
279279
self,
280280
spatial_size: Tuple[int, int],
@@ -285,6 +285,7 @@ def __init__(
285285
pool_kv_stride: List[int],
286286
pool_q_stride: List[int],
287287
pool_kvq_kernel: List[int],
288+
residual_pool: bool,
288289
dropout: float = 0.0,
289290
attention_dropout: float = 0.0,
290291
stochastic_depth_prob: float = 0.0,
@@ -293,7 +294,7 @@ def __init__(
293294
norm_layer: Optional[Callable[..., nn.Module]] = None,
294295
) -> None:
295296
"""
296-
MViT V2 main class.
297+
MViT main class.
297298
298299
Args:
299300
spatial_size (tuple of ints): The spacial size of the input as ``(H, W)``.
@@ -374,6 +375,7 @@ def __init__(
374375
kernel_kv=pool_kvq_kernel,
375376
stride_q=stride_q,
376377
stride_kv=stride_kv,
378+
residual_pool=residual_pool,
377379
dropout=attention_dropout,
378380
stochastic_depth_prob=sd_prob,
379381
norm_layer=norm_layer,
@@ -426,15 +428,15 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
426428
return x
427429

428430

429-
def _mvitv2(
431+
def _mvit(
430432
embed_channels: List[int],
431433
blocks: List[int],
432434
heads: List[int],
433435
stochastic_depth_prob: float,
434436
weights: Optional[WeightsEnum],
435437
progress: bool,
436438
**kwargs: Any,
437-
) -> MViTV2:
439+
) -> MViT:
438440
if weights is not None:
439441
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
440442
assert weights.meta["min_size"][0] == weights.meta["min_size"][1]
@@ -443,7 +445,7 @@ def _mvitv2(
443445
spatial_size = kwargs.pop("spatial_size", (224, 224))
444446
temporal_size = kwargs.pop("temporal_size", 16)
445447

446-
model = MViTV2(
448+
model = MViT(
447449
spatial_size=spatial_size,
448450
temporal_size=temporal_size,
449451
embed_channels=embed_channels,
@@ -452,6 +454,7 @@ def _mvitv2(
452454
pool_kv_stride=kwargs.pop("pool_kv_stride", [1, 8, 8]),
453455
pool_q_stride=kwargs.pop("pool_q_stride", [1, 2, 2]),
454456
pool_kvq_kernel=kwargs.pop("pool_kvq_kernel", [3, 3, 3]),
457+
residual_pool=kwargs.pop("residual_pool", False),
455458
stochastic_depth_prob=stochastic_depth_prob,
456459
**kwargs,
457460
)
@@ -462,82 +465,34 @@ def _mvitv2(
462465
return model
463466

464467

465-
class MViT_V2_T_Weights(WeightsEnum):
468+
class MViT_V1_B_Weights(WeightsEnum):
466469
pass
467470

468471

469-
class MViT_V2_S_Weights(WeightsEnum):
470-
pass
471-
472-
473-
class MViT_V2_B_Weights(WeightsEnum):
474-
pass
475-
476-
477-
def mvit_v2_t(*, weights: Optional[MViT_V2_T_Weights] = None, progress: bool = True, **kwargs: Any) -> MViTV2:
478-
"""
479-
Constructs a tiny MViTV2 architecture from
480-
`MViTv2: Improved Multiscale Vision Transformers for Classification and Detection
481-
<https://arxiv.org/abs/2112.01526>`__ and `Multiscale Vision Transformers
482-
<https://arxiv.org/abs/2104.11227>`__.
483-
484-
Args:
485-
weights (:class:`~torchvision.models.video.MViT_V2_T_Weights`, optional): The
486-
pretrained weights to use. See
487-
:class:`~torchvision.models.video.MViT_V2_T_Weights` below for
488-
more details, and possible values. By default, no pre-trained
489-
weights are used.
490-
progress (bool, optional): If True, displays a progress bar of the
491-
download to stderr. Default is True.
492-
**kwargs: parameters passed to the ``torchvision.models.video.MViTV2``
493-
base class. Please refer to the `source code
494-
<https://github.com/pytorch/vision/blob/main/torchvision/models/video/mvitv2.py>`_
495-
for more details about this class.
496-
497-
.. autoclass:: torchvision.models.video.MViT_V2_T_Weights
498-
:members:
499-
"""
500-
weights = MViT_V2_T_Weights.verify(weights)
501-
502-
return _mvitv2(
503-
spatial_size=(224, 224),
504-
temporal_size=16,
505-
embed_channels=[96, 192, 384, 768],
506-
blocks=[1, 2, 5, 2],
507-
heads=[1, 2, 4, 8],
508-
stochastic_depth_prob=kwargs.pop("stochastic_depth_prob", 0.1),
509-
weights=weights,
510-
progress=progress,
511-
**kwargs,
512-
)
513-
514-
515-
def mvit_v2_s(*, weights: Optional[MViT_V2_S_Weights] = None, progress: bool = True, **kwargs: Any) -> MViTV2:
472+
def mvit_v1_b(*, weights: Optional[MViT_V1_B_Weights] = None, progress: bool = True, **kwargs: Any) -> MViT:
516473
"""
517-
Constructs a small MViTV2 architecture from
518-
`MViTv2: Improved Multiscale Vision Transformers for Classification and Detection
519-
<https://arxiv.org/abs/2112.01526>`__ and `Multiscale Vision Transformers
520-
<https://arxiv.org/abs/2104.11227>`__.
474+
Constructs a base MViT-B architecture from
475+
`Multiscale Vision Transformers <https://arxiv.org/abs/2104.11227>`__.
521476
522477
Args:
523-
weights (:class:`~torchvision.models.video.MViT_V2_S_Weights`, optional): The
478+
weights (:class:`~torchvision.models.video.MViT_V1_B_Weights`, optional): The
524479
pretrained weights to use. See
525-
:class:`~torchvision.models.video.MViT_V2_S_Weights` below for
480+
:class:`~torchvision.models.video.MViT_V1_B_Weights` below for
526481
more details, and possible values. By default, no pre-trained
527482
weights are used.
528483
progress (bool, optional): If True, displays a progress bar of the
529484
download to stderr. Default is True.
530-
**kwargs: parameters passed to the ``torchvision.models.video.MViTV2``
485+
**kwargs: parameters passed to the ``torchvision.models.video.MViT``
531486
base class. Please refer to the `source code
532-
<https://github.com/pytorch/vision/blob/main/torchvision/models/video/mvitv2.py>`_
487+
<https://github.com/pytorch/vision/blob/main/torchvision/models/video/mvit.py>`_
533488
for more details about this class.
534489
535-
.. autoclass:: torchvision.models.video.MViT_V2_S_Weights
490+
.. autoclass:: torchvision.models.video.MViT_V1_B_Weights
536491
:members:
537492
"""
538-
weights = MViT_V2_S_Weights.verify(weights)
493+
weights = MViT_V1_B_Weights.verify(weights)
539494

540-
return _mvitv2(
495+
return _mvit(
541496
spatial_size=(224, 224),
542497
temporal_size=16,
543498
embed_channels=[96, 192, 384, 768],
@@ -548,41 +503,3 @@ def mvit_v2_s(*, weights: Optional[MViT_V2_S_Weights] = None, progress: bool = T
548503
progress=progress,
549504
**kwargs,
550505
)
551-
552-
553-
def mvit_v2_b(*, weights: Optional[MViT_V2_B_Weights] = None, progress: bool = True, **kwargs: Any) -> MViTV2:
554-
"""
555-
Constructs a base MViTV2 architecture from
556-
`MViTv2: Improved Multiscale Vision Transformers for Classification and Detection
557-
<https://arxiv.org/abs/2112.01526>`__ and `Multiscale Vision Transformers
558-
<https://arxiv.org/abs/2104.11227>`__.
559-
560-
Args:
561-
weights (:class:`~torchvision.models.video.MViT_V2_B_Weights`, optional): The
562-
pretrained weights to use. See
563-
:class:`~torchvision.models.video.MViT_V2_B_Weights` below for
564-
more details, and possible values. By default, no pre-trained
565-
weights are used.
566-
progress (bool, optional): If True, displays a progress bar of the
567-
download to stderr. Default is True.
568-
**kwargs: parameters passed to the ``torchvision.models.video.MViTV2``
569-
base class. Please refer to the `source code
570-
<https://github.com/pytorch/vision/blob/main/torchvision/models/video/mvitv2.py>`_
571-
for more details about this class.
572-
573-
.. autoclass:: torchvision.models.video.MViT_V2_B_Weights
574-
:members:
575-
"""
576-
weights = MViT_V2_B_Weights.verify(weights)
577-
578-
return _mvitv2(
579-
spatial_size=(224, 224),
580-
temporal_size=32,
581-
embed_channels=[96, 192, 384, 768],
582-
blocks=[2, 3, 16, 3],
583-
heads=[1, 2, 4, 8],
584-
stochastic_depth_prob=kwargs.pop("stochastic_depth_prob", 0.3),
585-
weights=weights,
586-
progress=progress,
587-
**kwargs,
588-
)

0 commit comments

Comments
 (0)