@@ -24,8 +24,9 @@ def _fake_cast_onnx(v):
24
24
25
25
26
26
def _resize_image_and_masks (image : Tensor , self_min_size : float , self_max_size : float ,
27
- target : Optional [Dict [str , Tensor ]],
28
- fixed_size : Optional [Tuple [int , int ]]) -> Tuple [Tensor , Optional [Dict [str , Tensor ]]]:
27
+ target : Optional [Dict [str , Tensor ]] = None ,
28
+ fixed_size : Optional [Tuple [int , int ]] = None ,
29
+ ) -> Tuple [Tensor , Optional [Dict [str , Tensor ]]]:
29
30
if torchvision ._is_tracing ():
30
31
im_shape = _get_shape_onnx (image )
31
32
else :
@@ -146,8 +147,10 @@ def torch_choice(self, k):
146
147
index = int (torch .empty (1 ).uniform_ (0. , float (len (k ))).item ())
147
148
return k [index ]
148
149
149
- def resize (self , image , target ):
150
- # type: (Tensor, Optional[Dict[str, Tensor]]) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]
150
+ def resize (self ,
151
+ image : Tensor ,
152
+ target : Optional [Dict [str , Tensor ]] = None ,
153
+ ) -> Tuple [Tensor , Optional [Dict [str , Tensor ]]]:
151
154
h , w = image .shape [- 2 :]
152
155
if self .training :
153
156
size = float (self .torch_choice (self .min_size ))
0 commit comments