Skip to content

Commit 5b5987d

Browse files
YosuaMichaelfacebook-github-bot
authored andcommitted
[fbsync] Post-paper Detection Optimizations (#5444)
Summary: * Use frozen BN only if pre-trained. * Add LSJ and ability to from scratch training. * Fixing formatter * Adding `--opt` and `--norm-weight-decay` support in Detection. * Fix error message * Make ScaleJitter proportional. * Adding more norm layers in split_normalization_params. * Add FixedSizeCrop * Temporary fix for fill values on PIL * Fix the bug on fill. * Add RandomShortestSize. * Skip resize when an augmentation method is used. * multiscale in [480, 800] * Add missing star * Add new RetinaNet variant. * Add tests. * Update expected file for old retina * Fixing tests * Add FrozenBN to retinav2 * Fix network initialization issues * Adding BN support in MaskRCNNHeads and FPN * Adding support of FasterRCNNHeads * Introduce norm_layers in backbone utils. * Bigger RPN head + 2x rcnn v2 models. * Adding gIoU support to retinanet * Fix assert * Add back nesterov momentum * Rename and extend `FastRCNNConvFCHead` to support arbitrary FCs * Fix linter (Note: this ignores all push blocking failures!) Reviewed By: jdsgomes, NicolasHug Differential Revision: D36095683 fbshipit-source-id: 9105524308694ac8830ed12ba40286bb75c4aa8d
1 parent d14c03b commit 5b5987d

11 files changed

+563
-54
lines changed
Binary file not shown.
Binary file not shown.
Binary file not shown.

test/test_models.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,11 +195,14 @@ def _check_input_backprop(model, inputs):
195195
"googlenet": lambda x: x.logits,
196196
"inception_v3": lambda x: x.logits,
197197
"fasterrcnn_resnet50_fpn": lambda x: x[1],
198+
"fasterrcnn_resnet50_fpn_v2": lambda x: x[1],
198199
"fasterrcnn_mobilenet_v3_large_fpn": lambda x: x[1],
199200
"fasterrcnn_mobilenet_v3_large_320_fpn": lambda x: x[1],
200201
"maskrcnn_resnet50_fpn": lambda x: x[1],
202+
"maskrcnn_resnet50_fpn_v2": lambda x: x[1],
201203
"keypointrcnn_resnet50_fpn": lambda x: x[1],
202204
"retinanet_resnet50_fpn": lambda x: x[1],
205+
"retinanet_resnet50_fpn_v2": lambda x: x[1],
203206
"ssd300_vgg16": lambda x: x[1],
204207
"ssdlite320_mobilenet_v3_large": lambda x: x[1],
205208
"fcos_resnet50_fpn": lambda x: x[1],
@@ -227,6 +230,7 @@ def _check_input_backprop(model, inputs):
227230
"fcn_resnet101",
228231
"lraspp_mobilenet_v3_large",
229232
"maskrcnn_resnet50_fpn",
233+
"maskrcnn_resnet50_fpn_v2",
230234
)
231235

232236
# The tests for the following quantized models are flaky possibly due to inconsistent
@@ -246,6 +250,13 @@ def _check_input_backprop(model, inputs):
246250
"max_size": 224,
247251
"input_shape": (3, 224, 224),
248252
},
253+
"retinanet_resnet50_fpn_v2": {
254+
"num_classes": 20,
255+
"score_thresh": 0.01,
256+
"min_size": 224,
257+
"max_size": 224,
258+
"input_shape": (3, 224, 224),
259+
},
249260
"keypointrcnn_resnet50_fpn": {
250261
"num_classes": 2,
251262
"min_size": 224,
@@ -259,6 +270,12 @@ def _check_input_backprop(model, inputs):
259270
"max_size": 224,
260271
"input_shape": (3, 224, 224),
261272
},
273+
"fasterrcnn_resnet50_fpn_v2": {
274+
"num_classes": 20,
275+
"min_size": 224,
276+
"max_size": 224,
277+
"input_shape": (3, 224, 224),
278+
},
262279
"fcos_resnet50_fpn": {
263280
"num_classes": 2,
264281
"score_thresh": 0.05,
@@ -272,6 +289,12 @@ def _check_input_backprop(model, inputs):
272289
"max_size": 224,
273290
"input_shape": (3, 224, 224),
274291
},
292+
"maskrcnn_resnet50_fpn_v2": {
293+
"num_classes": 10,
294+
"min_size": 224,
295+
"max_size": 224,
296+
"input_shape": (3, 224, 224),
297+
},
275298
"fasterrcnn_mobilenet_v3_large_fpn": {
276299
"box_score_thresh": 0.02076,
277300
},
@@ -311,6 +334,10 @@ def _check_input_backprop(model, inputs):
311334
"max_trainable": 5,
312335
"n_trn_params_per_layer": [36, 46, 65, 78, 88, 89],
313336
},
337+
"retinanet_resnet50_fpn_v2": {
338+
"max_trainable": 5,
339+
"n_trn_params_per_layer": [44, 74, 131, 170, 200, 203],
340+
},
314341
"keypointrcnn_resnet50_fpn": {
315342
"max_trainable": 5,
316343
"n_trn_params_per_layer": [48, 58, 77, 90, 100, 101],
@@ -319,10 +346,18 @@ def _check_input_backprop(model, inputs):
319346
"max_trainable": 5,
320347
"n_trn_params_per_layer": [30, 40, 59, 72, 82, 83],
321348
},
349+
"fasterrcnn_resnet50_fpn_v2": {
350+
"max_trainable": 5,
351+
"n_trn_params_per_layer": [50, 80, 137, 176, 206, 209],
352+
},
322353
"maskrcnn_resnet50_fpn": {
323354
"max_trainable": 5,
324355
"n_trn_params_per_layer": [42, 52, 71, 84, 94, 95],
325356
},
357+
"maskrcnn_resnet50_fpn_v2": {
358+
"max_trainable": 5,
359+
"n_trn_params_per_layer": [66, 96, 153, 192, 222, 225],
360+
},
326361
"fasterrcnn_mobilenet_v3_large_fpn": {
327362
"max_trainable": 6,
328363
"n_trn_params_per_layer": [22, 23, 44, 70, 91, 97, 100],

torchvision/models/detection/_utils.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
import math
22
from collections import OrderedDict
3-
from typing import List, Tuple
3+
from typing import Dict, List, Optional, Tuple
44

55
import torch
66
from torch import Tensor, nn
7-
from torchvision.ops.misc import FrozenBatchNorm2d
7+
from torch.nn import functional as F
8+
from torchvision.ops import FrozenBatchNorm2d, generalized_box_iou_loss
89

910

1011
class BalancedPositiveNegativeSampler:
@@ -507,3 +508,26 @@ def _topk_min(input: Tensor, orig_kval: int, axis: int) -> int:
507508
axis_dim_val = torch._shape_as_tensor(input)[axis].unsqueeze(0)
508509
min_kval = torch.min(torch.cat((torch.tensor([orig_kval], dtype=axis_dim_val.dtype), axis_dim_val), 0))
509510
return _fake_cast_onnx(min_kval)
511+
512+
513+
def _box_loss(
514+
type: str,
515+
box_coder: BoxCoder,
516+
anchors_per_image: Tensor,
517+
matched_gt_boxes_per_image: Tensor,
518+
bbox_regression_per_image: Tensor,
519+
cnf: Optional[Dict[str, float]] = None,
520+
) -> Tensor:
521+
torch._assert(type in ["l1", "smooth_l1", "giou"], f"Unsupported loss: {type}")
522+
523+
if type == "l1":
524+
target_regression = box_coder.encode_single(matched_gt_boxes_per_image, anchors_per_image)
525+
return F.l1_loss(bbox_regression_per_image, target_regression, reduction="sum")
526+
elif type == "smooth_l1":
527+
target_regression = box_coder.encode_single(matched_gt_boxes_per_image, anchors_per_image)
528+
beta = cnf["beta"] if cnf is not None and "beta" in cnf else 1.0
529+
return F.smooth_l1_loss(bbox_regression_per_image, target_regression, reduction="sum", beta=beta)
530+
else: # giou
531+
bbox_per_image = box_coder.decode_single(bbox_regression_per_image, anchors_per_image)
532+
eps = cnf["eps"] if cnf is not None and "eps" in cnf else 1e-7
533+
return generalized_box_iou_loss(bbox_per_image, matched_gt_boxes_per_image, reduction="sum", eps=eps)

torchvision/models/detection/backbone_utils.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ class BackboneWithFPN(nn.Module):
2525
in_channels_list (List[int]): number of channels for each feature map
2626
that is returned, in the order they are present in the OrderedDict
2727
out_channels (int): number of channels in the FPN.
28+
norm_layer (callable, optional): Module specifying the normalization layer to use. Default: None
2829
Attributes:
2930
out_channels (int): the number of channels in the FPN
3031
"""
@@ -36,6 +37,7 @@ def __init__(
3637
in_channels_list: List[int],
3738
out_channels: int,
3839
extra_blocks: Optional[ExtraFPNBlock] = None,
40+
norm_layer: Optional[Callable[..., nn.Module]] = None,
3941
) -> None:
4042
super().__init__()
4143

@@ -47,6 +49,7 @@ def __init__(
4749
in_channels_list=in_channels_list,
4850
out_channels=out_channels,
4951
extra_blocks=extra_blocks,
52+
norm_layer=norm_layer,
5053
)
5154
self.out_channels = out_channels
5255

@@ -115,6 +118,7 @@ def _resnet_fpn_extractor(
115118
trainable_layers: int,
116119
returned_layers: Optional[List[int]] = None,
117120
extra_blocks: Optional[ExtraFPNBlock] = None,
121+
norm_layer: Optional[Callable[..., nn.Module]] = None,
118122
) -> BackboneWithFPN:
119123

120124
# select layers that wont be frozen
@@ -139,7 +143,9 @@ def _resnet_fpn_extractor(
139143
in_channels_stage2 = backbone.inplanes // 8
140144
in_channels_list = [in_channels_stage2 * 2 ** (i - 1) for i in returned_layers]
141145
out_channels = 256
142-
return BackboneWithFPN(backbone, return_layers, in_channels_list, out_channels, extra_blocks=extra_blocks)
146+
return BackboneWithFPN(
147+
backbone, return_layers, in_channels_list, out_channels, extra_blocks=extra_blocks, norm_layer=norm_layer
148+
)
143149

144150

145151
def _validate_trainable_layers(
@@ -194,6 +200,7 @@ def _mobilenet_extractor(
194200
trainable_layers: int,
195201
returned_layers: Optional[List[int]] = None,
196202
extra_blocks: Optional[ExtraFPNBlock] = None,
203+
norm_layer: Optional[Callable[..., nn.Module]] = None,
197204
) -> nn.Module:
198205
backbone = backbone.features
199206
# Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks.
@@ -222,7 +229,9 @@ def _mobilenet_extractor(
222229
return_layers = {f"{stage_indices[k]}": str(v) for v, k in enumerate(returned_layers)}
223230

224231
in_channels_list = [backbone[stage_indices[i]].out_channels for i in returned_layers]
225-
return BackboneWithFPN(backbone, return_layers, in_channels_list, out_channels, extra_blocks=extra_blocks)
232+
return BackboneWithFPN(
233+
backbone, return_layers, in_channels_list, out_channels, extra_blocks=extra_blocks, norm_layer=norm_layer
234+
)
226235
else:
227236
m = nn.Sequential(
228237
backbone,

torchvision/models/detection/faster_rcnn.py

Lines changed: 111 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Optional, Union
1+
from typing import Any, Callable, List, Optional, Tuple, Union
22

33
import torch
44
import torch.nn.functional as F
@@ -24,14 +24,22 @@
2424
__all__ = [
2525
"FasterRCNN",
2626
"FasterRCNN_ResNet50_FPN_Weights",
27+
"FasterRCNN_ResNet50_FPN_V2_Weights",
2728
"FasterRCNN_MobileNet_V3_Large_FPN_Weights",
2829
"FasterRCNN_MobileNet_V3_Large_320_FPN_Weights",
2930
"fasterrcnn_resnet50_fpn",
31+
"fasterrcnn_resnet50_fpn_v2",
3032
"fasterrcnn_mobilenet_v3_large_fpn",
3133
"fasterrcnn_mobilenet_v3_large_320_fpn",
3234
]
3335

3436

37+
def _default_anchorgen():
38+
anchor_sizes = ((32,), (64,), (128,), (256,), (512,))
39+
aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)
40+
return AnchorGenerator(anchor_sizes, aspect_ratios)
41+
42+
3543
class FasterRCNN(GeneralizedRCNN):
3644
"""
3745
Implements Faster R-CNN.
@@ -216,9 +224,7 @@ def __init__(
216224
out_channels = backbone.out_channels
217225

218226
if rpn_anchor_generator is None:
219-
anchor_sizes = ((32,), (64,), (128,), (256,), (512,))
220-
aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)
221-
rpn_anchor_generator = AnchorGenerator(anchor_sizes, aspect_ratios)
227+
rpn_anchor_generator = _default_anchorgen()
222228
if rpn_head is None:
223229
rpn_head = RPNHead(out_channels, rpn_anchor_generator.num_anchors_per_location()[0])
224230

@@ -298,6 +304,43 @@ def forward(self, x):
298304
return x
299305

300306

307+
class FastRCNNConvFCHead(nn.Sequential):
308+
def __init__(
309+
self,
310+
input_size: Tuple[int, int, int],
311+
conv_layers: List[int],
312+
fc_layers: List[int],
313+
norm_layer: Optional[Callable[..., nn.Module]] = None,
314+
):
315+
"""
316+
Args:
317+
input_size (Tuple[int, int, int]): the input size in CHW format.
318+
conv_layers (list): feature dimensions of each Convolution layer
319+
fc_layers (list): feature dimensions of each FCN layer
320+
norm_layer (callable, optional): Module specifying the normalization layer to use. Default: None
321+
"""
322+
in_channels, in_height, in_width = input_size
323+
324+
blocks = []
325+
previous_channels = in_channels
326+
for current_channels in conv_layers:
327+
blocks.append(misc_nn_ops.Conv2dNormActivation(previous_channels, current_channels, norm_layer=norm_layer))
328+
previous_channels = current_channels
329+
blocks.append(nn.Flatten())
330+
previous_channels = previous_channels * in_height * in_width
331+
for current_channels in fc_layers:
332+
blocks.append(nn.Linear(previous_channels, current_channels))
333+
blocks.append(nn.ReLU(inplace=True))
334+
previous_channels = current_channels
335+
336+
super().__init__(*blocks)
337+
for layer in self.modules():
338+
if isinstance(layer, nn.Conv2d):
339+
nn.init.kaiming_normal_(layer.weight, mode="fan_out", nonlinearity="relu")
340+
if layer.bias is not None:
341+
nn.init.zeros_(layer.bias)
342+
343+
301344
class FastRCNNPredictor(nn.Module):
302345
"""
303346
Standard classification + bounding box regression layers
@@ -349,6 +392,10 @@ class FasterRCNN_ResNet50_FPN_Weights(WeightsEnum):
349392
DEFAULT = COCO_V1
350393

351394

395+
class FasterRCNN_ResNet50_FPN_V2_Weights(WeightsEnum):
396+
pass
397+
398+
352399
class FasterRCNN_MobileNet_V3_Large_FPN_Weights(WeightsEnum):
353400
COCO_V1 = Weights(
354401
url="https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_fpn-fb6a3cc7.pth",
@@ -481,6 +528,66 @@ def fasterrcnn_resnet50_fpn(
481528
return model
482529

483530

531+
def fasterrcnn_resnet50_fpn_v2(
532+
*,
533+
weights: Optional[FasterRCNN_ResNet50_FPN_V2_Weights] = None,
534+
progress: bool = True,
535+
num_classes: Optional[int] = None,
536+
weights_backbone: Optional[ResNet50_Weights] = None,
537+
trainable_backbone_layers: Optional[int] = None,
538+
**kwargs: Any,
539+
) -> FasterRCNN:
540+
"""
541+
Constructs an improved Faster R-CNN model with a ResNet-50-FPN backbone.
542+
543+
Reference: `"Benchmarking Detection Transfer Learning with Vision Transformers"
544+
<https://arxiv.org/abs/2111.11429>`_.
545+
546+
:func:`~torchvision.models.detection.fasterrcnn_resnet50_fpn` for more details.
547+
548+
Args:
549+
weights (FasterRCNN_ResNet50_FPN_V2_Weights, optional): The pretrained weights for the model
550+
progress (bool): If True, displays a progress bar of the download to stderr
551+
num_classes (int, optional): number of output classes of the model (including the background)
552+
weights_backbone (ResNet50_Weights, optional): The pretrained weights for the backbone
553+
trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from final block.
554+
Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. If ``None`` is
555+
passed (the default) this value is set to 3.
556+
"""
557+
weights = FasterRCNN_ResNet50_FPN_V2_Weights.verify(weights)
558+
weights_backbone = ResNet50_Weights.verify(weights_backbone)
559+
560+
if weights is not None:
561+
weights_backbone = None
562+
num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"]))
563+
elif num_classes is None:
564+
num_classes = 91
565+
566+
is_trained = weights is not None or weights_backbone is not None
567+
trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
568+
569+
backbone = resnet50(weights=weights_backbone, progress=progress)
570+
backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers, norm_layer=nn.BatchNorm2d)
571+
rpn_anchor_generator = _default_anchorgen()
572+
rpn_head = RPNHead(backbone.out_channels, rpn_anchor_generator.num_anchors_per_location()[0], conv_depth=2)
573+
box_head = FastRCNNConvFCHead(
574+
(backbone.out_channels, 7, 7), [256, 256, 256, 256], [1024], norm_layer=nn.BatchNorm2d
575+
)
576+
model = FasterRCNN(
577+
backbone,
578+
num_classes=num_classes,
579+
rpn_anchor_generator=rpn_anchor_generator,
580+
rpn_head=rpn_head,
581+
box_head=box_head,
582+
**kwargs,
583+
)
584+
585+
if weights is not None:
586+
model.load_state_dict(weights.get_state_dict(progress=progress))
587+
588+
return model
589+
590+
484591
def _fasterrcnn_mobilenet_v3_large_fpn(
485592
*,
486593
weights: Optional[Union[FasterRCNN_MobileNet_V3_Large_FPN_Weights, FasterRCNN_MobileNet_V3_Large_320_FPN_Weights]],

0 commit comments

Comments
 (0)