Skip to content

Commit 59c9742

Browse files
authored
Added annotation typing to googlenet (#2858)
* style: Added annotation typing for googlenet * fix: Removed duplicate typing * refactor: Moved factory function after class definition to fix typing * fix: Fixed annotation typing * refactor: Removed un-necessary import * fix: Fixed typing * refactor: Moved back up helper function and quote typed it
1 parent 8263c8a commit 59c9742

File tree

1 file changed

+45
-23
lines changed

1 file changed

+45
-23
lines changed

torchvision/models/googlenet.py

Lines changed: 45 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33
import torch
44
import torch.nn as nn
55
import torch.nn.functional as F
6-
from torch.jit.annotations import Optional, Tuple
76
from torch import Tensor
87
from .utils import load_state_dict_from_url
8+
from typing import Optional, Tuple, List, Callable, Any
99

1010
__all__ = ['GoogLeNet', 'googlenet', "GoogLeNetOutputs", "_GoogLeNetOutputs"]
1111

@@ -23,7 +23,7 @@
2323
_GoogLeNetOutputs = GoogLeNetOutputs
2424

2525

26-
def googlenet(pretrained=False, progress=True, **kwargs):
26+
def googlenet(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> "GoogLeNet":
2727
r"""GoogLeNet (Inception v1) model architecture from
2828
`"Going Deeper with Convolutions" <http://arxiv.org/abs/1409.4842>`_.
2929
@@ -52,8 +52,8 @@ def googlenet(pretrained=False, progress=True, **kwargs):
5252
model.load_state_dict(state_dict)
5353
if not original_aux_logits:
5454
model.aux_logits = False
55-
model.aux1 = None
56-
model.aux2 = None
55+
model.aux1 = None # type: ignore[assignment]
56+
model.aux2 = None # type: ignore[assignment]
5757
return model
5858

5959
return GoogLeNet(**kwargs)
@@ -62,8 +62,14 @@ def googlenet(pretrained=False, progress=True, **kwargs):
6262
class GoogLeNet(nn.Module):
6363
__constants__ = ['aux_logits', 'transform_input']
6464

65-
def __init__(self, num_classes=1000, aux_logits=True, transform_input=False, init_weights=None,
66-
blocks=None):
65+
def __init__(
66+
self,
67+
num_classes: int = 1000,
68+
aux_logits: bool = True,
69+
transform_input: bool = False,
70+
init_weights: Optional[bool] = None,
71+
blocks: Optional[List[Callable[..., nn.Module]]] = None
72+
) -> None:
6773
super(GoogLeNet, self).__init__()
6874
if blocks is None:
6975
blocks = [BasicConv2d, Inception, InceptionAux]
@@ -104,8 +110,8 @@ def __init__(self, num_classes=1000, aux_logits=True, transform_input=False, ini
104110
self.aux1 = inception_aux_block(512, num_classes)
105111
self.aux2 = inception_aux_block(528, num_classes)
106112
else:
107-
self.aux1 = None
108-
self.aux2 = None
113+
self.aux1 = None # type: ignore[assignment]
114+
self.aux2 = None # type: ignore[assignment]
109115

110116
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
111117
self.dropout = nn.Dropout(0.2)
@@ -114,7 +120,7 @@ def __init__(self, num_classes=1000, aux_logits=True, transform_input=False, ini
114120
if init_weights:
115121
self._initialize_weights()
116122

117-
def _initialize_weights(self):
123+
def _initialize_weights(self) -> None:
118124
for m in self.modules():
119125
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
120126
import scipy.stats as stats
@@ -127,17 +133,15 @@ def _initialize_weights(self):
127133
nn.init.constant_(m.weight, 1)
128134
nn.init.constant_(m.bias, 0)
129135

130-
def _transform_input(self, x):
131-
# type: (Tensor) -> Tensor
136+
def _transform_input(self, x: Tensor) -> Tensor:
132137
if self.transform_input:
133138
x_ch0 = torch.unsqueeze(x[:, 0], 1) * (0.229 / 0.5) + (0.485 - 0.5) / 0.5
134139
x_ch1 = torch.unsqueeze(x[:, 1], 1) * (0.224 / 0.5) + (0.456 - 0.5) / 0.5
135140
x_ch2 = torch.unsqueeze(x[:, 2], 1) * (0.225 / 0.5) + (0.406 - 0.5) / 0.5
136141
x = torch.cat((x_ch0, x_ch1, x_ch2), 1)
137142
return x
138143

139-
def _forward(self, x):
140-
# type: (Tensor) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]
144+
def _forward(self, x: Tensor) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
141145
# N x 3 x 224 x 224
142146
x = self.conv1(x)
143147
# N x 64 x 112 x 112
@@ -199,8 +203,7 @@ def eager_outputs(self, x: Tensor, aux2: Tensor, aux1: Optional[Tensor]) -> Goog
199203
else:
200204
return x # type: ignore[return-value]
201205

202-
def forward(self, x):
203-
# type: (Tensor) -> GoogLeNetOutputs
206+
def forward(self, x: Tensor) -> GoogLeNetOutputs:
204207
x = self._transform_input(x)
205208
x, aux1, aux2 = self._forward(x)
206209
aux_defined = self.training and self.aux_logits
@@ -214,8 +217,17 @@ def forward(self, x):
214217

215218
class Inception(nn.Module):
216219

217-
def __init__(self, in_channels, ch1x1, ch3x3red, ch3x3, ch5x5red, ch5x5, pool_proj,
218-
conv_block=None):
220+
def __init__(
221+
self,
222+
in_channels: int,
223+
ch1x1: int,
224+
ch3x3red: int,
225+
ch3x3: int,
226+
ch5x5red: int,
227+
ch5x5: int,
228+
pool_proj: int,
229+
conv_block: Optional[Callable[..., nn.Module]] = None
230+
) -> None:
219231
super(Inception, self).__init__()
220232
if conv_block is None:
221233
conv_block = BasicConv2d
@@ -238,7 +250,7 @@ def __init__(self, in_channels, ch1x1, ch3x3red, ch3x3, ch5x5red, ch5x5, pool_pr
238250
conv_block(in_channels, pool_proj, kernel_size=1)
239251
)
240252

241-
def _forward(self, x):
253+
def _forward(self, x: Tensor) -> List[Tensor]:
242254
branch1 = self.branch1(x)
243255
branch2 = self.branch2(x)
244256
branch3 = self.branch3(x)
@@ -247,14 +259,19 @@ def _forward(self, x):
247259
outputs = [branch1, branch2, branch3, branch4]
248260
return outputs
249261

250-
def forward(self, x):
262+
def forward(self, x: Tensor) -> Tensor:
251263
outputs = self._forward(x)
252264
return torch.cat(outputs, 1)
253265

254266

255267
class InceptionAux(nn.Module):
256268

257-
def __init__(self, in_channels, num_classes, conv_block=None):
269+
def __init__(
270+
self,
271+
in_channels: int,
272+
num_classes: int,
273+
conv_block: Optional[Callable[..., nn.Module]] = None
274+
) -> None:
258275
super(InceptionAux, self).__init__()
259276
if conv_block is None:
260277
conv_block = BasicConv2d
@@ -263,7 +280,7 @@ def __init__(self, in_channels, num_classes, conv_block=None):
263280
self.fc1 = nn.Linear(2048, 1024)
264281
self.fc2 = nn.Linear(1024, num_classes)
265282

266-
def forward(self, x):
283+
def forward(self, x: Tensor) -> Tensor:
267284
# aux1: N x 512 x 14 x 14, aux2: N x 528 x 14 x 14
268285
x = F.adaptive_avg_pool2d(x, (4, 4))
269286
# aux1: N x 512 x 4 x 4, aux2: N x 528 x 4 x 4
@@ -283,12 +300,17 @@ def forward(self, x):
283300

284301
class BasicConv2d(nn.Module):
285302

286-
def __init__(self, in_channels, out_channels, **kwargs):
303+
def __init__(
304+
self,
305+
in_channels: int,
306+
out_channels: int,
307+
**kwargs: Any
308+
) -> None:
287309
super(BasicConv2d, self).__init__()
288310
self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
289311
self.bn = nn.BatchNorm2d(out_channels, eps=0.001)
290312

291-
def forward(self, x):
313+
def forward(self, x: Tensor) -> Tensor:
292314
x = self.conv(x)
293315
x = self.bn(x)
294316
return F.relu(x, inplace=True)

0 commit comments

Comments
 (0)