Skip to content

Commit cc72b8f

Browse files
datumboxfmassa
authored andcommitted
[fbsync] Add typing in GeneralizedRCNNTransform (#4369)
Summary: * Add types in transform. * Trace on eval mode. Reviewed By: fmassa Differential Revision: D30793317 fbshipit-source-id: 4751969a060d909c59908aa37a7fb809dc6e19f7 Co-authored-by: Francisco Massa <[email protected]>
1 parent c23c2c4 commit cc72b8f

File tree

1 file changed

+19
-28
lines changed

1 file changed

+19
-28
lines changed

torchvision/models/detection/transform.py

Lines changed: 19 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,13 @@
1010

1111

1212
@torch.jit.unused
13-
def _get_shape_onnx(image):
14-
# type: (Tensor) -> Tensor
13+
def _get_shape_onnx(image: Tensor) -> Tensor:
1514
from torch.onnx import operators
1615
return operators.shape_as_tensor(image)[-2:]
1716

1817

1918
@torch.jit.unused
20-
def _fake_cast_onnx(v):
21-
# type: (Tensor) -> float
19+
def _fake_cast_onnx(v: Tensor) -> float:
2220
# ONNX requires a tensor but here we fake its type for JIT.
2321
return v
2422

@@ -74,7 +72,8 @@ class GeneralizedRCNNTransform(nn.Module):
7472
It returns a ImageList for the inputs, and a List[Dict[Tensor]] for the targets
7573
"""
7674

77-
def __init__(self, min_size, max_size, image_mean, image_std, size_divisible=32, fixed_size=None):
75+
def __init__(self, min_size: int, max_size: int, image_mean: List[float], image_std: List[float],
76+
size_divisible: int = 32, fixed_size: Optional[Tuple[int, int]] = None):
7877
super(GeneralizedRCNNTransform, self).__init__()
7978
if not isinstance(min_size, (list, tuple)):
8079
min_size = (min_size,)
@@ -86,10 +85,9 @@ def __init__(self, min_size, max_size, image_mean, image_std, size_divisible=32,
8685
self.fixed_size = fixed_size
8786

8887
def forward(self,
89-
images, # type: List[Tensor]
90-
targets=None # type: Optional[List[Dict[str, Tensor]]]
91-
):
92-
# type: (...) -> Tuple[ImageList, Optional[List[Dict[str, Tensor]]]]
88+
images: List[Tensor],
89+
targets: Optional[List[Dict[str, Tensor]]] = None
90+
) -> Tuple[ImageList, Optional[List[Dict[str, Tensor]]]]:
9391
images = [img for img in images]
9492
if targets is not None:
9593
# make a copy of targets to avoid modifying it in-place
@@ -126,7 +124,7 @@ def forward(self,
126124
image_list = ImageList(images, image_sizes_list)
127125
return image_list, targets
128126

129-
def normalize(self, image):
127+
def normalize(self, image: Tensor) -> Tensor:
130128
if not image.is_floating_point():
131129
raise TypeError(
132130
f"Expected input images to be of floating type (in range [0, 1]), "
@@ -137,8 +135,7 @@ def normalize(self, image):
137135
std = torch.as_tensor(self.image_std, dtype=dtype, device=device)
138136
return (image - mean[:, None, None]) / std[:, None, None]
139137

140-
def torch_choice(self, k):
141-
# type: (List[int]) -> int
138+
def torch_choice(self, k: List[int]) -> int:
142139
"""
143140
Implements `random.choice` via torch ops so it can be compiled with
144141
TorchScript. Remove if https://github.com/pytorch/pytorch/issues/25803
@@ -175,8 +172,7 @@ def resize(self,
175172
# _onnx_batch_images() is an implementation of
176173
# batch_images() that is supported by ONNX tracing.
177174
@torch.jit.unused
178-
def _onnx_batch_images(self, images, size_divisible=32):
179-
# type: (List[Tensor], int) -> Tensor
175+
def _onnx_batch_images(self, images: List[Tensor], size_divisible: int = 32) -> Tensor:
180176
max_size = []
181177
for i in range(images[0].dim()):
182178
max_size_i = torch.max(torch.stack([img.shape[i] for img in images]).to(torch.float32)).to(torch.int64)
@@ -197,16 +193,14 @@ def _onnx_batch_images(self, images, size_divisible=32):
197193

198194
return torch.stack(padded_imgs)
199195

200-
def max_by_axis(self, the_list):
201-
# type: (List[List[int]]) -> List[int]
196+
def max_by_axis(self, the_list: List[List[int]]) -> List[int]:
202197
maxes = the_list[0]
203198
for sublist in the_list[1:]:
204199
for index, item in enumerate(sublist):
205200
maxes[index] = max(maxes[index], item)
206201
return maxes
207202

208-
def batch_images(self, images, size_divisible=32):
209-
# type: (List[Tensor], int) -> Tensor
203+
def batch_images(self, images: List[Tensor], size_divisible: int = 32) -> Tensor:
210204
if torchvision._is_tracing():
211205
# batch_images() does not export well to ONNX
212206
# call _onnx_batch_images() instead
@@ -226,11 +220,10 @@ def batch_images(self, images, size_divisible=32):
226220
return batched_imgs
227221

228222
def postprocess(self,
229-
result, # type: List[Dict[str, Tensor]]
230-
image_shapes, # type: List[Tuple[int, int]]
231-
original_image_sizes # type: List[Tuple[int, int]]
232-
):
233-
# type: (...) -> List[Dict[str, Tensor]]
223+
result: List[Dict[str, Tensor]],
224+
image_shapes: List[Tuple[int, int]],
225+
original_image_sizes: List[Tuple[int, int]]
226+
) -> List[Dict[str, Tensor]]:
234227
if self.training:
235228
return result
236229
for i, (pred, im_s, o_im_s) in enumerate(zip(result, image_shapes, original_image_sizes)):
@@ -247,7 +240,7 @@ def postprocess(self,
247240
result[i]["keypoints"] = keypoints
248241
return result
249242

250-
def __repr__(self):
243+
def __repr__(self) -> str:
251244
format_string = self.__class__.__name__ + '('
252245
_indent = '\n '
253246
format_string += "{0}Normalize(mean={1}, std={2})".format(_indent, self.image_mean, self.image_std)
@@ -257,8 +250,7 @@ def __repr__(self):
257250
return format_string
258251

259252

260-
def resize_keypoints(keypoints, original_size, new_size):
261-
# type: (Tensor, List[int], List[int]) -> Tensor
253+
def resize_keypoints(keypoints: Tensor, original_size: List[int], new_size: List[int]) -> Tensor:
262254
ratios = [
263255
torch.tensor(s, dtype=torch.float32, device=keypoints.device) /
264256
torch.tensor(s_orig, dtype=torch.float32, device=keypoints.device)
@@ -276,8 +268,7 @@ def resize_keypoints(keypoints, original_size, new_size):
276268
return resized_data
277269

278270

279-
def resize_boxes(boxes, original_size, new_size):
280-
# type: (Tensor, List[int], List[int]) -> Tensor
271+
def resize_boxes(boxes: Tensor, original_size: List[int], new_size: List[int]) -> Tensor:
281272
ratios = [
282273
torch.tensor(s, dtype=torch.float32, device=boxes.device) /
283274
torch.tensor(s_orig, dtype=torch.float32, device=boxes.device)

0 commit comments

Comments
 (0)