6
6
from collections import OrderedDict
7
7
from .utils import load_state_dict_from_url
8
8
from torch import Tensor
9
- from torch . jit . annotations import List
9
+ from typing import Any , List , Tuple
10
10
11
11
12
12
__all__ = ['DenseNet' , 'densenet121' , 'densenet169' , 'densenet201' , 'densenet161' ]
20
20
21
21
22
22
class _DenseLayer (nn .Module ):
23
- def __init__ (self , num_input_features , growth_rate , bn_size , drop_rate , memory_efficient = False ):
23
+ def __init__ (
24
+ self ,
25
+ num_input_features : int ,
26
+ growth_rate : int ,
27
+ bn_size : int ,
28
+ drop_rate : float ,
29
+ memory_efficient : bool = False
30
+ ) -> None :
24
31
super (_DenseLayer , self ).__init__ ()
25
- self .add_module ('norm1' , nn .BatchNorm2d (num_input_features )),
26
- self .add_module ('relu1' , nn .ReLU (inplace = True )),
32
+ self .norm1 : nn .BatchNorm2d
33
+ self .add_module ('norm1' , nn .BatchNorm2d (num_input_features ))
34
+ self .relu1 : nn .ReLU
35
+ self .add_module ('relu1' , nn .ReLU (inplace = True ))
36
+ self .conv1 : nn .Conv2d
27
37
self .add_module ('conv1' , nn .Conv2d (num_input_features , bn_size *
28
38
growth_rate , kernel_size = 1 , stride = 1 ,
29
- bias = False )),
30
- self .add_module ('norm2' , nn .BatchNorm2d (bn_size * growth_rate )),
31
- self .add_module ('relu2' , nn .ReLU (inplace = True )),
39
+ bias = False ))
40
+ self .norm2 : nn .BatchNorm2d
41
+ self .add_module ('norm2' , nn .BatchNorm2d (bn_size * growth_rate ))
42
+ self .relu2 : nn .ReLU
43
+ self .add_module ('relu2' , nn .ReLU (inplace = True ))
44
+ self .conv2 : nn .Conv2d
32
45
self .add_module ('conv2' , nn .Conv2d (bn_size * growth_rate , growth_rate ,
33
46
kernel_size = 3 , stride = 1 , padding = 1 ,
34
- bias = False )),
47
+ bias = False ))
35
48
self .drop_rate = float (drop_rate )
36
49
self .memory_efficient = memory_efficient
37
50
38
- def bn_function (self , inputs ):
39
- # type: (List[Tensor]) -> Tensor
51
+ def bn_function (self , inputs : List [Tensor ]) -> Tensor :
40
52
concated_features = torch .cat (inputs , 1 )
41
53
bottleneck_output = self .conv1 (self .relu1 (self .norm1 (concated_features ))) # noqa: T484
42
54
return bottleneck_output
43
55
44
56
# todo: rewrite when torchscript supports any
45
- def any_requires_grad (self , input ):
46
- # type: (List[Tensor]) -> bool
57
+ def any_requires_grad (self , input : List [Tensor ]) -> bool :
47
58
for tensor in input :
48
59
if tensor .requires_grad :
49
60
return True
50
61
return False
51
62
52
63
@torch .jit .unused # noqa: T484
53
- def call_checkpoint_bottleneck (self , input ):
54
- # type: (List[Tensor]) -> Tensor
64
+ def call_checkpoint_bottleneck (self , input : List [Tensor ]) -> Tensor :
55
65
def closure (* inputs ):
56
66
return self .bn_function (inputs )
57
67
58
68
return cp .checkpoint (closure , * input )
59
69
60
70
@torch .jit ._overload_method # noqa: F811
61
- def forward (self , input ):
62
- # type: (List[Tensor]) -> (Tensor)
71
+ def forward (self , input : List [Tensor ]) -> Tensor :
63
72
pass
64
73
65
- @torch .jit ._overload_method # noqa: F811
66
- def forward (self , input ):
67
- # type: (Tensor) -> (Tensor)
74
+ @torch .jit ._overload_method # type: ignore[no-redef] # noqa: F811
75
+ def forward (self , input : Tensor ) -> Tensor :
68
76
pass
69
77
70
78
# torchscript does not yet support *args, so we overload method
71
79
# allowing it to take either a List[Tensor] or single Tensor
72
- def forward (self , input ): # noqa: F811
80
+ def forward (self , input : Tensor ) -> Tensor : # type: ignore[no-redef] # noqa: F811
73
81
if isinstance (input , Tensor ):
74
82
prev_features = [input ]
75
83
else :
@@ -93,7 +101,15 @@ def forward(self, input): # noqa: F811
93
101
class _DenseBlock (nn .ModuleDict ):
94
102
_version = 2
95
103
96
- def __init__ (self , num_layers , num_input_features , bn_size , growth_rate , drop_rate , memory_efficient = False ):
104
+ def __init__ (
105
+ self ,
106
+ num_layers : int ,
107
+ num_input_features : int ,
108
+ bn_size : int ,
109
+ growth_rate : int ,
110
+ drop_rate : float ,
111
+ memory_efficient : bool = False
112
+ ) -> None :
97
113
super (_DenseBlock , self ).__init__ ()
98
114
for i in range (num_layers ):
99
115
layer = _DenseLayer (
@@ -105,7 +121,7 @@ def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_ra
105
121
)
106
122
self .add_module ('denselayer%d' % (i + 1 ), layer )
107
123
108
- def forward (self , init_features ):
124
+ def forward (self , init_features : Tensor ) -> Tensor : # type: ignore[override]
109
125
features = [init_features ]
110
126
for name , layer in self .items ():
111
127
new_features = layer (features )
@@ -114,7 +130,7 @@ def forward(self, init_features):
114
130
115
131
116
132
class _Transition (nn .Sequential ):
117
- def __init__ (self , num_input_features , num_output_features ) :
133
+ def __init__ (self , num_input_features : int , num_output_features : int ) -> None :
118
134
super (_Transition , self ).__init__ ()
119
135
self .add_module ('norm' , nn .BatchNorm2d (num_input_features ))
120
136
self .add_module ('relu' , nn .ReLU (inplace = True ))
@@ -139,8 +155,16 @@ class DenseNet(nn.Module):
139
155
but slower. Default: *False*. See `"paper" <https://arxiv.org/pdf/1707.06990.pdf>`_
140
156
"""
141
157
142
- def __init__ (self , growth_rate = 32 , block_config = (6 , 12 , 24 , 16 ),
143
- num_init_features = 64 , bn_size = 4 , drop_rate = 0 , num_classes = 1000 , memory_efficient = False ):
158
+ def __init__ (
159
+ self ,
160
+ growth_rate : int = 32 ,
161
+ block_config : Tuple [int , int , int , int ] = (6 , 12 , 24 , 16 ),
162
+ num_init_features : int = 64 ,
163
+ bn_size : int = 4 ,
164
+ drop_rate : float = 0 ,
165
+ num_classes : int = 1000 ,
166
+ memory_efficient : bool = False
167
+ ) -> None :
144
168
145
169
super (DenseNet , self ).__init__ ()
146
170
@@ -188,7 +212,7 @@ def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16),
188
212
elif isinstance (m , nn .Linear ):
189
213
nn .init .constant_ (m .bias , 0 )
190
214
191
- def forward (self , x ) :
215
+ def forward (self , x : Tensor ) -> Tensor :
192
216
features = self .features (x )
193
217
out = F .relu (features , inplace = True )
194
218
out = F .adaptive_avg_pool2d (out , (1 , 1 ))
@@ -197,7 +221,7 @@ def forward(self, x):
197
221
return out
198
222
199
223
200
- def _load_state_dict (model , model_url , progress ) :
224
+ def _load_state_dict (model : nn . Module , model_url : str , progress : bool ) -> None :
201
225
# '.'s are no longer allowed in module names, but previous _DenseLayer
202
226
# has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
203
227
# They are also in the checkpoints in model_urls. This pattern is used
@@ -215,15 +239,22 @@ def _load_state_dict(model, model_url, progress):
215
239
model .load_state_dict (state_dict )
216
240
217
241
218
- def _densenet (arch , growth_rate , block_config , num_init_features , pretrained , progress ,
219
- ** kwargs ):
242
+ def _densenet (
243
+ arch : str ,
244
+ growth_rate : int ,
245
+ block_config : Tuple [int , int , int , int ],
246
+ num_init_features : int ,
247
+ pretrained : bool ,
248
+ progress : bool ,
249
+ ** kwargs : Any
250
+ ) -> DenseNet :
220
251
model = DenseNet (growth_rate , block_config , num_init_features , ** kwargs )
221
252
if pretrained :
222
253
_load_state_dict (model , model_urls [arch ], progress )
223
254
return model
224
255
225
256
226
- def densenet121 (pretrained = False , progress = True , ** kwargs ) :
257
+ def densenet121 (pretrained : bool = False , progress : bool = True , ** kwargs : Any ) -> DenseNet :
227
258
r"""Densenet-121 model from
228
259
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_
229
260
@@ -237,7 +268,7 @@ def densenet121(pretrained=False, progress=True, **kwargs):
237
268
** kwargs )
238
269
239
270
240
- def densenet161 (pretrained = False , progress = True , ** kwargs ) :
271
+ def densenet161 (pretrained : bool = False , progress : bool = True , ** kwargs : Any ) -> DenseNet :
241
272
r"""Densenet-161 model from
242
273
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_
243
274
@@ -251,7 +282,7 @@ def densenet161(pretrained=False, progress=True, **kwargs):
251
282
** kwargs )
252
283
253
284
254
- def densenet169 (pretrained = False , progress = True , ** kwargs ) :
285
+ def densenet169 (pretrained : bool = False , progress : bool = True , ** kwargs : Any ) -> DenseNet :
255
286
r"""Densenet-169 model from
256
287
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_
257
288
@@ -265,7 +296,7 @@ def densenet169(pretrained=False, progress=True, **kwargs):
265
296
** kwargs )
266
297
267
298
268
- def densenet201 (pretrained = False , progress = True , ** kwargs ) :
299
+ def densenet201 (pretrained : bool = False , progress : bool = True , ** kwargs : Any ) -> DenseNet :
269
300
r"""Densenet-201 model from
270
301
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_
271
302
0 commit comments