Skip to content

Commit 085c9cc

Browse files
frgfmvfdev-5
authored andcommitted
Added annotation typing to inception (pytorch#2857)
* style: Added annotation typing for inception * refactor: Moved factory function after class definition * style: Changed attribute setting for type hinting * refactor: Removed un-necessary import * fix: Fixed typing in constructors * fix: Fixed kwargs typing * style: Fixed lint * refactor: Moved helpers function back and quote typed it
1 parent 2e17fe1 commit 085c9cc

File tree

1 file changed

+67
-29
lines changed

1 file changed

+67
-29
lines changed

torchvision/models/inception.py

Lines changed: 67 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33
import torch
44
import torch.nn as nn
55
import torch.nn.functional as F
6-
from torch.jit.annotations import Optional
76
from torch import Tensor
87
from .utils import load_state_dict_from_url
8+
from typing import Callable, Any, Optional, Tuple, List
99

1010

1111
__all__ = ['Inception3', 'inception_v3', 'InceptionOutputs', '_InceptionOutputs']
@@ -24,7 +24,7 @@
2424
_InceptionOutputs = InceptionOutputs
2525

2626

27-
def inception_v3(pretrained=False, progress=True, **kwargs):
27+
def inception_v3(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> "Inception3":
2828
r"""Inception v3 model architecture from
2929
`"Rethinking the Inception Architecture for Computer Vision" <http://arxiv.org/abs/1512.00567>`_.
3030
@@ -63,8 +63,14 @@ def inception_v3(pretrained=False, progress=True, **kwargs):
6363

6464
class Inception3(nn.Module):
6565

66-
def __init__(self, num_classes=1000, aux_logits=True, transform_input=False,
67-
inception_blocks=None, init_weights=None):
66+
def __init__(
67+
self,
68+
num_classes: int = 1000,
69+
aux_logits: bool = True,
70+
transform_input: bool = False,
71+
inception_blocks: Optional[List[Callable[..., nn.Module]]] = None,
72+
init_weights: Optional[bool] = None
73+
) -> None:
6874
super(Inception3, self).__init__()
6975
if inception_blocks is None:
7076
inception_blocks = [
@@ -124,15 +130,15 @@ def __init__(self, num_classes=1000, aux_logits=True, transform_input=False,
124130
nn.init.constant_(m.weight, 1)
125131
nn.init.constant_(m.bias, 0)
126132

127-
def _transform_input(self, x):
133+
def _transform_input(self, x: Tensor) -> Tensor:
128134
if self.transform_input:
129135
x_ch0 = torch.unsqueeze(x[:, 0], 1) * (0.229 / 0.5) + (0.485 - 0.5) / 0.5
130136
x_ch1 = torch.unsqueeze(x[:, 1], 1) * (0.224 / 0.5) + (0.456 - 0.5) / 0.5
131137
x_ch2 = torch.unsqueeze(x[:, 2], 1) * (0.225 / 0.5) + (0.406 - 0.5) / 0.5
132138
x = torch.cat((x_ch0, x_ch1, x_ch2), 1)
133139
return x
134140

135-
def _forward(self, x):
141+
def _forward(self, x: Tensor) -> Tuple[Tensor, Optional[Tensor]]:
136142
# N x 3 x 299 x 299
137143
x = self.Conv2d_1a_3x3(x)
138144
# N x 32 x 149 x 149
@@ -188,13 +194,13 @@ def _forward(self, x):
188194
return x, aux
189195

190196
@torch.jit.unused
191-
def eager_outputs(self, x: torch.Tensor, aux: Optional[Tensor]) -> InceptionOutputs:
197+
def eager_outputs(self, x: Tensor, aux: Optional[Tensor]) -> InceptionOutputs:
192198
if self.training and self.aux_logits:
193199
return InceptionOutputs(x, aux)
194200
else:
195201
return x # type: ignore[return-value]
196202

197-
def forward(self, x):
203+
def forward(self, x: Tensor) -> InceptionOutputs:
198204
x = self._transform_input(x)
199205
x, aux = self._forward(x)
200206
aux_defined = self.training and self.aux_logits
@@ -208,7 +214,12 @@ def forward(self, x):
208214

209215
class InceptionA(nn.Module):
210216

211-
def __init__(self, in_channels, pool_features, conv_block=None):
217+
def __init__(
218+
self,
219+
in_channels: int,
220+
pool_features: int,
221+
conv_block: Optional[Callable[..., nn.Module]] = None
222+
) -> None:
212223
super(InceptionA, self).__init__()
213224
if conv_block is None:
214225
conv_block = BasicConv2d
@@ -223,7 +234,7 @@ def __init__(self, in_channels, pool_features, conv_block=None):
223234

224235
self.branch_pool = conv_block(in_channels, pool_features, kernel_size=1)
225236

226-
def _forward(self, x):
237+
def _forward(self, x: Tensor) -> List[Tensor]:
227238
branch1x1 = self.branch1x1(x)
228239

229240
branch5x5 = self.branch5x5_1(x)
@@ -239,14 +250,18 @@ def _forward(self, x):
239250
outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
240251
return outputs
241252

242-
def forward(self, x):
253+
def forward(self, x: Tensor) -> Tensor:
243254
outputs = self._forward(x)
244255
return torch.cat(outputs, 1)
245256

246257

247258
class InceptionB(nn.Module):
248259

249-
def __init__(self, in_channels, conv_block=None):
260+
def __init__(
261+
self,
262+
in_channels: int,
263+
conv_block: Optional[Callable[..., nn.Module]] = None
264+
) -> None:
250265
super(InceptionB, self).__init__()
251266
if conv_block is None:
252267
conv_block = BasicConv2d
@@ -256,7 +271,7 @@ def __init__(self, in_channels, conv_block=None):
256271
self.branch3x3dbl_2 = conv_block(64, 96, kernel_size=3, padding=1)
257272
self.branch3x3dbl_3 = conv_block(96, 96, kernel_size=3, stride=2)
258273

259-
def _forward(self, x):
274+
def _forward(self, x: Tensor) -> List[Tensor]:
260275
branch3x3 = self.branch3x3(x)
261276

262277
branch3x3dbl = self.branch3x3dbl_1(x)
@@ -268,14 +283,19 @@ def _forward(self, x):
268283
outputs = [branch3x3, branch3x3dbl, branch_pool]
269284
return outputs
270285

271-
def forward(self, x):
286+
def forward(self, x: Tensor) -> Tensor:
272287
outputs = self._forward(x)
273288
return torch.cat(outputs, 1)
274289

275290

276291
class InceptionC(nn.Module):
277292

278-
def __init__(self, in_channels, channels_7x7, conv_block=None):
293+
def __init__(
294+
self,
295+
in_channels: int,
296+
channels_7x7: int,
297+
conv_block: Optional[Callable[..., nn.Module]] = None
298+
) -> None:
279299
super(InceptionC, self).__init__()
280300
if conv_block is None:
281301
conv_block = BasicConv2d
@@ -294,7 +314,7 @@ def __init__(self, in_channels, channels_7x7, conv_block=None):
294314

295315
self.branch_pool = conv_block(in_channels, 192, kernel_size=1)
296316

297-
def _forward(self, x):
317+
def _forward(self, x: Tensor) -> List[Tensor]:
298318
branch1x1 = self.branch1x1(x)
299319

300320
branch7x7 = self.branch7x7_1(x)
@@ -313,14 +333,18 @@ def _forward(self, x):
313333
outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
314334
return outputs
315335

316-
def forward(self, x):
336+
def forward(self, x: Tensor) -> Tensor:
317337
outputs = self._forward(x)
318338
return torch.cat(outputs, 1)
319339

320340

321341
class InceptionD(nn.Module):
322342

323-
def __init__(self, in_channels, conv_block=None):
343+
def __init__(
344+
self,
345+
in_channels: int,
346+
conv_block: Optional[Callable[..., nn.Module]] = None
347+
) -> None:
324348
super(InceptionD, self).__init__()
325349
if conv_block is None:
326350
conv_block = BasicConv2d
@@ -332,7 +356,7 @@ def __init__(self, in_channels, conv_block=None):
332356
self.branch7x7x3_3 = conv_block(192, 192, kernel_size=(7, 1), padding=(3, 0))
333357
self.branch7x7x3_4 = conv_block(192, 192, kernel_size=3, stride=2)
334358

335-
def _forward(self, x):
359+
def _forward(self, x: Tensor) -> List[Tensor]:
336360
branch3x3 = self.branch3x3_1(x)
337361
branch3x3 = self.branch3x3_2(branch3x3)
338362

@@ -345,14 +369,18 @@ def _forward(self, x):
345369
outputs = [branch3x3, branch7x7x3, branch_pool]
346370
return outputs
347371

348-
def forward(self, x):
372+
def forward(self, x: Tensor) -> Tensor:
349373
outputs = self._forward(x)
350374
return torch.cat(outputs, 1)
351375

352376

353377
class InceptionE(nn.Module):
354378

355-
def __init__(self, in_channels, conv_block=None):
379+
def __init__(
380+
self,
381+
in_channels: int,
382+
conv_block: Optional[Callable[..., nn.Module]] = None
383+
) -> None:
356384
super(InceptionE, self).__init__()
357385
if conv_block is None:
358386
conv_block = BasicConv2d
@@ -369,7 +397,7 @@ def __init__(self, in_channels, conv_block=None):
369397

370398
self.branch_pool = conv_block(in_channels, 192, kernel_size=1)
371399

372-
def _forward(self, x):
400+
def _forward(self, x: Tensor) -> List[Tensor]:
373401
branch1x1 = self.branch1x1(x)
374402

375403
branch3x3 = self.branch3x3_1(x)
@@ -393,24 +421,29 @@ def _forward(self, x):
393421
outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
394422
return outputs
395423

396-
def forward(self, x):
424+
def forward(self, x: Tensor) -> Tensor:
397425
outputs = self._forward(x)
398426
return torch.cat(outputs, 1)
399427

400428

401429
class InceptionAux(nn.Module):
402430

403-
def __init__(self, in_channels, num_classes, conv_block=None):
431+
def __init__(
432+
self,
433+
in_channels: int,
434+
num_classes: int,
435+
conv_block: Optional[Callable[..., nn.Module]] = None
436+
) -> None:
404437
super(InceptionAux, self).__init__()
405438
if conv_block is None:
406439
conv_block = BasicConv2d
407440
self.conv0 = conv_block(in_channels, 128, kernel_size=1)
408441
self.conv1 = conv_block(128, 768, kernel_size=5)
409-
self.conv1.stddev = 0.01
442+
self.conv1.stddev = 0.01 # type: ignore[assignment]
410443
self.fc = nn.Linear(768, num_classes)
411-
self.fc.stddev = 0.001
444+
self.fc.stddev = 0.001 # type: ignore[assignment]
412445

413-
def forward(self, x):
446+
def forward(self, x: Tensor) -> Tensor:
414447
# N x 768 x 17 x 17
415448
x = F.avg_pool2d(x, kernel_size=5, stride=3)
416449
# N x 768 x 5 x 5
@@ -430,12 +463,17 @@ def forward(self, x):
430463

431464
class BasicConv2d(nn.Module):
432465

433-
def __init__(self, in_channels, out_channels, **kwargs):
466+
def __init__(
467+
self,
468+
in_channels: int,
469+
out_channels: int,
470+
**kwargs: Any
471+
) -> None:
434472
super(BasicConv2d, self).__init__()
435473
self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
436474
self.bn = nn.BatchNorm2d(out_channels, eps=0.001)
437475

438-
def forward(self, x):
476+
def forward(self, x: Tensor) -> Tensor:
439477
x = self.conv(x)
440478
x = self.bn(x)
441479
return F.relu(x, inplace=True)

0 commit comments

Comments
 (0)