10
10
11
11
12
12
@torch .jit .unused
13
- def _get_shape_onnx (image ):
14
- # type: (Tensor) -> Tensor
13
+ def _get_shape_onnx (image : Tensor ) -> Tensor :
15
14
from torch .onnx import operators
16
15
return operators .shape_as_tensor (image )[- 2 :]
17
16
18
17
19
18
@torch .jit .unused
20
- def _fake_cast_onnx (v ):
21
- # type: (Tensor) -> float
19
+ def _fake_cast_onnx (v : Tensor ) -> float :
22
20
# ONNX requires a tensor but here we fake its type for JIT.
23
21
return v
24
22
@@ -74,7 +72,8 @@ class GeneralizedRCNNTransform(nn.Module):
74
72
It returns a ImageList for the inputs, and a List[Dict[Tensor]] for the targets
75
73
"""
76
74
77
- def __init__ (self , min_size , max_size , image_mean , image_std , size_divisible = 32 , fixed_size = None ):
75
+ def __init__ (self , min_size : int , max_size : int , image_mean : List [float ], image_std : List [float ],
76
+ size_divisible : int = 32 , fixed_size : Optional [Tuple [int , int ]] = None ):
78
77
super (GeneralizedRCNNTransform , self ).__init__ ()
79
78
if not isinstance (min_size , (list , tuple )):
80
79
min_size = (min_size ,)
@@ -86,10 +85,9 @@ def __init__(self, min_size, max_size, image_mean, image_std, size_divisible=32,
86
85
self .fixed_size = fixed_size
87
86
88
87
def forward (self ,
89
- images , # type: List[Tensor]
90
- targets = None # type: Optional[List[Dict[str, Tensor]]]
91
- ):
92
- # type: (...) -> Tuple[ImageList, Optional[List[Dict[str, Tensor]]]]
88
+ images : List [Tensor ],
89
+ targets : Optional [List [Dict [str , Tensor ]]] = None
90
+ ) -> Tuple [ImageList , Optional [List [Dict [str , Tensor ]]]]:
93
91
images = [img for img in images ]
94
92
if targets is not None :
95
93
# make a copy of targets to avoid modifying it in-place
@@ -126,7 +124,7 @@ def forward(self,
126
124
image_list = ImageList (images , image_sizes_list )
127
125
return image_list , targets
128
126
129
- def normalize (self , image ) :
127
+ def normalize (self , image : Tensor ) -> Tensor :
130
128
if not image .is_floating_point ():
131
129
raise TypeError (
132
130
f"Expected input images to be of floating type (in range [0, 1]), "
@@ -137,8 +135,7 @@ def normalize(self, image):
137
135
std = torch .as_tensor (self .image_std , dtype = dtype , device = device )
138
136
return (image - mean [:, None , None ]) / std [:, None , None ]
139
137
140
- def torch_choice (self , k ):
141
- # type: (List[int]) -> int
138
+ def torch_choice (self , k : List [int ]) -> int :
142
139
"""
143
140
Implements `random.choice` via torch ops so it can be compiled with
144
141
TorchScript. Remove if https://github.com/pytorch/pytorch/issues/25803
@@ -175,8 +172,7 @@ def resize(self,
175
172
# _onnx_batch_images() is an implementation of
176
173
# batch_images() that is supported by ONNX tracing.
177
174
@torch .jit .unused
178
- def _onnx_batch_images (self , images , size_divisible = 32 ):
179
- # type: (List[Tensor], int) -> Tensor
175
+ def _onnx_batch_images (self , images : List [Tensor ], size_divisible : int = 32 ) -> Tensor :
180
176
max_size = []
181
177
for i in range (images [0 ].dim ()):
182
178
max_size_i = torch .max (torch .stack ([img .shape [i ] for img in images ]).to (torch .float32 )).to (torch .int64 )
@@ -197,16 +193,14 @@ def _onnx_batch_images(self, images, size_divisible=32):
197
193
198
194
return torch .stack (padded_imgs )
199
195
200
- def max_by_axis (self , the_list ):
201
- # type: (List[List[int]]) -> List[int]
196
+ def max_by_axis (self , the_list : List [List [int ]]) -> List [int ]:
202
197
maxes = the_list [0 ]
203
198
for sublist in the_list [1 :]:
204
199
for index , item in enumerate (sublist ):
205
200
maxes [index ] = max (maxes [index ], item )
206
201
return maxes
207
202
208
- def batch_images (self , images , size_divisible = 32 ):
209
- # type: (List[Tensor], int) -> Tensor
203
+ def batch_images (self , images : List [Tensor ], size_divisible : int = 32 ) -> Tensor :
210
204
if torchvision ._is_tracing ():
211
205
# batch_images() does not export well to ONNX
212
206
# call _onnx_batch_images() instead
@@ -226,11 +220,10 @@ def batch_images(self, images, size_divisible=32):
226
220
return batched_imgs
227
221
228
222
def postprocess (self ,
229
- result , # type: List[Dict[str, Tensor]]
230
- image_shapes , # type: List[Tuple[int, int]]
231
- original_image_sizes # type: List[Tuple[int, int]]
232
- ):
233
- # type: (...) -> List[Dict[str, Tensor]]
223
+ result : List [Dict [str , Tensor ]],
224
+ image_shapes : List [Tuple [int , int ]],
225
+ original_image_sizes : List [Tuple [int , int ]]
226
+ ) -> List [Dict [str , Tensor ]]:
234
227
if self .training :
235
228
return result
236
229
for i , (pred , im_s , o_im_s ) in enumerate (zip (result , image_shapes , original_image_sizes )):
@@ -247,7 +240,7 @@ def postprocess(self,
247
240
result [i ]["keypoints" ] = keypoints
248
241
return result
249
242
250
- def __repr__ (self ):
243
+ def __repr__ (self ) -> str :
251
244
format_string = self .__class__ .__name__ + '('
252
245
_indent = '\n '
253
246
format_string += "{0}Normalize(mean={1}, std={2})" .format (_indent , self .image_mean , self .image_std )
@@ -257,8 +250,7 @@ def __repr__(self):
257
250
return format_string
258
251
259
252
260
- def resize_keypoints (keypoints , original_size , new_size ):
261
- # type: (Tensor, List[int], List[int]) -> Tensor
253
+ def resize_keypoints (keypoints : Tensor , original_size : List [int ], new_size : List [int ]) -> Tensor :
262
254
ratios = [
263
255
torch .tensor (s , dtype = torch .float32 , device = keypoints .device ) /
264
256
torch .tensor (s_orig , dtype = torch .float32 , device = keypoints .device )
@@ -276,8 +268,7 @@ def resize_keypoints(keypoints, original_size, new_size):
276
268
return resized_data
277
269
278
270
279
- def resize_boxes (boxes , original_size , new_size ):
280
- # type: (Tensor, List[int], List[int]) -> Tensor
271
+ def resize_boxes (boxes : Tensor , original_size : List [int ], new_size : List [int ]) -> Tensor :
281
272
ratios = [
282
273
torch .tensor (s , dtype = torch .float32 , device = boxes .device ) /
283
274
torch .tensor (s_orig , dtype = torch .float32 , device = boxes .device )
0 commit comments