Skip to content

Commit 496fb36

Browse files
prabhat00155facebook-github-bot
authored andcommitted
[fbsync] Adding multiweight support to FasterRCNN (#4847)
Summary: * Aligning exception with all other models. * Adding prototype preprocessing on video references. * Adding the rest of model builders on faster_rcnn. Reviewed By: kazhang Differential Revision: D32216664 fbshipit-source-id: 6998018eebe4381b9bff9e9290061d35e42faa21
1 parent 274b271 commit 496fb36

File tree

4 files changed

+191
-14
lines changed

4 files changed

+191
-14
lines changed

references/detection/train.py

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,12 @@
3333
from group_by_aspect_ratio import GroupedBatchSampler, create_aspect_ratio_groups
3434

3535

36+
try:
37+
from torchvision.prototype import models as PM
38+
except ImportError:
39+
PM = None
40+
41+
3642
def get_dataset(name, image_set, transform, data_path):
3743
paths = {"coco": (data_path, get_coco, 91), "coco_kp": (data_path, get_coco_kp, 2)}
3844
p, ds_fn, num_classes = paths[name]
@@ -41,8 +47,15 @@ def get_dataset(name, image_set, transform, data_path):
4147
return ds, num_classes
4248

4349

44-
def get_transform(train, data_augmentation):
45-
return presets.DetectionPresetTrain(data_augmentation) if train else presets.DetectionPresetEval()
50+
def get_transform(train, args):
51+
if train:
52+
return presets.DetectionPresetTrain(args.data_augmentation)
53+
elif not args.weights:
54+
return presets.DetectionPresetEval()
55+
else:
56+
fn = PM.detection.__dict__[args.model]
57+
weights = PM._api.get_weight(fn, args.weights)
58+
return weights.transforms()
4659

4760

4861
def get_args_parser(add_help=True):
@@ -128,6 +141,9 @@ def get_args_parser(add_help=True):
128141
parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes")
129142
parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training")
130143

144+
# Prototype models only
145+
parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")
146+
131147
return parser
132148

133149

@@ -143,10 +159,8 @@ def main(args):
143159
# Data loading code
144160
print("Loading data")
145161

146-
dataset, num_classes = get_dataset(
147-
args.dataset, "train", get_transform(True, args.data_augmentation), args.data_path
148-
)
149-
dataset_test, _ = get_dataset(args.dataset, "val", get_transform(False, args.data_augmentation), args.data_path)
162+
dataset, num_classes = get_dataset(args.dataset, "train", get_transform(True, args), args.data_path)
163+
dataset_test, _ = get_dataset(args.dataset, "val", get_transform(False, args), args.data_path)
150164

151165
print("Creating data loaders")
152166
if args.distributed:
@@ -175,9 +189,14 @@ def main(args):
175189
if "rcnn" in args.model:
176190
if args.rpn_score_thresh is not None:
177191
kwargs["rpn_score_thresh"] = args.rpn_score_thresh
178-
model = torchvision.models.detection.__dict__[args.model](
179-
num_classes=num_classes, pretrained=args.pretrained, **kwargs
180-
)
192+
if not args.weights:
193+
model = torchvision.models.detection.__dict__[args.model](
194+
pretrained=args.pretrained, num_classes=num_classes, **kwargs
195+
)
196+
else:
197+
if PM is None:
198+
raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.")
199+
model = PM.detection.__dict__[args.model](weights=args.weights, num_classes=num_classes, **kwargs)
181200
model.to(device)
182201
if args.distributed and args.sync_bn:
183202
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)

test/test_prototype_models.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,13 @@ def test_classification_model(model_fn, dev):
4848
TM.test_classification_model(model_fn, dev)
4949

5050

51+
@pytest.mark.parametrize("model_fn", TM.get_models_from_module(models.detection))
52+
@pytest.mark.parametrize("dev", cpu_and_gpu())
53+
@pytest.mark.skipif(os.getenv("PYTORCH_TEST_WITH_PROTOTYPE", "0") == "0", reason="Prototype code tests are disabled")
54+
def test_detection_model(model_fn, dev):
55+
TM.test_detection_model(model_fn, dev)
56+
57+
5158
@pytest.mark.parametrize("model_fn", TM.get_models_from_module(models.quantization))
5259
@pytest.mark.skipif(os.getenv("PYTORCH_TEST_WITH_PROTOTYPE", "0") == "0", reason="Prototype code tests are disabled")
5360
def test_quantized_classification_model(model_fn):
@@ -71,6 +78,7 @@ def test_video_model(model_fn, dev):
7178
@pytest.mark.parametrize(
7279
"model_fn, module_name",
7380
get_models_with_module_names(models)
81+
+ get_models_with_module_names(models.detection)
7482
+ get_models_with_module_names(models.quantization)
7583
+ get_models_with_module_names(models.segmentation)
7684
+ get_models_with_module_names(models.video),
@@ -82,6 +90,9 @@ def test_old_vs_new_factory(model_fn, module_name, dev):
8290
"models": {
8391
"input_shape": (1, 3, 224, 224),
8492
},
93+
"detection": {
94+
"input_shape": (3, 300, 300),
95+
},
8596
"quantization": {
8697
"input_shape": (1, 3, 224, 224),
8798
},
@@ -95,7 +106,10 @@ def test_old_vs_new_factory(model_fn, module_name, dev):
95106
model_name = model_fn.__name__
96107
kwargs = {"pretrained": True, **defaults[module_name], **TM._model_params.get(model_name, {})}
97108
input_shape = kwargs.pop("input_shape")
109+
kwargs.pop("num_classes", None) # ignore this as it's an incompatible speed optimization for pre-trained models
98110
x = torch.rand(input_shape).to(device=dev)
111+
if module_name == "detection":
112+
x = [x]
99113

100114
# compare with new model builder parameterized in the old fashion way
101115
model_old = _build_model(_get_original_model(model_fn), **kwargs).to(device=dev)

torchvision/models/shufflenetv2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ def _shufflenetv2(arch: str, pretrained: bool, progress: bool, *args: Any, **kwa
162162
if pretrained:
163163
model_url = model_urls[arch]
164164
if model_url is None:
165-
raise NotImplementedError(f"pretrained {arch} is not supported as of now")
165+
raise ValueError(f"No checkpoint is available for model type {arch}")
166166
else:
167167
state_dict = load_state_dict_from_url(model_url, progress=progress)
168168
model.load_state_dict(state_dict)

torchvision/prototype/models/detection/faster_rcnn.py

Lines changed: 148 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,72 @@
11
import warnings
2-
from typing import Any, Optional
2+
from typing import Any, Optional, Union
33

44
from ....models.detection.faster_rcnn import (
5-
_validate_trainable_layers,
5+
_mobilenet_extractor,
66
_resnet_fpn_extractor,
7+
_validate_trainable_layers,
8+
AnchorGenerator,
79
FasterRCNN,
810
misc_nn_ops,
911
overwrite_eps,
1012
)
1113
from ...transforms.presets import CocoEval
1214
from .._api import Weights, WeightEntry
1315
from .._meta import _COCO_CATEGORIES
16+
from ..mobilenetv3 import MobileNetV3LargeWeights, mobilenet_v3_large
1417
from ..resnet import ResNet50Weights, resnet50
1518

1619

17-
__all__ = ["FasterRCNN", "FasterRCNNResNet50FPNWeights", "fasterrcnn_resnet50_fpn"]
20+
__all__ = [
21+
"FasterRCNN",
22+
"FasterRCNNResNet50FPNWeights",
23+
"FasterRCNNMobileNetV3LargeFPNWeights",
24+
"FasterRCNNMobileNetV3Large320FPNWeights",
25+
"fasterrcnn_resnet50_fpn",
26+
"fasterrcnn_mobilenet_v3_large_fpn",
27+
"fasterrcnn_mobilenet_v3_large_320_fpn",
28+
]
29+
30+
31+
_common_meta = {"categories": _COCO_CATEGORIES}
1832

1933

2034
class FasterRCNNResNet50FPNWeights(Weights):
2135
Coco_RefV1 = WeightEntry(
2236
url="https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth",
2337
transforms=CocoEval,
2438
meta={
25-
"categories": _COCO_CATEGORIES,
39+
**_common_meta,
2640
"recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-resnet-50-fpn",
2741
"map": 37.0,
2842
},
2943
)
3044

3145

46+
class FasterRCNNMobileNetV3LargeFPNWeights(Weights):
47+
Coco_RefV1 = WeightEntry(
48+
url="https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_fpn-fb6a3cc7.pth",
49+
transforms=CocoEval,
50+
meta={
51+
**_common_meta,
52+
"recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-mobilenetv3-large-fpn",
53+
"map": 32.8,
54+
},
55+
)
56+
57+
58+
class FasterRCNNMobileNetV3Large320FPNWeights(Weights):
59+
Coco_RefV1 = WeightEntry(
60+
url="https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_320_fpn-907ea3f9.pth",
61+
transforms=CocoEval,
62+
meta={
63+
**_common_meta,
64+
"recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-mobilenetv3-large-320-fpn",
65+
"map": 22.8,
66+
},
67+
)
68+
69+
3270
def fasterrcnn_resnet50_fpn(
3371
weights: Optional[FasterRCNNResNet50FPNWeights] = None,
3472
weights_backbone: Optional[ResNet50Weights] = None,
@@ -64,3 +102,109 @@ def fasterrcnn_resnet50_fpn(
64102
overwrite_eps(model, 0.0)
65103

66104
return model
105+
106+
107+
def _fasterrcnn_mobilenet_v3_large_fpn(
108+
weights: Optional[Union[FasterRCNNMobileNetV3LargeFPNWeights, FasterRCNNMobileNetV3Large320FPNWeights]] = None,
109+
weights_backbone: Optional[MobileNetV3LargeWeights] = None,
110+
progress: bool = True,
111+
num_classes: int = 91,
112+
trainable_backbone_layers: Optional[int] = None,
113+
**kwargs: Any,
114+
) -> FasterRCNN:
115+
if weights is not None:
116+
weights_backbone = None
117+
num_classes = len(weights.meta["categories"])
118+
119+
trainable_backbone_layers = _validate_trainable_layers(
120+
weights is not None or weights_backbone is not None, trainable_backbone_layers, 6, 3
121+
)
122+
123+
backbone = mobilenet_v3_large(weights=weights_backbone, progress=progress, norm_layer=misc_nn_ops.FrozenBatchNorm2d)
124+
backbone = _mobilenet_extractor(backbone, True, trainable_backbone_layers)
125+
anchor_sizes = (
126+
(
127+
32,
128+
64,
129+
128,
130+
256,
131+
512,
132+
),
133+
) * 3
134+
aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)
135+
model = FasterRCNN(
136+
backbone, num_classes, rpn_anchor_generator=AnchorGenerator(anchor_sizes, aspect_ratios), **kwargs
137+
)
138+
139+
if weights is not None:
140+
model.load_state_dict(weights.state_dict(progress=progress))
141+
142+
return model
143+
144+
145+
def fasterrcnn_mobilenet_v3_large_fpn(
146+
weights: Optional[FasterRCNNMobileNetV3LargeFPNWeights] = None,
147+
weights_backbone: Optional[MobileNetV3LargeWeights] = None,
148+
progress: bool = True,
149+
num_classes: int = 91,
150+
trainable_backbone_layers: Optional[int] = None,
151+
**kwargs: Any,
152+
) -> FasterRCNN:
153+
if "pretrained" in kwargs:
154+
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
155+
weights = FasterRCNNMobileNetV3LargeFPNWeights.Coco_RefV1 if kwargs.pop("pretrained") else None
156+
weights = FasterRCNNMobileNetV3LargeFPNWeights.verify(weights)
157+
if "pretrained_backbone" in kwargs:
158+
warnings.warn("The argument pretrained_backbone is deprecated, please use weights_backbone instead.")
159+
weights_backbone = MobileNetV3LargeWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained_backbone") else None
160+
weights_backbone = MobileNetV3LargeWeights.verify(weights_backbone)
161+
162+
defaults = {
163+
"rpn_score_thresh": 0.05,
164+
}
165+
166+
kwargs = {**defaults, **kwargs}
167+
return _fasterrcnn_mobilenet_v3_large_fpn(
168+
weights,
169+
weights_backbone,
170+
progress,
171+
num_classes,
172+
trainable_backbone_layers,
173+
**kwargs,
174+
)
175+
176+
177+
def fasterrcnn_mobilenet_v3_large_320_fpn(
178+
weights: Optional[FasterRCNNMobileNetV3Large320FPNWeights] = None,
179+
weights_backbone: Optional[MobileNetV3LargeWeights] = None,
180+
progress: bool = True,
181+
num_classes: int = 91,
182+
trainable_backbone_layers: Optional[int] = None,
183+
**kwargs: Any,
184+
) -> FasterRCNN:
185+
if "pretrained" in kwargs:
186+
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
187+
weights = FasterRCNNMobileNetV3Large320FPNWeights.Coco_RefV1 if kwargs.pop("pretrained") else None
188+
weights = FasterRCNNMobileNetV3Large320FPNWeights.verify(weights)
189+
if "pretrained_backbone" in kwargs:
190+
warnings.warn("The argument pretrained_backbone is deprecated, please use weights_backbone instead.")
191+
weights_backbone = MobileNetV3LargeWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained_backbone") else None
192+
weights_backbone = MobileNetV3LargeWeights.verify(weights_backbone)
193+
194+
defaults = {
195+
"min_size": 320,
196+
"max_size": 640,
197+
"rpn_pre_nms_top_n_test": 150,
198+
"rpn_post_nms_top_n_test": 150,
199+
"rpn_score_thresh": 0.05,
200+
}
201+
202+
kwargs = {**defaults, **kwargs}
203+
return _fasterrcnn_mobilenet_v3_large_fpn(
204+
weights,
205+
weights_backbone,
206+
progress,
207+
num_classes,
208+
trainable_backbone_layers,
209+
**kwargs,
210+
)

0 commit comments

Comments
 (0)