Skip to content

Commit 26f5318

Browse files
NicolasHugfacebook-github-bot
authored andcommitted
[fbsync] Add SSDlite architecture with MobileNetV3 backbones (#3757)
Summary: * Partial implementation of SSDlite. * Add normal init and BN hyperparams. * Refactor to keep JIT happy * Completed SSDlite. * Fix lint * Update todos * Add expected file in repo. * Use C4 expansion instead of C4 output. * Change scales formula for Default Boxes. * Add cosine annealing on trainer. * Make T_max count epochs. * Fix test and handle corner-case. * Add support of support width_mult * Add ssdlite presets. * Change ReLU6, [-1,1] rescaling, backbone init & no pretraining. * Use _reduced_tail=True. * Add sync BN support. * Adding the best config along with its weights and documentation. * Make mean/std configurable. * Fix not implemented for half exception Reviewed By: cpuhrsch Differential Revision: D28538769 fbshipit-source-id: df6c2e79b76e6d6297aa51ca0ff4535dc59eaf9b
1 parent 870bebf commit 26f5318

File tree

9 files changed

+277
-7
lines changed

9 files changed

+277
-7
lines changed

docs/source/models.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -427,6 +427,7 @@ Faster R-CNN MobileNetV3-Large FPN 32.8 - -
427427
Faster R-CNN MobileNetV3-Large 320 FPN 22.8 - -
428428
RetinaNet ResNet-50 FPN 36.4 - -
429429
SSD VGG16 25.1 - -
430+
SSDlite MobileNetV3-Large 21.3 - -
430431
Mask R-CNN ResNet-50 FPN 37.9 34.6 -
431432
====================================== ======= ======== ===========
432433

@@ -486,6 +487,7 @@ Faster R-CNN MobileNetV3-Large FPN 0.1020 0.0415
486487
Faster R-CNN MobileNetV3-Large 320 FPN 0.0978 0.0376 0.6
487488
RetinaNet ResNet-50 FPN 0.2514 0.0939 4.1
488489
SSD VGG16 0.2093 0.0744 1.5
490+
SSDlite MobileNetV3-Large 0.1773 0.0906 1.5
489491
Mask R-CNN ResNet-50 FPN 0.2728 0.0903 5.4
490492
Keypoint R-CNN ResNet-50 FPN 0.3789 0.1242 6.8
491493
====================================== =================== ================== ===========
@@ -511,6 +513,12 @@ SSD
511513
.. autofunction:: torchvision.models.detection.ssd300_vgg16
512514

513515

516+
SSDlite
517+
------------
518+
519+
.. autofunction:: torchvision.models.detection.ssdlite320_mobilenet_v3_large
520+
521+
514522
Mask R-CNN
515523
----------
516524

references/detection/README.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,14 @@ python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py\
5656
--weight-decay 0.0005 --data-augmentation ssd
5757
```
5858

59+
### SSDlite MobileNetV3-Large
60+
```
61+
python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py\
62+
--dataset coco --model ssdlite320_mobilenet_v3_large --epochs 660\
63+
--aspect-ratio-group-factor 3 --lr-scheduler cosineannealinglr --lr 0.15 --batch-size 24\
64+
--weight-decay 0.00004 --data-augmentation ssdlite
65+
```
66+
5967

6068
### Mask R-CNN
6169
```

references/detection/presets.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,12 @@ def __init__(self, data_augmentation, hflip_prob=0.5, mean=(123., 117., 104.)):
1616
T.RandomHorizontalFlip(p=hflip_prob),
1717
T.ToTensor(),
1818
])
19+
elif data_augmentation == 'ssdlite':
20+
self.transforms = T.Compose([
21+
T.RandomIoUCrop(),
22+
T.RandomHorizontalFlip(p=hflip_prob),
23+
T.ToTensor(),
24+
])
1925
else:
2026
raise ValueError(f'Unknown data augmentation policy "{data_augmentation}"')
2127

references/detection/train.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -73,9 +73,13 @@ def get_args_parser(add_help=True):
7373
parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
7474
metavar='W', help='weight decay (default: 1e-4)',
7575
dest='weight_decay')
76-
parser.add_argument('--lr-step-size', default=8, type=int, help='decrease lr every step-size epochs')
77-
parser.add_argument('--lr-steps', default=[16, 22], nargs='+', type=int, help='decrease lr every step-size epochs')
78-
parser.add_argument('--lr-gamma', default=0.1, type=float, help='decrease lr by a factor of lr-gamma')
76+
parser.add_argument('--lr-scheduler', default="multisteplr", help='the lr scheduler (default: multisteplr)')
77+
parser.add_argument('--lr-step-size', default=8, type=int,
78+
help='decrease lr every step-size epochs (multisteplr scheduler only)')
79+
parser.add_argument('--lr-steps', default=[16, 22], nargs='+', type=int,
80+
help='decrease lr every step-size epochs (multisteplr scheduler only)')
81+
parser.add_argument('--lr-gamma', default=0.1, type=float,
82+
help='decrease lr by a factor of lr-gamma (multisteplr scheduler only)')
7983
parser.add_argument('--print-freq', default=20, type=int, help='print frequency')
8084
parser.add_argument('--output-dir', default='.', help='path where to save')
8185
parser.add_argument('--resume', default='', help='resume from checkpoint')
@@ -85,6 +89,12 @@ def get_args_parser(add_help=True):
8589
parser.add_argument('--trainable-backbone-layers', default=None, type=int,
8690
help='number of trainable layers of backbone')
8791
parser.add_argument('--data-augmentation', default="hflip", help='data augmentation policy (default: hflip)')
92+
parser.add_argument(
93+
"--sync-bn",
94+
dest="sync_bn",
95+
help="Use sync batch norm",
96+
action="store_true",
97+
)
8898
parser.add_argument(
8999
"--test-only",
90100
dest="test_only",
@@ -156,6 +166,8 @@ def main(args):
156166
model = torchvision.models.detection.__dict__[args.model](num_classes=num_classes, pretrained=args.pretrained,
157167
**kwargs)
158168
model.to(device)
169+
if args.distributed and args.sync_bn:
170+
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
159171

160172
model_without_ddp = model
161173
if args.distributed:
@@ -166,8 +178,14 @@ def main(args):
166178
optimizer = torch.optim.SGD(
167179
params, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
168180

169-
# lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma)
170-
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.lr_steps, gamma=args.lr_gamma)
181+
args.lr_scheduler = args.lr_scheduler.lower()
182+
if args.lr_scheduler == 'multisteplr':
183+
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.lr_steps, gamma=args.lr_gamma)
184+
elif args.lr_scheduler == 'cosineannealinglr':
185+
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs)
186+
else:
187+
raise RuntimeError("Invalid lr scheduler '{}'. Only MultiStepLR and CosineAnnealingLR "
188+
"are supported.".format(args.lr_scheduler))
171189

172190
if args.resume:
173191
checkpoint = torch.load(args.resume, map_location='cpu')
Binary file not shown.

test/test_models.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ def get_available_video_models():
4646
"keypointrcnn_resnet50_fpn": lambda x: x[1],
4747
"retinanet_resnet50_fpn": lambda x: x[1],
4848
"ssd300_vgg16": lambda x: x[1],
49+
"ssdlite320_mobilenet_v3_large": lambda x: x[1],
4950
}
5051

5152

torchvision/models/detection/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@
33
from .keypoint_rcnn import *
44
from .retinanet import *
55
from .ssd import *
6+
from .ssdlite import *

torchvision/models/detection/anchor_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -206,8 +206,8 @@ def _grid_default_boxes(self, grid_sizes: List[List[int]], image_size: List[int]
206206
else:
207207
y_f_k, x_f_k = f_k
208208

209-
shifts_x = (torch.arange(0, f_k[1], dtype=dtype) + 0.5) / x_f_k
210-
shifts_y = (torch.arange(0, f_k[0], dtype=dtype) + 0.5) / y_f_k
209+
shifts_x = ((torch.arange(0, f_k[1]) + 0.5) / x_f_k).to(dtype=dtype)
210+
shifts_y = ((torch.arange(0, f_k[0]) + 0.5) / y_f_k).to(dtype=dtype)
211211
shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x)
212212
shift_x = shift_x.reshape(-1)
213213
shift_y = shift_y.reshape(-1)
Lines changed: 228 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,228 @@
1+
import torch
2+
3+
from collections import OrderedDict
4+
from functools import partial
5+
from torch import nn, Tensor
6+
from typing import Any, Callable, Dict, List, Optional, Tuple
7+
8+
from . import _utils as det_utils
9+
from .ssd import SSD, SSDScoringHead
10+
from .anchor_utils import DefaultBoxGenerator
11+
from .backbone_utils import _validate_trainable_layers
12+
from .. import mobilenet
13+
from ..mobilenetv3 import ConvBNActivation
14+
from ..utils import load_state_dict_from_url
15+
16+
17+
__all__ = ['ssdlite320_mobilenet_v3_large']
18+
19+
model_urls = {
20+
'ssdlite320_mobilenet_v3_large_coco':
21+
'https://download.pytorch.org/models/ssdlite320_mobilenet_v3_large_coco-a79551df.pth'
22+
}
23+
24+
25+
def _prediction_block(in_channels: int, out_channels: int, kernel_size: int,
26+
norm_layer: Callable[..., nn.Module]) -> nn.Sequential:
27+
return nn.Sequential(
28+
# 3x3 depthwise with stride 1 and padding 1
29+
ConvBNActivation(in_channels, in_channels, kernel_size=kernel_size, groups=in_channels,
30+
norm_layer=norm_layer, activation_layer=nn.ReLU6),
31+
32+
# 1x1 projetion to output channels
33+
nn.Conv2d(in_channels, out_channels, 1)
34+
)
35+
36+
37+
def _extra_block(in_channels: int, out_channels: int, norm_layer: Callable[..., nn.Module]) -> nn.Sequential:
38+
activation = nn.ReLU6
39+
intermediate_channels = out_channels // 2
40+
return nn.Sequential(
41+
# 1x1 projection to half output channels
42+
ConvBNActivation(in_channels, intermediate_channels, kernel_size=1,
43+
norm_layer=norm_layer, activation_layer=activation),
44+
45+
# 3x3 depthwise with stride 2 and padding 1
46+
ConvBNActivation(intermediate_channels, intermediate_channels, kernel_size=3, stride=2,
47+
groups=intermediate_channels, norm_layer=norm_layer, activation_layer=activation),
48+
49+
# 1x1 projetion to output channels
50+
ConvBNActivation(intermediate_channels, out_channels, kernel_size=1,
51+
norm_layer=norm_layer, activation_layer=activation),
52+
)
53+
54+
55+
def _normal_init(conv: nn.Module):
56+
for layer in conv.modules():
57+
if isinstance(layer, nn.Conv2d):
58+
torch.nn.init.normal_(layer.weight, mean=0.0, std=0.03)
59+
if layer.bias is not None:
60+
torch.nn.init.constant_(layer.bias, 0.0)
61+
62+
63+
class SSDLiteHead(nn.Module):
64+
def __init__(self, in_channels: List[int], num_anchors: List[int], num_classes: int,
65+
norm_layer: Callable[..., nn.Module]):
66+
super().__init__()
67+
self.classification_head = SSDLiteClassificationHead(in_channels, num_anchors, num_classes, norm_layer)
68+
self.regression_head = SSDLiteRegressionHead(in_channels, num_anchors, norm_layer)
69+
70+
def forward(self, x: List[Tensor]) -> Dict[str, Tensor]:
71+
return {
72+
'bbox_regression': self.regression_head(x),
73+
'cls_logits': self.classification_head(x),
74+
}
75+
76+
77+
class SSDLiteClassificationHead(SSDScoringHead):
78+
def __init__(self, in_channels: List[int], num_anchors: List[int], num_classes: int,
79+
norm_layer: Callable[..., nn.Module]):
80+
cls_logits = nn.ModuleList()
81+
for channels, anchors in zip(in_channels, num_anchors):
82+
cls_logits.append(_prediction_block(channels, num_classes * anchors, 3, norm_layer))
83+
_normal_init(cls_logits)
84+
super().__init__(cls_logits, num_classes)
85+
86+
87+
class SSDLiteRegressionHead(SSDScoringHead):
88+
def __init__(self, in_channels: List[int], num_anchors: List[int], norm_layer: Callable[..., nn.Module]):
89+
bbox_reg = nn.ModuleList()
90+
for channels, anchors in zip(in_channels, num_anchors):
91+
bbox_reg.append(_prediction_block(channels, 4 * anchors, 3, norm_layer))
92+
_normal_init(bbox_reg)
93+
super().__init__(bbox_reg, 4)
94+
95+
96+
class SSDLiteFeatureExtractorMobileNet(nn.Module):
97+
def __init__(self, backbone: nn.Module, c4_pos: int, norm_layer: Callable[..., nn.Module], rescaling: bool,
98+
**kwargs: Any):
99+
super().__init__()
100+
# non-public config parameters
101+
min_depth = kwargs.pop('_min_depth', 16)
102+
width_mult = kwargs.pop('_width_mult', 1.0)
103+
104+
assert not backbone[c4_pos].use_res_connect
105+
self.features = nn.Sequential(
106+
nn.Sequential(*backbone[:c4_pos], backbone[c4_pos].block[0]), # from start until C4 expansion layer
107+
nn.Sequential(backbone[c4_pos].block[1:], *backbone[c4_pos + 1:]), # from C4 depthwise until end
108+
)
109+
110+
get_depth = lambda d: max(min_depth, int(d * width_mult)) # noqa: E731
111+
extra = nn.ModuleList([
112+
_extra_block(backbone[-1].out_channels, get_depth(512), norm_layer),
113+
_extra_block(get_depth(512), get_depth(256), norm_layer),
114+
_extra_block(get_depth(256), get_depth(256), norm_layer),
115+
_extra_block(get_depth(256), get_depth(128), norm_layer),
116+
])
117+
_normal_init(extra)
118+
119+
self.extra = extra
120+
self.rescaling = rescaling
121+
122+
def forward(self, x: Tensor) -> Dict[str, Tensor]:
123+
# Rescale from [0, 1] to [-1, -1]
124+
if self.rescaling:
125+
x = 2.0 * x - 1.0
126+
127+
# Get feature maps from backbone and extra. Can't be refactored due to JIT limitations.
128+
output = []
129+
for block in self.features:
130+
x = block(x)
131+
output.append(x)
132+
133+
for block in self.extra:
134+
x = block(x)
135+
output.append(x)
136+
137+
return OrderedDict([(str(i), v) for i, v in enumerate(output)])
138+
139+
140+
def _mobilenet_extractor(backbone_name: str, progress: bool, pretrained: bool, trainable_layers: int,
141+
norm_layer: Callable[..., nn.Module], rescaling: bool, **kwargs: Any):
142+
backbone = mobilenet.__dict__[backbone_name](pretrained=pretrained, progress=progress,
143+
norm_layer=norm_layer, **kwargs).features
144+
if not pretrained:
145+
# Change the default initialization scheme if not pretrained
146+
_normal_init(backbone)
147+
148+
# Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks.
149+
# The first and last blocks are always included because they are the C0 (conv1) and Cn.
150+
stage_indices = [0] + [i for i, b in enumerate(backbone) if getattr(b, "_is_cn", False)] + [len(backbone) - 1]
151+
num_stages = len(stage_indices)
152+
153+
# find the index of the layer from which we wont freeze
154+
assert 0 <= trainable_layers <= num_stages
155+
freeze_before = num_stages if trainable_layers == 0 else stage_indices[num_stages - trainable_layers]
156+
157+
for b in backbone[:freeze_before]:
158+
for parameter in b.parameters():
159+
parameter.requires_grad_(False)
160+
161+
return SSDLiteFeatureExtractorMobileNet(backbone, stage_indices[-2], norm_layer, rescaling, **kwargs)
162+
163+
164+
def ssdlite320_mobilenet_v3_large(pretrained: bool = False, progress: bool = True, num_classes: int = 91,
165+
pretrained_backbone: bool = False, trainable_backbone_layers: Optional[int] = None,
166+
norm_layer: Optional[Callable[..., nn.Module]] = None,
167+
**kwargs: Any):
168+
"""
169+
Constructs an SSDlite model with a MobileNetV3 Large backbone. See `SSD` for more details.
170+
171+
Example:
172+
173+
>>> model = torchvision.models.detection.ssdlite320_mobilenet_v3_large(pretrained=True)
174+
>>> model.eval()
175+
>>> x = [torch.rand(3, 320, 320), torch.rand(3, 500, 400)]
176+
>>> predictions = model(x)
177+
178+
Args:
179+
norm_layer:
180+
**kwargs:
181+
pretrained (bool): If True, returns a model pre-trained on COCO train2017
182+
progress (bool): If True, displays a progress bar of the download to stderr
183+
num_classes (int): number of output classes of the model (including the background)
184+
pretrained_backbone (bool): If True, returns a model with backbone pre-trained on Imagenet
185+
trainable_backbone_layers (int): number of trainable (not frozen) resnet layers starting from final block.
186+
Valid values are between 0 and 6, with 6 meaning all backbone layers are trainable.
187+
norm_layer (callable, optional): Module specifying the normalization layer to use.
188+
"""
189+
trainable_backbone_layers = _validate_trainable_layers(
190+
pretrained or pretrained_backbone, trainable_backbone_layers, 6, 6)
191+
192+
if pretrained:
193+
pretrained_backbone = False
194+
195+
# Enable [-1, 1] rescaling and reduced tail if no pretrained backbone is selected
196+
rescaling = reduce_tail = not pretrained_backbone
197+
198+
if norm_layer is None:
199+
norm_layer = partial(nn.BatchNorm2d, eps=0.001, momentum=0.03)
200+
201+
backbone = _mobilenet_extractor("mobilenet_v3_large", progress, pretrained_backbone, trainable_backbone_layers,
202+
norm_layer, rescaling, _reduced_tail=reduce_tail, _width_mult=1.0)
203+
204+
size = (320, 320)
205+
anchor_generator = DefaultBoxGenerator([[2, 3] for _ in range(6)], min_ratio=0.2, max_ratio=0.95)
206+
out_channels = det_utils.retrieve_out_channels(backbone, size)
207+
num_anchors = anchor_generator.num_anchors_per_location()
208+
assert len(out_channels) == len(anchor_generator.aspect_ratios)
209+
210+
defaults = {
211+
"score_thresh": 0.001,
212+
"nms_thresh": 0.55,
213+
"detections_per_img": 300,
214+
"topk_candidates": 300,
215+
"image_mean": [0., 0., 0.],
216+
"image_std": [1., 1., 1.],
217+
}
218+
kwargs = {**defaults, **kwargs}
219+
model = SSD(backbone, anchor_generator, size, num_classes,
220+
head=SSDLiteHead(out_channels, num_anchors, num_classes, norm_layer), **kwargs)
221+
222+
if pretrained:
223+
weights_name = 'ssdlite320_mobilenet_v3_large_coco'
224+
if model_urls.get(weights_name, None) is None:
225+
raise ValueError("No checkpoint is available for model {}".format(weights_name))
226+
state_dict = load_state_dict_from_url(model_urls[weights_name], progress=progress)
227+
model.load_state_dict(state_dict)
228+
return model

0 commit comments

Comments
 (0)