Skip to content

Commit 3e7683f

Browse files
authored
Add support to MViT v1 (#6179)
* Switch implementation to v1 variant. * Fix docs * Adding back a v2 pseudovariant * Changing the way the network are configured. * Temporarily removing v2 * Adding weights. * Expand _squeeze/_unsqueeze to support arbitrary dims. * Update references script. * Fix tests. * Fixing frames and preprocessing. * Fix std/mean values in transforms. * Add permanent Dropout and update the weights. * Update accuracies.
1 parent c603159 commit 3e7683f

File tree

9 files changed

+286
-240
lines changed

9 files changed

+286
-240
lines changed

docs/source/models.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -465,7 +465,7 @@ pre-trained weights:
465465
.. toctree::
466466
:maxdepth: 1
467467

468-
models/video_mvitv2
468+
models/video_mvit
469469
models/video_resnet
470470

471471
|

docs/source/models/video_mvitv2.rst renamed to docs/source/models/video_mvit.rst

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,15 @@ The MViT V2 model is based on the
1212
Model builders
1313
--------------
1414

15-
The following model builders can be used to instantiate a MViTV2 model, with or
15+
The following model builders can be used to instantiate a MViT model, with or
1616
without pre-trained weights. All the model builders internally rely on the
17-
``torchvision.models.video.MViTV2`` base class. Please refer to the `source
17+
``torchvision.models.video.MViT`` base class. Please refer to the `source
1818
code
19-
<https://github.com/pytorch/vision/blob/main/torchvision/models/video/mvitv2.py>`_ for
19+
<https://github.com/pytorch/vision/blob/main/torchvision/models/video/mvit.py>`_ for
2020
more details about this class.
2121

2222
.. autosummary::
2323
:toctree: generated/
2424
:template: function.rst
2525

26-
mvit_v2_t
27-
mvit_v2_s
28-
mvit_v2_b
26+
mvit_v1_b

references/video_classification/train.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ def main(args):
152152
split="train",
153153
step_between_clips=1,
154154
transform=transform_train,
155-
frame_rate=15,
155+
frame_rate=args.frame_rate,
156156
extensions=(
157157
"avi",
158158
"mp4",
@@ -189,7 +189,7 @@ def main(args):
189189
split="val",
190190
step_between_clips=1,
191191
transform=transform_test,
192-
frame_rate=15,
192+
frame_rate=args.frame_rate,
193193
extensions=(
194194
"avi",
195195
"mp4",
@@ -324,6 +324,7 @@ def parse_args():
324324
parser.add_argument("--model", default="r2plus1d_18", type=str, help="model name")
325325
parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu Default: cuda)")
326326
parser.add_argument("--clip-len", default=16, type=int, metavar="N", help="number of frames per clip")
327+
parser.add_argument("--frame-rate", default=15, type=int, metavar="N", help="the frame rate")
327328
parser.add_argument(
328329
"--clips-per-video", default=5, type=int, metavar="N", help="maximum number of clips per video to consider"
329330
)
-939 Bytes
Binary file not shown.

test/test_extended_models.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ def test_transforms_jit(model_fn):
181181
"input_shape": (1, 3, 520, 520),
182182
},
183183
"video": {
184-
"input_shape": (1, 4, 3, 112, 112),
184+
"input_shape": (1, 3, 4, 112, 112),
185185
},
186186
"optical_flow": {
187187
"input_shape": (1, 3, 128, 128),
@@ -195,6 +195,8 @@ def test_transforms_jit(model_fn):
195195
if module_name == "optical_flow":
196196
args = (x, x)
197197
else:
198+
if module_name == "video":
199+
x = x.permute(0, 2, 1, 3, 4)
198200
args = (x,)
199201

200202
problematic_weights = []

test/test_models.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -309,15 +309,9 @@ def _check_input_backprop(model, inputs):
309309
"image_size": 56,
310310
"input_shape": (1, 3, 56, 56),
311311
},
312-
"mvit_v2_t": {
312+
"mvit_v1_b": {
313313
"input_shape": (1, 3, 16, 224, 224),
314314
},
315-
"mvit_v2_s": {
316-
"input_shape": (1, 3, 16, 224, 224),
317-
},
318-
"mvit_v2_b": {
319-
"input_shape": (1, 3, 32, 224, 224),
320-
},
321315
}
322316
# speeding up slow models:
323317
slow_models = [
@@ -347,7 +341,6 @@ def _check_input_backprop(model, inputs):
347341
skipped_big_models = {
348342
"vit_h_14",
349343
"regnet_y_128gf",
350-
"mvit_v2_b",
351344
}
352345

353346
# The following contains configuration and expected values to be used tests that are model specific

torchvision/models/video/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
from .mvitv2 import *
1+
from .mvit import *
22
from .resnet import *

0 commit comments

Comments
 (0)