Skip to content

Clean up SSD and SSDlite implementations #3818

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 12, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions docs/source/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -426,8 +426,8 @@ Faster R-CNN ResNet-50 FPN 37.0 - -
Faster R-CNN MobileNetV3-Large FPN 32.8 - -
Faster R-CNN MobileNetV3-Large 320 FPN 22.8 - -
RetinaNet ResNet-50 FPN 36.4 - -
SSD VGG16 25.1 - -
SSDlite MobileNetV3-Large 21.3 - -
SSD300 VGG16 25.1 - -
SSDlite320 MobileNetV3-Large 21.3 - -
Mask R-CNN ResNet-50 FPN 37.9 34.6 -
====================================== ======= ======== ===========

Expand Down Expand Up @@ -486,8 +486,8 @@ Faster R-CNN ResNet-50 FPN 0.2288 0.0590
Faster R-CNN MobileNetV3-Large FPN 0.1020 0.0415 1.0
Faster R-CNN MobileNetV3-Large 320 FPN 0.0978 0.0376 0.6
RetinaNet ResNet-50 FPN 0.2514 0.0939 4.1
SSD VGG16 0.2093 0.0744 1.5
SSDlite MobileNetV3-Large 0.1773 0.0906 1.5
SSD300 VGG16 0.2093 0.0744 1.5
SSDlite320 MobileNetV3-Large 0.1773 0.0906 1.5
Mask R-CNN ResNet-50 FPN 0.2728 0.0903 5.4
Keypoint R-CNN ResNet-50 FPN 0.3789 0.1242 6.8
====================================== =================== ================== ===========
Expand All @@ -502,19 +502,19 @@ Faster R-CNN


RetinaNet
------------
---------

.. autofunction:: torchvision.models.detection.retinanet_resnet50_fpn


SSD
------------
---

.. autofunction:: torchvision.models.detection.ssd300_vgg16


SSDlite
------------
-------

.. autofunction:: torchvision.models.detection.ssdlite320_mobilenet_v3_large

Expand Down
4 changes: 2 additions & 2 deletions references/detection/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,15 +48,15 @@ python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py\
--lr-steps 16 22 --aspect-ratio-group-factor 3 --lr 0.01
```

### SSD VGG16
### SSD300 VGG16
```
python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py\
--dataset coco --model ssd300_vgg16 --epochs 120\
--lr-steps 80 110 --aspect-ratio-group-factor 3 --lr 0.002 --batch-size 4\
--weight-decay 0.0005 --data-augmentation ssd
```

### SSDlite MobileNetV3-Large
### SSDlite320 MobileNetV3-Large
```
python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py\
--dataset coco --model ssdlite320_mobilenet_v3_large --epochs 660\
Expand Down
31 changes: 17 additions & 14 deletions torchvision/models/detection/ssd.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,7 @@ def postprocess_detections(self, head_outputs: Dict[str, Tensor], image_anchors:


class SSDFeatureExtractorVGG(nn.Module):
def __init__(self, backbone: nn.Module, highres: bool, rescaling: bool):
def __init__(self, backbone: nn.Module, highres: bool):
super().__init__()

_, _, maxpool3_pos, maxpool4_pos, _ = (i for i, layer in enumerate(backbone) if isinstance(layer, nn.MaxPool2d))
Expand Down Expand Up @@ -476,13 +476,8 @@ def __init__(self, backbone: nn.Module, highres: bool, rescaling: bool):
fc,
))
self.extra = extra
self.rescaling = rescaling

def forward(self, x: Tensor) -> Dict[str, Tensor]:
# Undo the 0-1 scaling of toTensor. Necessary for some backbones.
if self.rescaling:
x *= 255

# L2 regularization + Rescaling of 1st block's feature map
x = self.features(x)
rescaled = self.scale_weight.view(1, -1, 1, 1) * F.normalize(x)
Expand All @@ -496,8 +491,7 @@ def forward(self, x: Tensor) -> Dict[str, Tensor]:
return OrderedDict([(str(i), v) for i, v in enumerate(output)])


def _vgg_extractor(backbone_name: str, highres: bool, progress: bool, pretrained: bool, trainable_layers: int,
rescaling: bool):
def _vgg_extractor(backbone_name: str, highres: bool, progress: bool, pretrained: bool, trainable_layers: int):
if backbone_name in backbone_urls:
# Use custom backbones more appropriate for SSD
arch = backbone_name.split('_')[0]
Expand All @@ -521,19 +515,19 @@ def _vgg_extractor(backbone_name: str, highres: bool, progress: bool, pretrained
for parameter in b.parameters():
parameter.requires_grad_(False)

return SSDFeatureExtractorVGG(backbone, highres, rescaling)
return SSDFeatureExtractorVGG(backbone, highres)


def ssd300_vgg16(pretrained: bool = False, progress: bool = True, num_classes: int = 91,
pretrained_backbone: bool = True, trainable_backbone_layers: Optional[int] = None, **kwargs: Any):
"""
Constructs an SSD model with a VGG16 backbone. See `SSD` for more details.
Constructs an SSD model with input size 300x300 and a VGG16 backbone. See `SSD` for more details.

Example:

>>> model = torchvision.models.detection.ssd300_vgg16(pretrained=True)
>>> model.eval()
>>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
>>> x = [torch.rand(3, 300, 300), torch.rand(3, 500, 400)]
>>> predictions = model(x)

Args:
Expand All @@ -544,19 +538,28 @@ def ssd300_vgg16(pretrained: bool = False, progress: bool = True, num_classes: i
trainable_backbone_layers (int): number of trainable (not frozen) resnet layers starting from final block.
Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable.
"""
if "size" in kwargs:
warnings.warn("The size of the model is already fixed; ignoring the argument.")

trainable_backbone_layers = _validate_trainable_layers(
pretrained or pretrained_backbone, trainable_backbone_layers, 5, 5)

if pretrained:
# no need to download the backbone if pretrained is set
pretrained_backbone = False

backbone = _vgg_extractor("vgg16_features", False, progress, pretrained_backbone, trainable_backbone_layers, True)
backbone = _vgg_extractor("vgg16_features", False, progress, pretrained_backbone, trainable_backbone_layers)
anchor_generator = DefaultBoxGenerator([[2], [2, 3], [2, 3], [2, 3], [2], [2]],
scales=[0.07, 0.15, 0.33, 0.51, 0.69, 0.87, 1.05],
steps=[8, 16, 32, 64, 100, 300])
model = SSD(backbone, anchor_generator, (300, 300), num_classes,
image_mean=[0.48235, 0.45882, 0.40784], image_std=[1., 1., 1.], **kwargs)

defaults = {
# Rescale the input in a way compatible to the backbone
"image_mean": [0.48235, 0.45882, 0.40784],
"image_std": [1.0 / 255.0, 1.0 / 255.0, 1.0 / 255.0], # undo the 0-1 scaling of toTensor
}
kwargs = {**defaults, **kwargs}
model = SSD(backbone, anchor_generator, (300, 300), num_classes, **kwargs)
if pretrained:
weights_name = 'ssd300_vgg16_coco'
if model_urls.get(weights_name, None) is None:
Expand Down
30 changes: 15 additions & 15 deletions torchvision/models/detection/ssdlite.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch
import warnings

from collections import OrderedDict
from functools import partial
Expand Down Expand Up @@ -94,8 +95,7 @@ def __init__(self, in_channels: List[int], num_anchors: List[int], norm_layer: C


class SSDLiteFeatureExtractorMobileNet(nn.Module):
def __init__(self, backbone: nn.Module, c4_pos: int, norm_layer: Callable[..., nn.Module], rescaling: bool,
**kwargs: Any):
def __init__(self, backbone: nn.Module, c4_pos: int, norm_layer: Callable[..., nn.Module], **kwargs: Any):
super().__init__()
# non-public config parameters
min_depth = kwargs.pop('_min_depth', 16)
Expand All @@ -117,13 +117,8 @@ def __init__(self, backbone: nn.Module, c4_pos: int, norm_layer: Callable[..., n
_normal_init(extra)

self.extra = extra
self.rescaling = rescaling

def forward(self, x: Tensor) -> Dict[str, Tensor]:
# Rescale from [0, 1] to [-1, -1]
if self.rescaling:
x = 2.0 * x - 1.0

# Get feature maps from backbone and extra. Can't be refactored due to JIT limitations.
output = []
for block in self.features:
Expand All @@ -138,7 +133,7 @@ def forward(self, x: Tensor) -> Dict[str, Tensor]:


def _mobilenet_extractor(backbone_name: str, progress: bool, pretrained: bool, trainable_layers: int,
norm_layer: Callable[..., nn.Module], rescaling: bool, **kwargs: Any):
norm_layer: Callable[..., nn.Module], **kwargs: Any):
backbone = mobilenet.__dict__[backbone_name](pretrained=pretrained, progress=progress,
norm_layer=norm_layer, **kwargs).features
if not pretrained:
Expand All @@ -158,15 +153,15 @@ def _mobilenet_extractor(backbone_name: str, progress: bool, pretrained: bool, t
for parameter in b.parameters():
parameter.requires_grad_(False)

return SSDLiteFeatureExtractorMobileNet(backbone, stage_indices[-2], norm_layer, rescaling, **kwargs)
return SSDLiteFeatureExtractorMobileNet(backbone, stage_indices[-2], norm_layer, **kwargs)


def ssdlite320_mobilenet_v3_large(pretrained: bool = False, progress: bool = True, num_classes: int = 91,
pretrained_backbone: bool = False, trainable_backbone_layers: Optional[int] = None,
norm_layer: Optional[Callable[..., nn.Module]] = None,
**kwargs: Any):
"""
Constructs an SSDlite model with a MobileNetV3 Large backbone. See `SSD` for more details.
Constructs an SSDlite model with input size 320x320 and a MobileNetV3 Large backbone. See `SSD` for more details.

Example:

Expand All @@ -186,20 +181,23 @@ def ssdlite320_mobilenet_v3_large(pretrained: bool = False, progress: bool = Tru
Valid values are between 0 and 6, with 6 meaning all backbone layers are trainable.
norm_layer (callable, optional): Module specifying the normalization layer to use.
"""
if "size" in kwargs:
warnings.warn("The size of the model is already fixed; ignoring the argument.")

trainable_backbone_layers = _validate_trainable_layers(
pretrained or pretrained_backbone, trainable_backbone_layers, 6, 6)

if pretrained:
pretrained_backbone = False

# Enable [-1, 1] rescaling and reduced tail if no pretrained backbone is selected
rescaling = reduce_tail = not pretrained_backbone
# Enable reduced tail if no pretrained backbone is selected
reduce_tail = not pretrained_backbone

if norm_layer is None:
norm_layer = partial(nn.BatchNorm2d, eps=0.001, momentum=0.03)

backbone = _mobilenet_extractor("mobilenet_v3_large", progress, pretrained_backbone, trainable_backbone_layers,
norm_layer, rescaling, _reduced_tail=reduce_tail, _width_mult=1.0)
norm_layer, _reduced_tail=reduce_tail, _width_mult=1.0)

size = (320, 320)
anchor_generator = DefaultBoxGenerator([[2, 3] for _ in range(6)], min_ratio=0.2, max_ratio=0.95)
Expand All @@ -212,8 +210,10 @@ def ssdlite320_mobilenet_v3_large(pretrained: bool = False, progress: bool = Tru
"nms_thresh": 0.55,
"detections_per_img": 300,
"topk_candidates": 300,
"image_mean": [0., 0., 0.],
"image_std": [1., 1., 1.],
# Rescale the input in a way compatible to the backbone:
# The following mean/std rescale the data from [0, 1] to [-1, -1]
"image_mean": [0.5, 0.5, 0.5],
"image_std": [0.5, 0.5, 0.5],
}
kwargs = {**defaults, **kwargs}
model = SSD(backbone, anchor_generator, size, num_classes,
Expand Down