Skip to content

Commit 8cf4cd0

Browse files
frgfmvfdev-5
authored andcommitted
Added annotation typing to densenet (pytorch#2860)
* style: Added annotation typing for densenet * fix: Fixed import * refactor: Removed un-necessary import * fix: Fixed constructor typing * chore: Updated mypy.ini * fix: Fixed tuple typing * style: Ignored some mypy errors * style: Fixed typing * fix: Added missing constructor typing
1 parent 3bfaf9e commit 8cf4cd0

File tree

2 files changed

+64
-37
lines changed

2 files changed

+64
-37
lines changed

mypy.ini

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,6 @@ ignore_errors = True
1616

1717
ignore_errors = True
1818

19-
[mypy-torchvision.models.densenet.*]
20-
21-
ignore_errors = True
22-
2319
[mypy-torchvision.models.quantization.*]
2420

2521
ignore_errors = True

torchvision/models/densenet.py

Lines changed: 64 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from collections import OrderedDict
77
from .utils import load_state_dict_from_url
88
from torch import Tensor
9-
from torch.jit.annotations import List
9+
from typing import Any, List, Tuple
1010

1111

1212
__all__ = ['DenseNet', 'densenet121', 'densenet169', 'densenet201', 'densenet161']
@@ -20,56 +20,64 @@
2020

2121

2222
class _DenseLayer(nn.Module):
23-
def __init__(self, num_input_features, growth_rate, bn_size, drop_rate, memory_efficient=False):
23+
def __init__(
24+
self,
25+
num_input_features: int,
26+
growth_rate: int,
27+
bn_size: int,
28+
drop_rate: float,
29+
memory_efficient: bool = False
30+
) -> None:
2431
super(_DenseLayer, self).__init__()
25-
self.add_module('norm1', nn.BatchNorm2d(num_input_features)),
26-
self.add_module('relu1', nn.ReLU(inplace=True)),
32+
self.norm1: nn.BatchNorm2d
33+
self.add_module('norm1', nn.BatchNorm2d(num_input_features))
34+
self.relu1: nn.ReLU
35+
self.add_module('relu1', nn.ReLU(inplace=True))
36+
self.conv1: nn.Conv2d
2737
self.add_module('conv1', nn.Conv2d(num_input_features, bn_size *
2838
growth_rate, kernel_size=1, stride=1,
29-
bias=False)),
30-
self.add_module('norm2', nn.BatchNorm2d(bn_size * growth_rate)),
31-
self.add_module('relu2', nn.ReLU(inplace=True)),
39+
bias=False))
40+
self.norm2: nn.BatchNorm2d
41+
self.add_module('norm2', nn.BatchNorm2d(bn_size * growth_rate))
42+
self.relu2: nn.ReLU
43+
self.add_module('relu2', nn.ReLU(inplace=True))
44+
self.conv2: nn.Conv2d
3245
self.add_module('conv2', nn.Conv2d(bn_size * growth_rate, growth_rate,
3346
kernel_size=3, stride=1, padding=1,
34-
bias=False)),
47+
bias=False))
3548
self.drop_rate = float(drop_rate)
3649
self.memory_efficient = memory_efficient
3750

38-
def bn_function(self, inputs):
39-
# type: (List[Tensor]) -> Tensor
51+
def bn_function(self, inputs: List[Tensor]) -> Tensor:
4052
concated_features = torch.cat(inputs, 1)
4153
bottleneck_output = self.conv1(self.relu1(self.norm1(concated_features))) # noqa: T484
4254
return bottleneck_output
4355

4456
# todo: rewrite when torchscript supports any
45-
def any_requires_grad(self, input):
46-
# type: (List[Tensor]) -> bool
57+
def any_requires_grad(self, input: List[Tensor]) -> bool:
4758
for tensor in input:
4859
if tensor.requires_grad:
4960
return True
5061
return False
5162

5263
@torch.jit.unused # noqa: T484
53-
def call_checkpoint_bottleneck(self, input):
54-
# type: (List[Tensor]) -> Tensor
64+
def call_checkpoint_bottleneck(self, input: List[Tensor]) -> Tensor:
5565
def closure(*inputs):
5666
return self.bn_function(inputs)
5767

5868
return cp.checkpoint(closure, *input)
5969

6070
@torch.jit._overload_method # noqa: F811
61-
def forward(self, input):
62-
# type: (List[Tensor]) -> (Tensor)
71+
def forward(self, input: List[Tensor]) -> Tensor:
6372
pass
6473

65-
@torch.jit._overload_method # noqa: F811
66-
def forward(self, input):
67-
# type: (Tensor) -> (Tensor)
74+
@torch.jit._overload_method # type: ignore[no-redef] # noqa: F811
75+
def forward(self, input: Tensor) -> Tensor:
6876
pass
6977

7078
# torchscript does not yet support *args, so we overload method
7179
# allowing it to take either a List[Tensor] or single Tensor
72-
def forward(self, input): # noqa: F811
80+
def forward(self, input: Tensor) -> Tensor: # type: ignore[no-redef] # noqa: F811
7381
if isinstance(input, Tensor):
7482
prev_features = [input]
7583
else:
@@ -93,7 +101,15 @@ def forward(self, input): # noqa: F811
93101
class _DenseBlock(nn.ModuleDict):
94102
_version = 2
95103

96-
def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate, memory_efficient=False):
104+
def __init__(
105+
self,
106+
num_layers: int,
107+
num_input_features: int,
108+
bn_size: int,
109+
growth_rate: int,
110+
drop_rate: float,
111+
memory_efficient: bool = False
112+
) -> None:
97113
super(_DenseBlock, self).__init__()
98114
for i in range(num_layers):
99115
layer = _DenseLayer(
@@ -105,7 +121,7 @@ def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_ra
105121
)
106122
self.add_module('denselayer%d' % (i + 1), layer)
107123

108-
def forward(self, init_features):
124+
def forward(self, init_features: Tensor) -> Tensor: # type: ignore[override]
109125
features = [init_features]
110126
for name, layer in self.items():
111127
new_features = layer(features)
@@ -114,7 +130,7 @@ def forward(self, init_features):
114130

115131

116132
class _Transition(nn.Sequential):
117-
def __init__(self, num_input_features, num_output_features):
133+
def __init__(self, num_input_features: int, num_output_features: int) -> None:
118134
super(_Transition, self).__init__()
119135
self.add_module('norm', nn.BatchNorm2d(num_input_features))
120136
self.add_module('relu', nn.ReLU(inplace=True))
@@ -139,8 +155,16 @@ class DenseNet(nn.Module):
139155
but slower. Default: *False*. See `"paper" <https://arxiv.org/pdf/1707.06990.pdf>`_
140156
"""
141157

142-
def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16),
143-
num_init_features=64, bn_size=4, drop_rate=0, num_classes=1000, memory_efficient=False):
158+
def __init__(
159+
self,
160+
growth_rate: int = 32,
161+
block_config: Tuple[int, int, int, int] = (6, 12, 24, 16),
162+
num_init_features: int = 64,
163+
bn_size: int = 4,
164+
drop_rate: float = 0,
165+
num_classes: int = 1000,
166+
memory_efficient: bool = False
167+
) -> None:
144168

145169
super(DenseNet, self).__init__()
146170

@@ -188,7 +212,7 @@ def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16),
188212
elif isinstance(m, nn.Linear):
189213
nn.init.constant_(m.bias, 0)
190214

191-
def forward(self, x):
215+
def forward(self, x: Tensor) -> Tensor:
192216
features = self.features(x)
193217
out = F.relu(features, inplace=True)
194218
out = F.adaptive_avg_pool2d(out, (1, 1))
@@ -197,7 +221,7 @@ def forward(self, x):
197221
return out
198222

199223

200-
def _load_state_dict(model, model_url, progress):
224+
def _load_state_dict(model: nn.Module, model_url: str, progress: bool) -> None:
201225
# '.'s are no longer allowed in module names, but previous _DenseLayer
202226
# has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
203227
# They are also in the checkpoints in model_urls. This pattern is used
@@ -215,15 +239,22 @@ def _load_state_dict(model, model_url, progress):
215239
model.load_state_dict(state_dict)
216240

217241

218-
def _densenet(arch, growth_rate, block_config, num_init_features, pretrained, progress,
219-
**kwargs):
242+
def _densenet(
243+
arch: str,
244+
growth_rate: int,
245+
block_config: Tuple[int, int, int, int],
246+
num_init_features: int,
247+
pretrained: bool,
248+
progress: bool,
249+
**kwargs: Any
250+
) -> DenseNet:
220251
model = DenseNet(growth_rate, block_config, num_init_features, **kwargs)
221252
if pretrained:
222253
_load_state_dict(model, model_urls[arch], progress)
223254
return model
224255

225256

226-
def densenet121(pretrained=False, progress=True, **kwargs):
257+
def densenet121(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> DenseNet:
227258
r"""Densenet-121 model from
228259
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_
229260
@@ -237,7 +268,7 @@ def densenet121(pretrained=False, progress=True, **kwargs):
237268
**kwargs)
238269

239270

240-
def densenet161(pretrained=False, progress=True, **kwargs):
271+
def densenet161(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> DenseNet:
241272
r"""Densenet-161 model from
242273
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_
243274
@@ -251,7 +282,7 @@ def densenet161(pretrained=False, progress=True, **kwargs):
251282
**kwargs)
252283

253284

254-
def densenet169(pretrained=False, progress=True, **kwargs):
285+
def densenet169(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> DenseNet:
255286
r"""Densenet-169 model from
256287
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_
257288
@@ -265,7 +296,7 @@ def densenet169(pretrained=False, progress=True, **kwargs):
265296
**kwargs)
266297

267298

268-
def densenet201(pretrained=False, progress=True, **kwargs):
299+
def densenet201(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> DenseNet:
269300
r"""Densenet-201 model from
270301
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_
271302

0 commit comments

Comments
 (0)