Skip to content

Commit 78c6805

Browse files
Vincent Moensfacebook-github-bot
authored andcommitted
[fbsync] Add raft builders and presets in prototypes (#5043)
Summary: Co-authored-by: Vasilis Vryniotis <[email protected]> Reviewed By: NicolasHug Differential Revision: D32950935 fbshipit-source-id: b936d2879243f7d5b8e05f329ebdfd074f5af05b
1 parent 683167f commit 78c6805

File tree

7 files changed

+224
-7
lines changed

7 files changed

+224
-7
lines changed

test/test_prototype_models.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import pytest
44
import test_models as TM
55
import torch
6-
from common_utils import cpu_and_gpu, run_on_env_var
6+
from common_utils import cpu_and_gpu, run_on_env_var, needs_cuda
77
from torchvision.prototype import models
88
from torchvision.prototype.models._api import WeightsEnum, Weights
99
from torchvision.prototype.models._utils import handle_legacy_interface
@@ -75,10 +75,12 @@ def test_get_weight(name, weight):
7575
+ TM.get_models_from_module(models.detection)
7676
+ TM.get_models_from_module(models.quantization)
7777
+ TM.get_models_from_module(models.segmentation)
78-
+ TM.get_models_from_module(models.video),
78+
+ TM.get_models_from_module(models.video)
79+
+ TM.get_models_from_module(models.optical_flow),
7980
)
8081
def test_naming_conventions(model_fn):
8182
weights_enum = _get_model_weights(model_fn)
83+
print(weights_enum)
8284
assert weights_enum is not None
8385
assert len(weights_enum) == 0 or hasattr(weights_enum, "default")
8486

@@ -149,13 +151,22 @@ def test_video_model(model_fn, dev):
149151
TM.test_video_model(model_fn, dev)
150152

151153

154+
@needs_cuda
155+
@pytest.mark.parametrize("model_builder", TM.get_models_from_module(models.optical_flow))
156+
@pytest.mark.parametrize("scripted", (False, True))
157+
@run_if_test_with_prototype
158+
def test_raft(model_builder, scripted):
159+
TM.test_raft(model_builder, scripted)
160+
161+
152162
@pytest.mark.parametrize(
153163
"model_fn",
154164
TM.get_models_from_module(models)
155165
+ TM.get_models_from_module(models.detection)
156166
+ TM.get_models_from_module(models.quantization)
157167
+ TM.get_models_from_module(models.segmentation)
158-
+ TM.get_models_from_module(models.video),
168+
+ TM.get_models_from_module(models.video)
169+
+ TM.get_models_from_module(models.optical_flow),
159170
)
160171
@pytest.mark.parametrize("dev", cpu_and_gpu())
161172
@run_if_test_with_prototype
@@ -177,6 +188,9 @@ def test_old_vs_new_factory(model_fn, dev):
177188
"video": {
178189
"input_shape": (1, 3, 4, 112, 112),
179190
},
191+
"optical_flow": {
192+
"input_shape": (1, 3, 128, 128),
193+
},
180194
}
181195
model_name = model_fn.__name__
182196
module_name = model_fn.__module__.split(".")[-2]

torchvision/models/optical_flow/raft.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -585,7 +585,7 @@ def raft_large(*, pretrained=False, progress=True, **kwargs):
585585
"""
586586

587587
if pretrained:
588-
raise ValueError("Pretrained weights aren't available yet")
588+
raise ValueError("No checkpoint is available for raft_large")
589589

590590
return _raft(
591591
# Feature encoder
@@ -631,7 +631,7 @@ def raft_small(*, pretrained=False, progress=True, **kwargs):
631631
"""
632632

633633
if pretrained:
634-
raise ValueError("Pretrained weights aren't available yet")
634+
raise ValueError("No checkpoint is available for raft_small")
635635

636636
return _raft(
637637
# Feature encoder

torchvision/prototype/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from .vgg import *
1313
from .vision_transformer import *
1414
from . import detection
15+
from . import optical_flow
1516
from . import quantization
1617
from . import segmentation
1718
from . import video
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .raft import RAFT, raft_large, raft_small, Raft_Large_Weights, Raft_Small_Weights
Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
from typing import Optional
2+
3+
from torch.nn.modules.batchnorm import BatchNorm2d
4+
from torch.nn.modules.instancenorm import InstanceNorm2d
5+
from torchvision.models.optical_flow import RAFT
6+
from torchvision.models.optical_flow.raft import _raft, BottleneckBlock, ResidualBlock
7+
8+
# from torchvision.prototype.transforms import RaftEval
9+
10+
from .._api import WeightsEnum
11+
12+
# from .._api import Weights
13+
from .._utils import handle_legacy_interface
14+
15+
16+
__all__ = (
17+
"RAFT",
18+
"raft_large",
19+
"raft_small",
20+
"Raft_Large_Weights",
21+
"Raft_Small_Weights",
22+
)
23+
24+
25+
class Raft_Large_Weights(WeightsEnum):
26+
pass
27+
# C_T_V1 = Weights(
28+
# # Chairs + Things
29+
# url="",
30+
# transforms=RaftEval,
31+
# meta={
32+
# "recipe": "",
33+
# "epe": -1234,
34+
# },
35+
# )
36+
37+
# C_T_SKHT_V1 = Weights(
38+
# # Chairs + Things + Sintel fine-tuning, i.e.:
39+
# # Chairs + Things + (Sintel + Kitti + HD1K + Things_clean)
40+
# # Corresponds to the C+T+S+K+H on paper with fine-tuning on Sintel
41+
# url="",
42+
# transforms=RaftEval,
43+
# meta={
44+
# "recipe": "",
45+
# "epe": -1234,
46+
# },
47+
# )
48+
49+
# C_T_SKHT_K_V1 = Weights(
50+
# # Chairs + Things + Sintel fine-tuning + Kitti fine-tuning i.e.:
51+
# # Chairs + Things + (Sintel + Kitti + HD1K + Things_clean) + Kitti
52+
# # Same as CT_SKHT with extra fine-tuning on Kitti
53+
# # Corresponds to the C+T+S+K+H on paper with fine-tuning on Sintel and then on Kitti
54+
# url="",
55+
# transforms=RaftEval,
56+
# meta={
57+
# "recipe": "",
58+
# "epe": -1234,
59+
# },
60+
# )
61+
62+
# default = C_T_V1
63+
64+
65+
class Raft_Small_Weights(WeightsEnum):
66+
pass
67+
# C_T_V1 = Weights(
68+
# url="", # TODO
69+
# transforms=RaftEval,
70+
# meta={
71+
# "recipe": "",
72+
# "epe": -1234,
73+
# },
74+
# )
75+
# default = C_T_V1
76+
77+
78+
@handle_legacy_interface(weights=("pretrained", None))
79+
def raft_large(*, weights: Optional[Raft_Large_Weights] = None, progress=True, **kwargs):
80+
"""RAFT model from
81+
`RAFT: Recurrent All Pairs Field Transforms for Optical Flow <https://arxiv.org/abs/2003.12039>`_.
82+
83+
Args:
84+
weights(Raft_Large_weights, optinal): TODO not implemented yet
85+
progress (bool): If True, displays a progress bar of the download to stderr
86+
kwargs (dict): Parameters that will be passed to the :class:`~torchvision.models.optical_flow.RAFT` class
87+
to override any default.
88+
89+
Returns:
90+
nn.Module: The model.
91+
"""
92+
93+
weights = Raft_Large_Weights.verify(weights)
94+
95+
return _raft(
96+
# Feature encoder
97+
feature_encoder_layers=(64, 64, 96, 128, 256),
98+
feature_encoder_block=ResidualBlock,
99+
feature_encoder_norm_layer=InstanceNorm2d,
100+
# Context encoder
101+
context_encoder_layers=(64, 64, 96, 128, 256),
102+
context_encoder_block=ResidualBlock,
103+
context_encoder_norm_layer=BatchNorm2d,
104+
# Correlation block
105+
corr_block_num_levels=4,
106+
corr_block_radius=4,
107+
# Motion encoder
108+
motion_encoder_corr_layers=(256, 192),
109+
motion_encoder_flow_layers=(128, 64),
110+
motion_encoder_out_channels=128,
111+
# Recurrent block
112+
recurrent_block_hidden_state_size=128,
113+
recurrent_block_kernel_size=((1, 5), (5, 1)),
114+
recurrent_block_padding=((0, 2), (2, 0)),
115+
# Flow head
116+
flow_head_hidden_size=256,
117+
# Mask predictor
118+
use_mask_predictor=True,
119+
**kwargs,
120+
)
121+
122+
123+
@handle_legacy_interface(weights=("pretrained", None))
124+
def raft_small(*, weights: Optional[Raft_Small_Weights] = None, progress=True, **kwargs):
125+
"""RAFT "small" model from
126+
`RAFT: Recurrent All Pairs Field Transforms for Optical Flow <https://arxiv.org/abs/2003.12039>`_.
127+
128+
Args:
129+
weights(Raft_Small_weights, optinal): TODO not implemented yet
130+
progress (bool): If True, displays a progress bar of the download to stderr
131+
kwargs (dict): Parameters that will be passed to the :class:`~torchvision.models.optical_flow.RAFT` class
132+
to override any default.
133+
134+
Returns:
135+
nn.Module: The model.
136+
137+
"""
138+
139+
weights = Raft_Small_Weights.verify(weights)
140+
141+
return _raft(
142+
# Feature encoder
143+
feature_encoder_layers=(32, 32, 64, 96, 128),
144+
feature_encoder_block=BottleneckBlock,
145+
feature_encoder_norm_layer=InstanceNorm2d,
146+
# Context encoder
147+
context_encoder_layers=(32, 32, 64, 96, 160),
148+
context_encoder_block=BottleneckBlock,
149+
context_encoder_norm_layer=None,
150+
# Correlation block
151+
corr_block_num_levels=4,
152+
corr_block_radius=3,
153+
# Motion encoder
154+
motion_encoder_corr_layers=(96,),
155+
motion_encoder_flow_layers=(64, 32),
156+
motion_encoder_out_channels=82,
157+
# Recurrent block
158+
recurrent_block_hidden_state_size=96,
159+
recurrent_block_kernel_size=(3,),
160+
recurrent_block_padding=(1,),
161+
# Flow head
162+
flow_head_hidden_size=128,
163+
# Mask predictor
164+
use_mask_predictor=False,
165+
**kwargs,
166+
)

torchvision/prototype/transforms/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,4 @@
33

44
from ._geometry import Resize, RandomResize, HorizontalFlip, Crop, CenterCrop, RandomCrop
55
from ._misc import Identity, Normalize
6-
from ._presets import CocoEval, ImageNetEval, VocEval, Kinect400Eval
6+
from ._presets import CocoEval, ImageNetEval, VocEval, Kinect400Eval, RaftEval

torchvision/prototype/transforms/_presets.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from ...transforms import functional as F, InterpolationMode
77

88

9-
__all__ = ["CocoEval", "ImageNetEval", "Kinect400Eval", "VocEval"]
9+
__all__ = ["CocoEval", "ImageNetEval", "Kinect400Eval", "VocEval", "RaftEval"]
1010

1111

1212
class CocoEval(nn.Module):
@@ -97,3 +97,38 @@ def forward(self, img: Tensor, target: Optional[Tensor] = None) -> Tuple[Tensor,
9797
target = F.pil_to_tensor(target)
9898
target = target.squeeze(0).to(torch.int64)
9999
return img, target
100+
101+
102+
class RaftEval(nn.Module):
103+
def forward(
104+
self, img1: Tensor, img2: Tensor, flow: Optional[Tensor], valid_flow_mask: Optional[Tensor]
105+
) -> Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]]:
106+
107+
img1, img2, flow, valid_flow_mask = self._pil_or_numpy_to_tensor(img1, img2, flow, valid_flow_mask)
108+
109+
img1 = F.convert_image_dtype(img1, torch.float32)
110+
img2 = F.convert_image_dtype(img2, torch.float32)
111+
112+
# map [0, 1] into [-1, 1]
113+
img1 = F.normalize(img1, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
114+
img2 = F.normalize(img2, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
115+
116+
img1 = img1.contiguous()
117+
img2 = img2.contiguous()
118+
119+
return img1, img2, flow, valid_flow_mask
120+
121+
def _pil_or_numpy_to_tensor(
122+
self, img1: Tensor, img2: Tensor, flow: Optional[Tensor], valid_flow_mask: Optional[Tensor]
123+
) -> Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]]:
124+
if not isinstance(img1, Tensor):
125+
img1 = F.pil_to_tensor(img1)
126+
if not isinstance(img2, Tensor):
127+
img2 = F.pil_to_tensor(img2)
128+
129+
if flow is not None and not isinstance(flow, Tensor):
130+
flow = torch.from_numpy(flow)
131+
if valid_flow_mask is not None and not isinstance(valid_flow_mask, Tensor):
132+
valid_flow_mask = torch.from_numpy(valid_flow_mask)
133+
134+
return img1, img2, flow, valid_flow_mask

0 commit comments

Comments
 (0)