Skip to content

Commit 14484de

Browse files
yiwen-songfmassadatumbox
authored andcommitted
[fbsync] Added typing annotations to transforms/autoaugment (#4226)
Summary: * style: Added typing annotations * style: Fixed typing * style: Fixed typing * Remove unnecessary any. * Update mypy.ini Reviewed By: NicolasHug Differential Revision: D30417197 fbshipit-source-id: c801be04c456b4ec6c7794b9b1c89a79fe8773c6 Co-authored-by: Francisco Massa <[email protected]> Co-authored-by: Vasilis Vryniotis <[email protected]> Co-authored-by: Vasilis Vryniotis <[email protected]>
1 parent 1ca922c commit 14484de

File tree

2 files changed

+14
-12
lines changed

2 files changed

+14
-12
lines changed

mypy.ini

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,6 @@ ignore_errors = True
3636

3737
ignore_errors = True
3838

39-
[mypy-torchvision.transforms.autoaugment.*]
40-
41-
ignore_errors = True
42-
4339
[mypy-PIL.*]
4440

4541
ignore_missing_imports = True

torchvision/transforms/autoaugment.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from enum import Enum
55
from torch import Tensor
6-
from typing import List, Tuple, Optional
6+
from typing import List, Tuple, Optional, Dict
77

88
from . import functional as F, InterpolationMode
99

@@ -19,7 +19,9 @@ class AutoAugmentPolicy(Enum):
1919
SVHN = "svhn"
2020

2121

22-
def _get_transforms(policy: AutoAugmentPolicy):
22+
def _get_transforms( # type: ignore[return]
23+
policy: AutoAugmentPolicy
24+
) -> List[Tuple[Tuple[str, float, Optional[int]], Tuple[str, float, Optional[int]]]]:
2325
if policy == AutoAugmentPolicy.IMAGENET:
2426
return [
2527
(("Posterize", 0.4, 8), ("Rotate", 0.6, 9)),
@@ -106,7 +108,7 @@ def _get_transforms(policy: AutoAugmentPolicy):
106108
]
107109

108110

109-
def _get_magnitudes():
111+
def _get_magnitudes() -> Dict[str, Tuple[Optional[Tensor], Optional[bool]]]:
110112
_BINS = 10
111113
return {
112114
# name: (magnitudes, signed)
@@ -144,8 +146,12 @@ class AutoAugment(torch.nn.Module):
144146
image. If given a number, the value is used for all bands respectively.
145147
"""
146148

147-
def __init__(self, policy: AutoAugmentPolicy = AutoAugmentPolicy.IMAGENET,
148-
interpolation: InterpolationMode = InterpolationMode.NEAREST, fill: Optional[List[float]] = None):
149+
def __init__(
150+
self,
151+
policy: AutoAugmentPolicy = AutoAugmentPolicy.IMAGENET,
152+
interpolation: InterpolationMode = InterpolationMode.NEAREST,
153+
fill: Optional[List[float]] = None
154+
) -> None:
149155
super().__init__()
150156
self.policy = policy
151157
self.interpolation = interpolation
@@ -163,7 +169,7 @@ def get_params(transform_num: int) -> Tuple[int, Tensor, Tensor]:
163169
Returns:
164170
params required by the autoaugment transformation
165171
"""
166-
policy_id = torch.randint(transform_num, (1,)).item()
172+
policy_id = int(torch.randint(transform_num, (1,)).item())
167173
probs = torch.rand((2,))
168174
signs = torch.randint(2, (2,))
169175

@@ -172,7 +178,7 @@ def get_params(transform_num: int) -> Tuple[int, Tensor, Tensor]:
172178
def _get_op_meta(self, name: str) -> Tuple[Optional[Tensor], Optional[bool]]:
173179
return self._op_meta[name]
174180

175-
def forward(self, img: Tensor):
181+
def forward(self, img: Tensor) -> Tensor:
176182
"""
177183
img (PIL Image or Tensor): Image to be transformed.
178184
@@ -233,5 +239,5 @@ def forward(self, img: Tensor):
233239

234240
return img
235241

236-
def __repr__(self):
242+
def __repr__(self) -> str:
237243
return self.__class__.__name__ + '(policy={}, fill={})'.format(self.policy, self.fill)

0 commit comments

Comments
 (0)