diff --git a/torchvision/models/alexnet.py b/torchvision/models/alexnet.py index a0126312d1a..e2ec53d9a82 100644 --- a/torchvision/models/alexnet.py +++ b/torchvision/models/alexnet.py @@ -1,6 +1,7 @@ import torch import torch.nn as nn from .utils import load_state_dict_from_url +from typing import Any __all__ = ['AlexNet', 'alexnet'] @@ -13,7 +14,7 @@ class AlexNet(nn.Module): - def __init__(self, num_classes=1000): + def __init__(self, num_classes: int = 1000): super(AlexNet, self).__init__() self.features = nn.Sequential( nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2), @@ -41,7 +42,7 @@ def __init__(self, num_classes=1000): nn.Linear(4096, num_classes), ) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.features(x) x = self.avgpool(x) x = torch.flatten(x, 1) @@ -49,7 +50,7 @@ def forward(self, x): return x -def alexnet(pretrained=False, progress=True, **kwargs): +def alexnet(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> AlexNet: r"""AlexNet model architecture from the `"One weird trick..." `_ paper.