3
3
4
4
from enum import Enum
5
5
from torch import Tensor
6
- from typing import List , Tuple , Optional
6
+ from typing import List , Tuple , Optional , Dict
7
7
8
8
from . import functional as F , InterpolationMode
9
9
@@ -19,7 +19,9 @@ class AutoAugmentPolicy(Enum):
19
19
SVHN = "svhn"
20
20
21
21
22
- def _get_transforms (policy : AutoAugmentPolicy ):
22
+ def _get_transforms ( # type: ignore[return]
23
+ policy : AutoAugmentPolicy
24
+ ) -> List [Tuple [Tuple [str , float , Optional [int ]], Tuple [str , float , Optional [int ]]]]:
23
25
if policy == AutoAugmentPolicy .IMAGENET :
24
26
return [
25
27
(("Posterize" , 0.4 , 8 ), ("Rotate" , 0.6 , 9 )),
@@ -106,7 +108,7 @@ def _get_transforms(policy: AutoAugmentPolicy):
106
108
]
107
109
108
110
109
- def _get_magnitudes ():
111
+ def _get_magnitudes () -> Dict [ str , Tuple [ Optional [ Tensor ], Optional [ bool ]]] :
110
112
_BINS = 10
111
113
return {
112
114
# name: (magnitudes, signed)
@@ -144,8 +146,12 @@ class AutoAugment(torch.nn.Module):
144
146
image. If given a number, the value is used for all bands respectively.
145
147
"""
146
148
147
- def __init__ (self , policy : AutoAugmentPolicy = AutoAugmentPolicy .IMAGENET ,
148
- interpolation : InterpolationMode = InterpolationMode .NEAREST , fill : Optional [List [float ]] = None ):
149
+ def __init__ (
150
+ self ,
151
+ policy : AutoAugmentPolicy = AutoAugmentPolicy .IMAGENET ,
152
+ interpolation : InterpolationMode = InterpolationMode .NEAREST ,
153
+ fill : Optional [List [float ]] = None
154
+ ) -> None :
149
155
super ().__init__ ()
150
156
self .policy = policy
151
157
self .interpolation = interpolation
@@ -163,7 +169,7 @@ def get_params(transform_num: int) -> Tuple[int, Tensor, Tensor]:
163
169
Returns:
164
170
params required by the autoaugment transformation
165
171
"""
166
- policy_id = torch .randint (transform_num , (1 ,)).item ()
172
+ policy_id = int ( torch .randint (transform_num , (1 ,)).item () )
167
173
probs = torch .rand ((2 ,))
168
174
signs = torch .randint (2 , (2 ,))
169
175
@@ -172,7 +178,7 @@ def get_params(transform_num: int) -> Tuple[int, Tensor, Tensor]:
172
178
def _get_op_meta (self , name : str ) -> Tuple [Optional [Tensor ], Optional [bool ]]:
173
179
return self ._op_meta [name ]
174
180
175
- def forward (self , img : Tensor ):
181
+ def forward (self , img : Tensor ) -> Tensor :
176
182
"""
177
183
img (PIL Image or Tensor): Image to be transformed.
178
184
@@ -233,5 +239,5 @@ def forward(self, img: Tensor):
233
239
234
240
return img
235
241
236
- def __repr__ (self ):
242
+ def __repr__ (self ) -> str :
237
243
return self .__class__ .__name__ + '(policy={}, fill={})' .format (self .policy , self .fill )
0 commit comments