Skip to content

Commit d19b182

Browse files
frgfmvfdev-5
authored andcommitted
style: Added annotation typing for alexnet (pytorch#2859)
1 parent 3bba998 commit d19b182

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

torchvision/models/alexnet.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import torch
22
import torch.nn as nn
33
from .utils import load_state_dict_from_url
4+
from typing import Any
45

56

67
__all__ = ['AlexNet', 'alexnet']
@@ -13,7 +14,7 @@
1314

1415
class AlexNet(nn.Module):
1516

16-
def __init__(self, num_classes=1000):
17+
def __init__(self, num_classes: int = 1000):
1718
super(AlexNet, self).__init__()
1819
self.features = nn.Sequential(
1920
nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
@@ -41,15 +42,15 @@ def __init__(self, num_classes=1000):
4142
nn.Linear(4096, num_classes),
4243
)
4344

44-
def forward(self, x):
45+
def forward(self, x: torch.Tensor) -> torch.Tensor:
4546
x = self.features(x)
4647
x = self.avgpool(x)
4748
x = torch.flatten(x, 1)
4849
x = self.classifier(x)
4950
return x
5051

5152

52-
def alexnet(pretrained=False, progress=True, **kwargs):
53+
def alexnet(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> AlexNet:
5354
r"""AlexNet model architecture from the
5455
`"One weird trick..." <https://arxiv.org/abs/1404.5997>`_ paper.
5556

0 commit comments

Comments
 (0)