Skip to content

Add typing annotations to detection/backbone_utils #4603

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Oct 19, 2021
4 changes: 0 additions & 4 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,6 @@ ignore_errors=True

ignore_errors = True

[mypy-torchvision.models.detection.backbone_utils]

ignore_errors = True

[mypy-torchvision.models.detection.image_list]

ignore_errors = True
Expand Down
16 changes: 9 additions & 7 deletions torchvision/models/detection/backbone_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import warnings
from typing import Callable

from torch import nn
from torch import nn, Tensor
from torchvision.ops import misc as misc_nn_ops
from torchvision.ops.feature_pyramid_network import FeaturePyramidNetwork, LastLevelMaxPool

Expand Down Expand Up @@ -42,7 +43,7 @@ def __init__(self, backbone, return_layers, in_channels_list, out_channels, extr
)
self.out_channels = out_channels

def forward(self, x):
def forward(self, x: Tensor) -> Tensor:
x = self.body(x)
x = self.fpn(x)
return x
Expand All @@ -51,11 +52,11 @@ def forward(self, x):
def resnet_fpn_backbone(
backbone_name,
pretrained,
norm_layer=misc_nn_ops.FrozenBatchNorm2d,
norm_layer: Callable[..., nn.Module] = misc_nn_ops.FrozenBatchNorm2d,
trainable_layers=3,
returned_layers=None,
extra_blocks=None,
):
) -> BackboneWithFPN:
"""
Constructs a specified ResNet backbone with FPN on top. Freezes the specified number of layers in the backbone.

Expand Down Expand Up @@ -137,12 +138,13 @@ def _validate_trainable_layers(pretrained, trainable_backbone_layers, max_value,
def mobilenet_backbone(
backbone_name,
pretrained,
fpn,
norm_layer=misc_nn_ops.FrozenBatchNorm2d,
fpn: bool,
norm_layer: Callable[..., nn.Module] = misc_nn_ops.FrozenBatchNorm2d,
trainable_layers=2,
returned_layers=None,
extra_blocks=None,
):
) -> nn.Module:

backbone = mobilenet.__dict__[backbone_name](pretrained=pretrained, norm_layer=norm_layer).features

# Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks.
Expand Down