Skip to content

Commit 8d46ef7

Browse files
authored
Graduate ConvNeXt to main TorchVision area (#5330)
* Graduate ConvNeXt to main TorchVision area. * Linter and all var. * Renaming var and making named params mandatory.
1 parent d5a22a8 commit 8d46ef7

File tree

5 files changed

+299
-176
lines changed

5 files changed

+299
-176
lines changed

docs/source/models.rst

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,10 @@ You can construct a model with random weights by calling its constructor:
8989
vit_b_32 = models.vit_b_32()
9090
vit_l_16 = models.vit_l_16()
9191
vit_l_32 = models.vit_l_32()
92+
convnext_tiny = models.convnext_tiny()
93+
convnext_small = models.convnext_small()
94+
convnext_base = models.convnext_base()
95+
convnext_large = models.convnext_large()
9296
9397
We provide pre-trained models, using the PyTorch :mod:`torch.utils.model_zoo`.
9498
These can be constructed by passing ``pretrained=True``:
@@ -136,6 +140,10 @@ These can be constructed by passing ``pretrained=True``:
136140
vit_b_32 = models.vit_b_32(pretrained=True)
137141
vit_l_16 = models.vit_l_16(pretrained=True)
138142
vit_l_32 = models.vit_l_32(pretrained=True)
143+
convnext_tiny = models.convnext_tiny(pretrained=True)
144+
convnext_small = models.convnext_small(pretrained=True)
145+
convnext_base = models.convnext_base(pretrained=True)
146+
convnext_large = models.convnext_large(pretrained=True)
139147
140148
Instancing a pre-trained model will download its weights to a cache directory.
141149
This directory can be set using the `TORCH_HOME` environment variable. See
@@ -248,10 +256,10 @@ vit_b_16 81.072 95.318
248256
vit_b_32 75.912 92.466
249257
vit_l_16 79.662 94.638
250258
vit_l_32 76.972 93.070
251-
convnext_tiny (prototype) 82.520 96.146
252-
convnext_small (prototype) 83.616 96.650
253-
convnext_base (prototype) 84.062 96.870
254-
convnext_large (prototype) 84.414 96.976
259+
convnext_tiny 82.520 96.146
260+
convnext_small 83.616 96.650
261+
convnext_base 84.062 96.870
262+
convnext_large 84.414 96.976
255263
================================ ============= =============
256264

257265

@@ -467,6 +475,18 @@ VisionTransformer
467475
vit_l_16
468476
vit_l_32
469477

478+
ConvNeXt
479+
--------
480+
481+
.. autosummary::
482+
:toctree: generated/
483+
:template: function.rst
484+
485+
convnext_tiny
486+
convnext_small
487+
convnext_base
488+
convnext_large
489+
470490
Quantized Models
471491
----------------
472492

hubconf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
dependencies = ["torch"]
33

44
from torchvision.models.alexnet import alexnet
5+
from torchvision.models.convnext import convnext_tiny, convnext_small, convnext_base, convnext_large
56
from torchvision.models.densenet import densenet121, densenet169, densenet201, densenet161
67
from torchvision.models.efficientnet import (
78
efficientnet_b0,

torchvision/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from .alexnet import *
2+
from .convnext import *
23
from .resnet import *
34
from .vgg import *
45
from .squeezenet import *

torchvision/models/convnext.py

Lines changed: 271 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,271 @@
1+
from functools import partial
2+
from typing import Any, Callable, Dict, List, Optional, Sequence
3+
4+
import torch
5+
from torch import nn, Tensor
6+
from torch.nn import functional as F
7+
8+
from .._internally_replaced_utils import load_state_dict_from_url
9+
from ..ops.misc import ConvNormActivation
10+
from ..ops.stochastic_depth import StochasticDepth
11+
from ..utils import _log_api_usage_once
12+
13+
14+
__all__ = [
15+
"ConvNeXt",
16+
"convnext_tiny",
17+
"convnext_small",
18+
"convnext_base",
19+
"convnext_large",
20+
]
21+
22+
23+
_MODELS_URLS: Dict[str, Optional[str]] = {
24+
"convnext_tiny": "https://download.pytorch.org/models/convnext_tiny-983f1562.pth",
25+
"convnext_small": "https://download.pytorch.org/models/convnext_small-0c510722.pth",
26+
"convnext_base": "https://download.pytorch.org/models/convnext_base-6075fbad.pth",
27+
"convnext_large": "https://download.pytorch.org/models/convnext_large-ea097f82.pth",
28+
}
29+
30+
31+
class LayerNorm2d(nn.LayerNorm):
32+
def forward(self, x: Tensor) -> Tensor:
33+
x = x.permute(0, 2, 3, 1)
34+
x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
35+
x = x.permute(0, 3, 1, 2)
36+
return x
37+
38+
39+
class Permute(nn.Module):
40+
def __init__(self, dims: List[int]):
41+
super().__init__()
42+
self.dims = dims
43+
44+
def forward(self, x):
45+
return torch.permute(x, self.dims)
46+
47+
48+
class CNBlock(nn.Module):
49+
def __init__(
50+
self,
51+
dim,
52+
layer_scale: float,
53+
stochastic_depth_prob: float,
54+
norm_layer: Optional[Callable[..., nn.Module]] = None,
55+
) -> None:
56+
super().__init__()
57+
if norm_layer is None:
58+
norm_layer = partial(nn.LayerNorm, eps=1e-6)
59+
60+
self.block = nn.Sequential(
61+
nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim, bias=True),
62+
Permute([0, 2, 3, 1]),
63+
norm_layer(dim),
64+
nn.Linear(in_features=dim, out_features=4 * dim, bias=True),
65+
nn.GELU(),
66+
nn.Linear(in_features=4 * dim, out_features=dim, bias=True),
67+
Permute([0, 3, 1, 2]),
68+
)
69+
self.layer_scale = nn.Parameter(torch.ones(dim, 1, 1) * layer_scale)
70+
self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row")
71+
72+
def forward(self, input: Tensor) -> Tensor:
73+
result = self.layer_scale * self.block(input)
74+
result = self.stochastic_depth(result)
75+
result += input
76+
return result
77+
78+
79+
class CNBlockConfig:
80+
# Stores information listed at Section 3 of the ConvNeXt paper
81+
def __init__(
82+
self,
83+
input_channels: int,
84+
out_channels: Optional[int],
85+
num_layers: int,
86+
) -> None:
87+
self.input_channels = input_channels
88+
self.out_channels = out_channels
89+
self.num_layers = num_layers
90+
91+
def __repr__(self) -> str:
92+
s = self.__class__.__name__ + "("
93+
s += "input_channels={input_channels}"
94+
s += ", out_channels={out_channels}"
95+
s += ", num_layers={num_layers}"
96+
s += ")"
97+
return s.format(**self.__dict__)
98+
99+
100+
class ConvNeXt(nn.Module):
101+
def __init__(
102+
self,
103+
block_setting: List[CNBlockConfig],
104+
stochastic_depth_prob: float = 0.0,
105+
layer_scale: float = 1e-6,
106+
num_classes: int = 1000,
107+
block: Optional[Callable[..., nn.Module]] = None,
108+
norm_layer: Optional[Callable[..., nn.Module]] = None,
109+
**kwargs: Any,
110+
) -> None:
111+
super().__init__()
112+
_log_api_usage_once(self)
113+
114+
if not block_setting:
115+
raise ValueError("The block_setting should not be empty")
116+
elif not (isinstance(block_setting, Sequence) and all([isinstance(s, CNBlockConfig) for s in block_setting])):
117+
raise TypeError("The block_setting should be List[CNBlockConfig]")
118+
119+
if block is None:
120+
block = CNBlock
121+
122+
if norm_layer is None:
123+
norm_layer = partial(LayerNorm2d, eps=1e-6)
124+
125+
layers: List[nn.Module] = []
126+
127+
# Stem
128+
firstconv_output_channels = block_setting[0].input_channels
129+
layers.append(
130+
ConvNormActivation(
131+
3,
132+
firstconv_output_channels,
133+
kernel_size=4,
134+
stride=4,
135+
padding=0,
136+
norm_layer=norm_layer,
137+
activation_layer=None,
138+
bias=True,
139+
)
140+
)
141+
142+
total_stage_blocks = sum(cnf.num_layers for cnf in block_setting)
143+
stage_block_id = 0
144+
for cnf in block_setting:
145+
# Bottlenecks
146+
stage: List[nn.Module] = []
147+
for _ in range(cnf.num_layers):
148+
# adjust stochastic depth probability based on the depth of the stage block
149+
sd_prob = stochastic_depth_prob * stage_block_id / (total_stage_blocks - 1.0)
150+
stage.append(block(cnf.input_channels, layer_scale, sd_prob))
151+
stage_block_id += 1
152+
layers.append(nn.Sequential(*stage))
153+
if cnf.out_channels is not None:
154+
# Downsampling
155+
layers.append(
156+
nn.Sequential(
157+
norm_layer(cnf.input_channels),
158+
nn.Conv2d(cnf.input_channels, cnf.out_channels, kernel_size=2, stride=2),
159+
)
160+
)
161+
162+
self.features = nn.Sequential(*layers)
163+
self.avgpool = nn.AdaptiveAvgPool2d(1)
164+
165+
lastblock = block_setting[-1]
166+
lastconv_output_channels = (
167+
lastblock.out_channels if lastblock.out_channels is not None else lastblock.input_channels
168+
)
169+
self.classifier = nn.Sequential(
170+
norm_layer(lastconv_output_channels), nn.Flatten(1), nn.Linear(lastconv_output_channels, num_classes)
171+
)
172+
173+
for m in self.modules():
174+
if isinstance(m, (nn.Conv2d, nn.Linear)):
175+
nn.init.trunc_normal_(m.weight, std=0.02)
176+
if m.bias is not None:
177+
nn.init.zeros_(m.bias)
178+
179+
def _forward_impl(self, x: Tensor) -> Tensor:
180+
x = self.features(x)
181+
x = self.avgpool(x)
182+
x = self.classifier(x)
183+
return x
184+
185+
def forward(self, x: Tensor) -> Tensor:
186+
return self._forward_impl(x)
187+
188+
189+
def _convnext(
190+
arch: str,
191+
block_setting: List[CNBlockConfig],
192+
stochastic_depth_prob: float,
193+
pretrained: bool,
194+
progress: bool,
195+
**kwargs: Any,
196+
) -> ConvNeXt:
197+
model = ConvNeXt(block_setting, stochastic_depth_prob=stochastic_depth_prob, **kwargs)
198+
if pretrained:
199+
if arch not in _MODELS_URLS:
200+
raise ValueError(f"No checkpoint is available for model type {arch}")
201+
state_dict = load_state_dict_from_url(_MODELS_URLS[arch], progress=progress)
202+
model.load_state_dict(state_dict)
203+
return model
204+
205+
206+
def convnext_tiny(*, pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ConvNeXt:
207+
r"""ConvNeXt Tiny model architecture from the
208+
`"A ConvNet for the 2020s" <https://arxiv.org/abs/2201.03545>`_ paper.
209+
Args:
210+
pretrained (bool): If True, returns a model pre-trained on ImageNet
211+
progress (bool): If True, displays a progress bar of the download to stderr
212+
"""
213+
block_setting = [
214+
CNBlockConfig(96, 192, 3),
215+
CNBlockConfig(192, 384, 3),
216+
CNBlockConfig(384, 768, 9),
217+
CNBlockConfig(768, None, 3),
218+
]
219+
stochastic_depth_prob = kwargs.pop("stochastic_depth_prob", 0.1)
220+
return _convnext("convnext_tiny", block_setting, stochastic_depth_prob, pretrained, progress, **kwargs)
221+
222+
223+
def convnext_small(*, pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ConvNeXt:
224+
r"""ConvNeXt Small model architecture from the
225+
`"A ConvNet for the 2020s" <https://arxiv.org/abs/2201.03545>`_ paper.
226+
Args:
227+
pretrained (bool): If True, returns a model pre-trained on ImageNet
228+
progress (bool): If True, displays a progress bar of the download to stderr
229+
"""
230+
block_setting = [
231+
CNBlockConfig(96, 192, 3),
232+
CNBlockConfig(192, 384, 3),
233+
CNBlockConfig(384, 768, 27),
234+
CNBlockConfig(768, None, 3),
235+
]
236+
stochastic_depth_prob = kwargs.pop("stochastic_depth_prob", 0.4)
237+
return _convnext("convnext_small", block_setting, stochastic_depth_prob, pretrained, progress, **kwargs)
238+
239+
240+
def convnext_base(*, pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ConvNeXt:
241+
r"""ConvNeXt Base model architecture from the
242+
`"A ConvNet for the 2020s" <https://arxiv.org/abs/2201.03545>`_ paper.
243+
Args:
244+
pretrained (bool): If True, returns a model pre-trained on ImageNet
245+
progress (bool): If True, displays a progress bar of the download to stderr
246+
"""
247+
block_setting = [
248+
CNBlockConfig(128, 256, 3),
249+
CNBlockConfig(256, 512, 3),
250+
CNBlockConfig(512, 1024, 27),
251+
CNBlockConfig(1024, None, 3),
252+
]
253+
stochastic_depth_prob = kwargs.pop("stochastic_depth_prob", 0.5)
254+
return _convnext("convnext_base", block_setting, stochastic_depth_prob, pretrained, progress, **kwargs)
255+
256+
257+
def convnext_large(*, pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ConvNeXt:
258+
r"""ConvNeXt Large model architecture from the
259+
`"A ConvNet for the 2020s" <https://arxiv.org/abs/2201.03545>`_ paper.
260+
Args:
261+
pretrained (bool): If True, returns a model pre-trained on ImageNet
262+
progress (bool): If True, displays a progress bar of the download to stderr
263+
"""
264+
block_setting = [
265+
CNBlockConfig(192, 384, 3),
266+
CNBlockConfig(384, 768, 3),
267+
CNBlockConfig(768, 1536, 27),
268+
CNBlockConfig(1536, None, 3),
269+
]
270+
stochastic_depth_prob = kwargs.pop("stochastic_depth_prob", 0.5)
271+
return _convnext("convnext_large", block_setting, stochastic_depth_prob, pretrained, progress, **kwargs)

0 commit comments

Comments
 (0)