Skip to content

Commit d8e0dfd

Browse files
fmassadatumboxkhushi-411
authored andcommitted
[fbsync] Add typing annotations to detection/backbone_utils (#4603)
Summary: * Start adding types * add typing * Type prototype models * fix optional type bug * transient import * Fix weights type * fix import * Apply suggestions from code review Address nits Reviewed By: datumbox Differential Revision: D31898218 fbshipit-source-id: e52ab136755ed0d0cc975a2ec914bd285f3d0674 Co-authored-by: Vasilis Vryniotis <[email protected]> Co-authored-by: Khushi Agrawal <[email protected]> Co-authored-by: Vasilis Vryniotis <[email protected]>
1 parent 08fce7c commit d8e0dfd

File tree

3 files changed

+52
-36
lines changed

3 files changed

+52
-36
lines changed

mypy.ini

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,6 @@ ignore_errors=True
2121

2222
ignore_errors = True
2323

24-
[mypy-torchvision.models.detection.backbone_utils]
25-
26-
ignore_errors = True
27-
2824
[mypy-torchvision.models.detection.transform]
2925

3026
ignore_errors = True

torchvision/models/detection/backbone_utils.py

Lines changed: 38 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import warnings
2-
from typing import List, Optional
2+
from typing import Callable, Dict, Optional, List
33

4-
from torch import nn
4+
from torch import nn, Tensor
55
from torchvision.ops import misc as misc_nn_ops
66
from torchvision.ops.feature_pyramid_network import FeaturePyramidNetwork, LastLevelMaxPool, ExtraFPNBlock
77

@@ -29,7 +29,14 @@ class BackboneWithFPN(nn.Module):
2929
out_channels (int): the number of channels in the FPN
3030
"""
3131

32-
def __init__(self, backbone, return_layers, in_channels_list, out_channels, extra_blocks=None):
32+
def __init__(
33+
self,
34+
backbone: nn.Module,
35+
return_layers: Dict[str, str],
36+
in_channels_list: List[int],
37+
out_channels: int,
38+
extra_blocks: Optional[ExtraFPNBlock] = None,
39+
) -> None:
3340
super(BackboneWithFPN, self).__init__()
3441

3542
if extra_blocks is None:
@@ -43,20 +50,20 @@ def __init__(self, backbone, return_layers, in_channels_list, out_channels, extr
4350
)
4451
self.out_channels = out_channels
4552

46-
def forward(self, x):
53+
def forward(self, x: Tensor) -> Dict[str, Tensor]:
4754
x = self.body(x)
4855
x = self.fpn(x)
4956
return x
5057

5158

5259
def resnet_fpn_backbone(
53-
backbone_name,
54-
pretrained,
55-
norm_layer=misc_nn_ops.FrozenBatchNorm2d,
56-
trainable_layers=3,
57-
returned_layers=None,
58-
extra_blocks=None,
59-
):
60+
backbone_name: str,
61+
pretrained: bool,
62+
norm_layer: Callable[..., nn.Module] = misc_nn_ops.FrozenBatchNorm2d,
63+
trainable_layers: int = 3,
64+
returned_layers: Optional[List[int]] = None,
65+
extra_blocks: Optional[ExtraFPNBlock] = None,
66+
) -> BackboneWithFPN:
6067
"""
6168
Constructs a specified ResNet backbone with FPN on top. Freezes the specified number of layers in the backbone.
6269
@@ -80,7 +87,7 @@ def resnet_fpn_backbone(
8087
backbone_name (string): resnet architecture. Possible values are 'ResNet', 'resnet18', 'resnet34', 'resnet50',
8188
'resnet101', 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 'wide_resnet50_2', 'wide_resnet101_2'
8289
pretrained (bool): If True, returns a model with backbone pre-trained on Imagenet
83-
norm_layer (torchvision.ops): it is recommended to use the default value. For details visit:
90+
norm_layer (callable): it is recommended to use the default value. For details visit:
8491
(https://github.com/facebookresearch/maskrcnn-benchmark/issues/267)
8592
trainable_layers (int): number of trainable (not frozen) resnet layers starting from final block.
8693
Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable.
@@ -101,7 +108,8 @@ def _resnet_backbone_config(
101108
trainable_layers: int,
102109
returned_layers: Optional[List[int]],
103110
extra_blocks: Optional[ExtraFPNBlock],
104-
):
111+
) -> BackboneWithFPN:
112+
105113
# select layers that wont be frozen
106114
assert 0 <= trainable_layers <= 5
107115
layers_to_train = ["layer4", "layer3", "layer2", "layer1", "conv1"][:trainable_layers]
@@ -125,8 +133,13 @@ def _resnet_backbone_config(
125133
return BackboneWithFPN(backbone, return_layers, in_channels_list, out_channels, extra_blocks=extra_blocks)
126134

127135

128-
def _validate_trainable_layers(pretrained, trainable_backbone_layers, max_value, default_value):
129-
# dont freeze any layers if pretrained model or backbone is not used
136+
def _validate_trainable_layers(
137+
pretrained: bool,
138+
trainable_backbone_layers: Optional[int],
139+
max_value: int,
140+
default_value: int,
141+
) -> int:
142+
# don't freeze any layers if pretrained model or backbone is not used
130143
if not pretrained:
131144
if trainable_backbone_layers is not None:
132145
warnings.warn(
@@ -144,14 +157,15 @@ def _validate_trainable_layers(pretrained, trainable_backbone_layers, max_value,
144157

145158

146159
def mobilenet_backbone(
147-
backbone_name,
148-
pretrained,
149-
fpn,
150-
norm_layer=misc_nn_ops.FrozenBatchNorm2d,
151-
trainable_layers=2,
152-
returned_layers=None,
153-
extra_blocks=None,
154-
):
160+
backbone_name: str,
161+
pretrained: bool,
162+
fpn: bool,
163+
norm_layer: Callable[..., nn.Module] = misc_nn_ops.FrozenBatchNorm2d,
164+
trainable_layers: int = 2,
165+
returned_layers: Optional[List[int]] = None,
166+
extra_blocks: Optional[ExtraFPNBlock] = None,
167+
) -> nn.Module:
168+
155169
backbone = mobilenet.__dict__[backbone_name](pretrained=pretrained, norm_layer=norm_layer).features
156170

157171
# Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks.
@@ -185,5 +199,5 @@ def mobilenet_backbone(
185199
# depthwise linear combination of channels to reduce their size
186200
nn.Conv2d(backbone[-1].out_channels, out_channels, 1),
187201
)
188-
m.out_channels = out_channels
202+
m.out_channels = out_channels # type: ignore[assignment]
189203
return m
Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,20 @@
1-
from ....models.detection.backbone_utils import misc_nn_ops, _resnet_backbone_config
1+
from typing import Callable, Optional, List
2+
3+
from torch import nn
4+
5+
from ....models.detection.backbone_utils import misc_nn_ops, _resnet_backbone_config, BackboneWithFPN, ExtraFPNBlock
26
from .. import resnet
7+
from .._api import Weights
38

49

510
def resnet_fpn_backbone(
6-
backbone_name,
7-
weights,
8-
norm_layer=misc_nn_ops.FrozenBatchNorm2d,
9-
trainable_layers=3,
10-
returned_layers=None,
11-
extra_blocks=None,
12-
):
11+
backbone_name: str,
12+
weights: Optional[Weights],
13+
norm_layer: Callable[..., nn.Module] = misc_nn_ops.FrozenBatchNorm2d,
14+
trainable_layers: int = 3,
15+
returned_layers: Optional[List[int]] = None,
16+
extra_blocks: Optional[ExtraFPNBlock] = None,
17+
) -> BackboneWithFPN:
18+
1319
backbone = resnet.__dict__[backbone_name](weights=weights, norm_layer=norm_layer)
1420
return _resnet_backbone_config(backbone, trainable_layers, returned_layers, extra_blocks)

0 commit comments

Comments
 (0)