Skip to content

Commit fd08315

Browse files
committed
Add inline documentation for models
Also add pre-trained ResNet-152 model. ResNet-152: Prec@1 78.312 Prec@5 94.046
1 parent a919deb commit fd08315

File tree

4 files changed

+74
-4
lines changed

4 files changed

+74
-4
lines changed

torchvision/models/__init__.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,33 @@
1+
"""The models subpackage contains definitions for the following model
2+
architectures:
3+
4+
- `AlexNet`_
5+
- `VGG`_
6+
- `ResNet`_
7+
8+
You can construct a model with random weights by calling its constructor:
9+
10+
.. code:: python
11+
12+
import torchvision.models as models
13+
resnet18 = models.resnet18()
14+
alexnet = models.alexnet()
15+
16+
We provide pre-trained models for the ResNet variants and AlexNet, using the
17+
PyTorch :mod:`torch.utils.model_zoo`. These can constructed by passing
18+
``pretrained=True``:
19+
20+
.. code:: python
21+
22+
import torchvision.models as models
23+
resnet18 = models.resnet18(pretrained=True)
24+
alexnet = models.alexnet(pretrained=True)
25+
26+
.. _AlexNet: https://arxiv.org/abs/1404.5997
27+
.. _VGG: https://arxiv.org/abs/1409.1556
28+
.. _ResNet: https://arxiv.org/abs/1512.03385
29+
"""
30+
131
from .alexnet import *
232
from .resnet import *
333
from .vgg import *

torchvision/models/alexnet.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,11 @@ def forward(self, x):
4646

4747

4848
def alexnet(pretrained=False):
49-
r"""AlexNet model architecture from the "One weird trick" paper.
50-
https://arxiv.org/abs/1404.5997
49+
r"""AlexNet model architecture from the
50+
`"One weird trick..." <https://arxiv.org/abs/1404.5997>`_ paper.
51+
52+
Args:
53+
pretrained (bool): If True, returns a model pre-trained on ImageNet
5154
"""
5255
model = AlexNet()
5356
if pretrained:

torchvision/models/resnet.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
'resnet34': 'https://s3.amazonaws.com/pytorch/models/resnet34-333f7ec4.pth',
1313
'resnet50': 'https://s3.amazonaws.com/pytorch/models/resnet50-19c8e357.pth',
1414
'resnet101': 'https://s3.amazonaws.com/pytorch/models/resnet101-5d3b4d8f.pth',
15+
'resnet152': 'https://s3.amazonaws.com/pytorch/models/resnet152-b121ed2d.pth',
1516
}
1617

1718

@@ -152,32 +153,60 @@ def forward(self, x):
152153

153154

154155
def resnet18(pretrained=False):
156+
"""Constructs a ResNet-18 model.
157+
158+
Args:
159+
pretrained (bool): If True, returns a model pre-trained on ImageNet
160+
"""
155161
model = ResNet(BasicBlock, [2, 2, 2, 2])
156162
if pretrained:
157163
model.load_state_dict(model_zoo.load_url(model_urls['resnet18']))
158164
return model
159165

160166

161167
def resnet34(pretrained=False):
168+
"""Constructs a ResNet-34 model.
169+
170+
Args:
171+
pretrained (bool): If True, returns a model pre-trained on ImageNet
172+
"""
162173
model = ResNet(BasicBlock, [3, 4, 6, 3])
163174
if pretrained:
164175
model.load_state_dict(model_zoo.load_url(model_urls['resnet34']))
165176
return model
166177

167178

168179
def resnet50(pretrained=False):
180+
"""Constructs a ResNet-50 model.
181+
182+
Args:
183+
pretrained (bool): If True, returns a model pre-trained on ImageNet
184+
"""
169185
model = ResNet(Bottleneck, [3, 4, 6, 3])
170186
if pretrained:
171187
model.load_state_dict(model_zoo.load_url(model_urls['resnet50']))
172188
return model
173189

174190

175191
def resnet101(pretrained=False):
192+
"""Constructs a ResNet-101 model.
193+
194+
Args:
195+
pretrained (bool): If True, returns a model pre-trained on ImageNet
196+
"""
176197
model = ResNet(Bottleneck, [3, 4, 23, 3])
177198
if pretrained:
178199
model.load_state_dict(model_zoo.load_url(model_urls['resnet101']))
179200
return model
180201

181202

182-
def resnet152():
183-
return ResNet(Bottleneck, [3, 8, 36, 3])
203+
def resnet152(pretrained=False):
204+
"""Constructs a ResNet-152 model.
205+
206+
Args:
207+
pretrained (bool): If True, returns a model pre-trained on ImageNet
208+
"""
209+
model = ResNet(Bottleneck, [3, 8, 36, 3])
210+
if pretrained:
211+
model.load_state_dict(model_zoo.load_url(model_urls['resnet152']))
212+
return model

torchvision/models/vgg.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,32 +53,40 @@ def make_layers(cfg, batch_norm=False):
5353

5454

5555
def vgg11():
56+
"""VGG 11-layer model (configuration "A")"""
5657
return VGG(make_layers(cfg['A']))
5758

5859

5960
def vgg11_bn():
61+
"""VGG 11-layer model (configuration "A") with batch normalization"""
6062
return VGG(make_layers(cfg['A'], batch_norm=True))
6163

6264

6365
def vgg13():
66+
"""VGG 13-layer model (configuration "B")"""
6467
return VGG(make_layers(cfg['B']))
6568

6669

6770
def vgg13_bn():
71+
"""VGG 13-layer model (configuration "B") with batch normalization"""
6872
return VGG(make_layers(cfg['B'], batch_norm=True))
6973

7074

7175
def vgg16():
76+
"""VGG 11-layer model (configuration "B")"""
7277
return VGG(make_layers(cfg['D']))
7378

7479

7580
def vgg16_bn():
81+
"""VGG 16-layer model (configuration "D") with batch normalization"""
7682
return VGG(make_layers(cfg['D'], batch_norm=True))
7783

7884

7985
def vgg19():
86+
"""VGG 19-layer model (configuration "D")"""
8087
return VGG(make_layers(cfg['E']))
8188

8289

8390
def vgg19_bn():
91+
"""VGG 19-layer model (configuration 'E') with batch normalization"""
8492
return VGG(make_layers(cfg['E'], batch_norm=True))

0 commit comments

Comments
 (0)