Skip to content

Commit 971c3e4

Browse files
authored
Type annotations for torchvision.ops (#2331)
* Add type annotations for torchvision.ops * Fix type annotations for torchvision.ops * Fix typo in import * Fix undefined name in FeaturePyramidNetwork
1 parent 67f5fcf commit 971c3e4

11 files changed

+194
-100
lines changed

torchvision/ops/_utils.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
import torch
22
from torch import Tensor
3-
from torch.jit.annotations import List
3+
from torch.jit.annotations import List, Tuple
44

55

6-
def _cat(tensors, dim=0):
7-
# type: (List[Tensor], int) -> Tensor
6+
def _cat(tensors: List[Tensor], dim: int = 0) -> Tensor:
87
"""
98
Efficient version of torch.cat that avoids a copy if there is only a single element in a list
109
"""
@@ -15,8 +14,7 @@ def _cat(tensors, dim=0):
1514
return torch.cat(tensors, dim)
1615

1716

18-
def convert_boxes_to_roi_format(boxes):
19-
# type: (List[Tensor]) -> Tensor
17+
def convert_boxes_to_roi_format(boxes: List[Tensor]) -> Tensor:
2018
concat_boxes = _cat([b for b in boxes], dim=0)
2119
temp = []
2220
for i, b in enumerate(boxes):
@@ -26,7 +24,7 @@ def convert_boxes_to_roi_format(boxes):
2624
return rois
2725

2826

29-
def check_roi_boxes_shape(boxes):
27+
def check_roi_boxes_shape(boxes: Tensor):
3028
if isinstance(boxes, (list, tuple)):
3129
for _tensor in boxes:
3230
assert _tensor.size(1) == 4, \

torchvision/ops/boxes.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,7 @@
44
import torchvision
55

66

7-
def nms(boxes, scores, iou_threshold):
8-
# type: (Tensor, Tensor, float) -> Tensor
7+
def nms(boxes: Tensor, scores: Tensor, iou_threshold: float) -> Tensor:
98
"""
109
Performs non-maximum suppression (NMS) on the boxes according
1110
to their intersection-over-union (IoU).
@@ -41,8 +40,12 @@ def nms(boxes, scores, iou_threshold):
4140

4241

4342
@torch.jit._script_if_tracing
44-
def batched_nms(boxes, scores, idxs, iou_threshold):
45-
# type: (Tensor, Tensor, Tensor, float) -> Tensor
43+
def batched_nms(
44+
boxes: Tensor,
45+
scores: Tensor,
46+
idxs: Tensor,
47+
iou_threshold: float,
48+
) -> Tensor:
4649
"""
4750
Performs non-maximum suppression in a batched fashion.
4851
@@ -83,8 +86,7 @@ def batched_nms(boxes, scores, idxs, iou_threshold):
8386
return keep
8487

8588

86-
def remove_small_boxes(boxes, min_size):
87-
# type: (Tensor, float) -> Tensor
89+
def remove_small_boxes(boxes: Tensor, min_size: float) -> Tensor:
8890
"""
8991
Remove boxes which contains at least one side smaller than min_size.
9092
@@ -102,8 +104,7 @@ def remove_small_boxes(boxes, min_size):
102104
return keep
103105

104106

105-
def clip_boxes_to_image(boxes, size):
106-
# type: (Tensor, Tuple[int, int]) -> Tensor
107+
def clip_boxes_to_image(boxes: Tensor, size: Tuple[int, int]) -> Tensor:
107108
"""
108109
Clip boxes so that they lie inside an image of size `size`.
109110
@@ -132,7 +133,7 @@ def clip_boxes_to_image(boxes, size):
132133
return clipped_boxes.reshape(boxes.shape)
133134

134135

135-
def box_area(boxes):
136+
def box_area(boxes: Tensor) -> Tensor:
136137
"""
137138
Computes the area of a set of bounding boxes, which are specified by its
138139
(x1, y1, x2, y2) coordinates.
@@ -149,7 +150,7 @@ def box_area(boxes):
149150

150151
# implementation from https://github.com/kuangliu/torchcv/blob/master/torchcv/utils/box.py
151152
# with slight modifications
152-
def box_iou(boxes1, boxes2):
153+
def box_iou(boxes1: Tensor, boxes2: Tensor) -> Tensor:
153154
"""
154155
Return intersection-over-union (Jaccard index) of boxes.
155156

torchvision/ops/deform_conv.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,15 @@
88
from torch.jit.annotations import Optional, Tuple
99

1010

11-
def deform_conv2d(input, offset, weight, bias=None, stride=(1, 1), padding=(0, 0), dilation=(1, 1)):
12-
# type: (Tensor, Tensor, Tensor, Optional[Tensor], Tuple[int, int], Tuple[int, int], Tuple[int, int]) -> Tensor
11+
def deform_conv2d(
12+
input: Tensor,
13+
offset: Tensor,
14+
weight: Tensor,
15+
bias: Optional[Tensor] = None,
16+
stride: Tuple[int, int] = (1, 1),
17+
padding: Tuple[int, int] = (0, 0),
18+
dilation: Tuple[int, int] = (1, 1),
19+
) -> Tensor:
1320
"""
1421
Performs Deformable Convolution, described in Deformable Convolutional Networks
1522
@@ -80,8 +87,17 @@ class DeformConv2d(nn.Module):
8087
"""
8188
See deform_conv2d
8289
"""
83-
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0,
84-
dilation=1, groups=1, bias=True):
90+
def __init__(
91+
self,
92+
in_channels: int,
93+
out_channels: int,
94+
kernel_size: int,
95+
stride: int = 1,
96+
padding: int = 0,
97+
dilation: int = 1,
98+
groups: int = 1,
99+
bias: bool = True,
100+
):
85101
super(DeformConv2d, self).__init__()
86102

87103
if in_channels % groups != 0:
@@ -107,14 +123,14 @@ def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0,
107123

108124
self.reset_parameters()
109125

110-
def reset_parameters(self):
126+
def reset_parameters(self) -> None:
111127
init.kaiming_uniform_(self.weight, a=math.sqrt(5))
112128
if self.bias is not None:
113129
fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
114130
bound = 1 / math.sqrt(fan_in)
115131
init.uniform_(self.bias, -bound, bound)
116132

117-
def forward(self, input, offset):
133+
def forward(self, input: Tensor, offset: Tensor) -> Tensor:
118134
"""
119135
Arguments:
120136
input (Tensor[batch_size, in_channels, in_height, in_width]): input tensor
@@ -125,7 +141,7 @@ def forward(self, input, offset):
125141
return deform_conv2d(input, offset, self.weight, self.bias, stride=self.stride,
126142
padding=self.padding, dilation=self.dilation)
127143

128-
def __repr__(self):
144+
def __repr__(self) -> str:
129145
s = self.__class__.__name__ + '('
130146
s += '{in_channels}'
131147
s += ', {out_channels}'

torchvision/ops/feature_pyramid_network.py

Lines changed: 47 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,31 @@
44
import torch.nn.functional as F
55
from torch import nn, Tensor
66

7-
from torch.jit.annotations import Tuple, List, Dict
7+
from torch.jit.annotations import Tuple, List, Dict, Optional
8+
9+
10+
class ExtraFPNBlock(nn.Module):
11+
"""
12+
Base class for the extra block in the FPN.
13+
14+
Arguments:
15+
results (List[Tensor]): the result of the FPN
16+
x (List[Tensor]): the original feature maps
17+
names (List[str]): the names for each one of the
18+
original feature maps
19+
20+
Returns:
21+
results (List[Tensor]): the extended set of results
22+
of the FPN
23+
names (List[str]): the extended set of names for the results
24+
"""
25+
def forward(
26+
self,
27+
results: List[Tensor],
28+
x: List[Tensor],
29+
names: List[str],
30+
) -> Tuple[List[Tensor], List[str]]:
31+
pass
832

933

1034
class FeaturePyramidNetwork(nn.Module):
@@ -44,7 +68,12 @@ class FeaturePyramidNetwork(nn.Module):
4468
>>> ('feat3', torch.Size([1, 5, 8, 8]))]
4569
4670
"""
47-
def __init__(self, in_channels_list, out_channels, extra_blocks=None):
71+
def __init__(
72+
self,
73+
in_channels_list: List[int],
74+
out_channels: int,
75+
extra_blocks: Optional[ExtraFPNBlock] = None,
76+
):
4877
super(FeaturePyramidNetwork, self).__init__()
4978
self.inner_blocks = nn.ModuleList()
5079
self.layer_blocks = nn.ModuleList()
@@ -66,8 +95,7 @@ def __init__(self, in_channels_list, out_channels, extra_blocks=None):
6695
assert isinstance(extra_blocks, ExtraFPNBlock)
6796
self.extra_blocks = extra_blocks
6897

69-
def get_result_from_inner_blocks(self, x, idx):
70-
# type: (Tensor, int) -> Tensor
98+
def get_result_from_inner_blocks(self, x: Tensor, idx: int) -> Tensor:
7199
"""
72100
This is equivalent to self.inner_blocks[idx](x),
73101
but torchscript doesn't support this yet
@@ -85,8 +113,7 @@ def get_result_from_inner_blocks(self, x, idx):
85113
i += 1
86114
return out
87115

88-
def get_result_from_layer_blocks(self, x, idx):
89-
# type: (Tensor, int) -> Tensor
116+
def get_result_from_layer_blocks(self, x: Tensor, idx: int) -> Tensor:
90117
"""
91118
This is equivalent to self.layer_blocks[idx](x),
92119
but torchscript doesn't support this yet
@@ -104,8 +131,7 @@ def get_result_from_layer_blocks(self, x, idx):
104131
i += 1
105132
return out
106133

107-
def forward(self, x):
108-
# type: (Dict[str, Tensor]) -> Dict[str, Tensor]
134+
def forward(self, x: Dict[str, Tensor]) -> Dict[str, Tensor]:
109135
"""
110136
Computes the FPN for a set of feature maps.
111137
@@ -140,31 +166,16 @@ def forward(self, x):
140166
return out
141167

142168

143-
class ExtraFPNBlock(nn.Module):
144-
"""
145-
Base class for the extra block in the FPN.
146-
147-
Arguments:
148-
results (List[Tensor]): the result of the FPN
149-
x (List[Tensor]): the original feature maps
150-
names (List[str]): the names for each one of the
151-
original feature maps
152-
153-
Returns:
154-
results (List[Tensor]): the extended set of results
155-
of the FPN
156-
names (List[str]): the extended set of names for the results
157-
"""
158-
def forward(self, results, x, names):
159-
pass
160-
161-
162169
class LastLevelMaxPool(ExtraFPNBlock):
163170
"""
164171
Applies a max_pool2d on top of the last feature map
165172
"""
166-
def forward(self, x, y, names):
167-
# type: (List[Tensor], List[Tensor], List[str]) -> Tuple[List[Tensor], List[str]]
173+
def forward(
174+
self,
175+
x: List[Tensor],
176+
y: List[Tensor],
177+
names: List[str],
178+
) -> Tuple[List[Tensor], List[str]]:
168179
names.append("pool")
169180
x.append(F.max_pool2d(x[-1], 1, 2, 0))
170181
return x, names
@@ -174,7 +185,7 @@ class LastLevelP6P7(ExtraFPNBlock):
174185
"""
175186
This module is used in RetinaNet to generate extra layers, P6 and P7.
176187
"""
177-
def __init__(self, in_channels, out_channels):
188+
def __init__(self, in_channels: int, out_channels: int):
178189
super(LastLevelP6P7, self).__init__()
179190
self.p6 = nn.Conv2d(in_channels, out_channels, 3, 2, 1)
180191
self.p7 = nn.Conv2d(out_channels, out_channels, 3, 2, 1)
@@ -183,7 +194,12 @@ def __init__(self, in_channels, out_channels):
183194
nn.init.constant_(module.bias, 0)
184195
self.use_P5 = in_channels == out_channels
185196

186-
def forward(self, p, c, names):
197+
def forward(
198+
self,
199+
p: List[Tensor],
200+
c: List[Tensor],
201+
names: List[str],
202+
) -> Tuple[List[Tensor], List[str]]:
187203
p5, c5 = p[-1], c[-1]
188204
x = p5 if self.use_P5 else c5
189205
p6 = self.p6(x)

torchvision/ops/misc.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010

1111
import warnings
1212
import torch
13+
from torch import Tensor, Size
14+
from torch.jit.annotations import List, Optional, Tuple
1315

1416

1517
class Conv2d(torch.nn.Conv2d):
@@ -46,7 +48,12 @@ class FrozenBatchNorm2d(torch.nn.Module):
4648
are fixed
4749
"""
4850

49-
def __init__(self, num_features, eps=0., n=None):
51+
def __init__(
52+
self,
53+
num_features: Tuple[int, ...],
54+
eps: float = 0.,
55+
n: Optional[Tuple[int, ...]] = None,
56+
):
5057
# n=None for backward-compatibility
5158
if n is not None:
5259
warnings.warn("`n` argument is deprecated and has been renamed `num_features`",
@@ -59,8 +66,16 @@ def __init__(self, num_features, eps=0., n=None):
5966
self.register_buffer("running_mean", torch.zeros(num_features))
6067
self.register_buffer("running_var", torch.ones(num_features))
6168

62-
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
63-
missing_keys, unexpected_keys, error_msgs):
69+
def _load_from_state_dict(
70+
self,
71+
state_dict: dict,
72+
prefix: str,
73+
local_metadata: dict,
74+
strict: bool,
75+
missing_keys: List[str],
76+
unexpected_keys: List[str],
77+
error_msgs: List[str],
78+
):
6479
num_batches_tracked_key = prefix + 'num_batches_tracked'
6580
if num_batches_tracked_key in state_dict:
6681
del state_dict[num_batches_tracked_key]
@@ -69,7 +84,7 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
6984
state_dict, prefix, local_metadata, strict,
7085
missing_keys, unexpected_keys, error_msgs)
7186

72-
def forward(self, x):
87+
def forward(self, x: Tensor) -> Tensor:
7388
# move reshapes to the beginning
7489
# to make it fuser-friendly
7590
w = self.weight.reshape(1, -1, 1, 1)
@@ -80,5 +95,5 @@ def forward(self, x):
8095
bias = b - rm * scale
8196
return x * scale + bias
8297

83-
def __repr__(self):
98+
def __repr__(self) -> str:
8499
return f"{self.__class__.__name__}({self.weight.shape[0]})"

torchvision/ops/new_empty_tensor.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,7 @@
33
from torch import Tensor
44

55

6-
def _new_empty_tensor(x, shape):
7-
# type: (Tensor, List[int]) -> Tensor
6+
def _new_empty_tensor(x: Tensor, shape: List[int]) -> Tensor:
87
"""
98
Arguments:
109
input (Tensor): input tensor

0 commit comments

Comments
 (0)