Skip to content

Commit 0d599ba

Browse files
yiwen-songfacebook-github-bot
authored andcommitted
[fbsync] Adding more ConvNeXt variants + Speed optimizations (#5253)
Summary: * Refactor model builder * Add 3 more convnext variants. * Adding weights for convnext_small. * Fix minor bug. * Fix number of parameters for small model. * Adding weights for the base variant. * Adding weights for the large variant. * Simplify LayerNorm2d implementation. * Optimize speed of CNBlock. * Repackage weights. Reviewed By: kazhang Differential Revision: D33927490 fbshipit-source-id: 569d9f752b1c5d5ba6f9a8f9721b4f91fac6663d
1 parent 1af32fe commit 0d599ba

File tree

6 files changed

+164
-48
lines changed

6 files changed

+164
-48
lines changed

docs/source/models.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,9 @@ vit_b_32 75.912 92.466
249249
vit_l_16 79.662 94.638
250250
vit_l_32 76.972 93.070
251251
convnext_tiny (prototype) 82.520 96.146
252+
convnext_small (prototype) 83.616 96.650
253+
convnext_base (prototype) 84.062 96.870
254+
convnext_large (prototype) 84.414 96.976
252255
================================ ============= =============
253256

254257

references/classification/README.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -201,11 +201,12 @@ and `--batch_size 64`.
201201
### ConvNeXt
202202
```
203203
torchrun --nproc_per_node=8 train.py\
204-
--model convnext_tiny --batch-size 128 --opt adamw --lr 1e-3 --lr-scheduler cosineannealinglr \
204+
--model $MODEL --batch-size 128 --opt adamw --lr 1e-3 --lr-scheduler cosineannealinglr \
205205
--lr-warmup-epochs 5 --lr-warmup-method linear --auto-augment ta_wide --epochs 600 --random-erase 0.1 \
206206
--label-smoothing 0.1 --mixup-alpha 0.2 --cutmix-alpha 1.0 --weight-decay 0.05 --norm-weight-decay 0.0 \
207-
--train-crop-size 176 --model-ema --val-resize-size 236 --ra-sampler --ra-reps 4
207+
--train-crop-size 176 --model-ema --val-resize-size 232 --ra-sampler --ra-reps 4
208208
```
209+
Here `$MODEL` is one of `convnext_tiny`, `convnext_small`, `convnext_base` and `convnext_large`. Note that each variant had its `--val-resize-size` optimized in a post-training step, see their `Weights` entry for their exact value.
209210

210211
Note that the above command corresponds to training on a single node with 8 GPUs.
211212
For generatring the pre-trained weights, we trained with 2 nodes, each with 8 GPUs (for a total of 16 GPUs),
Binary file not shown.
Binary file not shown.
Binary file not shown.

torchvision/prototype/models/convnext.py

Lines changed: 158 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -15,47 +15,56 @@
1515
from ._utils import handle_legacy_interface, _ovewrite_named_param
1616

1717

18-
__all__ = ["ConvNeXt", "ConvNeXt_Tiny_Weights", "convnext_tiny"]
18+
__all__ = [
19+
"ConvNeXt",
20+
"ConvNeXt_Tiny_Weights",
21+
"ConvNeXt_Small_Weights",
22+
"ConvNeXt_Base_Weights",
23+
"ConvNeXt_Large_Weights",
24+
"convnext_tiny",
25+
"convnext_small",
26+
"convnext_base",
27+
"convnext_large",
28+
]
1929

2030

2131
class LayerNorm2d(nn.LayerNorm):
22-
def __init__(self, *args: Any, **kwargs: Any) -> None:
23-
self.channels_last = kwargs.pop("channels_last", False)
24-
super().__init__(*args, **kwargs)
25-
2632
def forward(self, x: Tensor) -> Tensor:
27-
# TODO: Benchmark this against the approach described at https://github.com/pytorch/vision/pull/5197#discussion_r786251298
28-
if not self.channels_last:
29-
x = x.permute(0, 2, 3, 1)
33+
x = x.permute(0, 2, 3, 1)
3034
x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
31-
if not self.channels_last:
32-
x = x.permute(0, 3, 1, 2)
35+
x = x.permute(0, 3, 1, 2)
3336
return x
3437

3538

39+
class Permute(nn.Module):
40+
def __init__(self, dims: List[int]):
41+
super().__init__()
42+
self.dims = dims
43+
44+
def forward(self, x):
45+
return torch.permute(x, self.dims)
46+
47+
3648
class CNBlock(nn.Module):
3749
def __init__(
38-
self, dim, layer_scale: float, stochastic_depth_prob: float, norm_layer: Callable[..., nn.Module]
50+
self,
51+
dim,
52+
layer_scale: float,
53+
stochastic_depth_prob: float,
54+
norm_layer: Optional[Callable[..., nn.Module]] = None,
3955
) -> None:
4056
super().__init__()
57+
if norm_layer is None:
58+
norm_layer = partial(nn.LayerNorm, eps=1e-6)
59+
4160
self.block = nn.Sequential(
42-
ConvNormActivation(
43-
dim,
44-
dim,
45-
kernel_size=7,
46-
groups=dim,
47-
norm_layer=norm_layer,
48-
activation_layer=None,
49-
bias=True,
50-
),
51-
ConvNormActivation(dim, 4 * dim, kernel_size=1, norm_layer=None, activation_layer=nn.GELU, inplace=None),
52-
ConvNormActivation(
53-
4 * dim,
54-
dim,
55-
kernel_size=1,
56-
norm_layer=None,
57-
activation_layer=None,
58-
),
61+
nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim, bias=True),
62+
Permute([0, 2, 3, 1]),
63+
norm_layer(dim),
64+
nn.Linear(in_features=dim, out_features=4 * dim, bias=True),
65+
nn.GELU(),
66+
nn.Linear(in_features=4 * dim, out_features=dim, bias=True),
67+
Permute([0, 3, 1, 2]),
5968
)
6069
self.layer_scale = nn.Parameter(torch.ones(dim, 1, 1) * layer_scale)
6170
self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row")
@@ -138,7 +147,7 @@ def __init__(
138147
for _ in range(cnf.num_layers):
139148
# adjust stochastic depth probability based on the depth of the stage block
140149
sd_prob = stochastic_depth_prob * stage_block_id / (total_stage_blocks - 1.0)
141-
stage.append(block(cnf.input_channels, layer_scale, sd_prob, norm_layer))
150+
stage.append(block(cnf.input_channels, layer_scale, sd_prob))
142151
stage_block_id += 1
143152
layers.append(nn.Sequential(*stage))
144153
if cnf.out_channels is not None:
@@ -177,30 +186,95 @@ def forward(self, x: Tensor) -> Tensor:
177186
return self._forward_impl(x)
178187

179188

189+
def _convnext(
190+
block_setting: List[CNBlockConfig],
191+
stochastic_depth_prob: float,
192+
weights: Optional[WeightsEnum],
193+
progress: bool,
194+
**kwargs: Any,
195+
) -> ConvNeXt:
196+
if weights is not None:
197+
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
198+
199+
model = ConvNeXt(block_setting, stochastic_depth_prob=stochastic_depth_prob, **kwargs)
200+
201+
if weights is not None:
202+
model.load_state_dict(weights.get_state_dict(progress=progress))
203+
204+
return model
205+
206+
207+
_COMMON_META = {
208+
"task": "image_classification",
209+
"architecture": "ConvNeXt",
210+
"publication_year": 2022,
211+
"size": (224, 224),
212+
"min_size": (32, 32),
213+
"categories": _IMAGENET_CATEGORIES,
214+
"interpolation": InterpolationMode.BILINEAR,
215+
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#convnext",
216+
}
217+
218+
180219
class ConvNeXt_Tiny_Weights(WeightsEnum):
181220
IMAGENET1K_V1 = Weights(
182-
url="https://download.pytorch.org/models/convnext_tiny-47b116bd.pth",
221+
url="https://download.pytorch.org/models/convnext_tiny-983f1562.pth",
183222
transforms=partial(ImageNetEval, crop_size=224, resize_size=236),
184223
meta={
185-
"task": "image_classification",
186-
"architecture": "ConvNeXt",
187-
"publication_year": 2022,
224+
**_COMMON_META,
188225
"num_params": 28589128,
189-
"size": (224, 224),
190-
"min_size": (32, 32),
191-
"categories": _IMAGENET_CATEGORIES,
192-
"interpolation": InterpolationMode.BILINEAR,
193-
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#convnext",
194226
"acc@1": 82.520,
195227
"acc@5": 96.146,
196228
},
197229
)
198230
DEFAULT = IMAGENET1K_V1
199231

200232

233+
class ConvNeXt_Small_Weights(WeightsEnum):
234+
IMAGENET1K_V1 = Weights(
235+
url="https://download.pytorch.org/models/convnext_small-0c510722.pth",
236+
transforms=partial(ImageNetEval, crop_size=224, resize_size=230),
237+
meta={
238+
**_COMMON_META,
239+
"num_params": 50223688,
240+
"acc@1": 83.616,
241+
"acc@5": 96.650,
242+
},
243+
)
244+
DEFAULT = IMAGENET1K_V1
245+
246+
247+
class ConvNeXt_Base_Weights(WeightsEnum):
248+
IMAGENET1K_V1 = Weights(
249+
url="https://download.pytorch.org/models/convnext_base-6075fbad.pth",
250+
transforms=partial(ImageNetEval, crop_size=224, resize_size=232),
251+
meta={
252+
**_COMMON_META,
253+
"num_params": 88591464,
254+
"acc@1": 84.062,
255+
"acc@5": 96.870,
256+
},
257+
)
258+
DEFAULT = IMAGENET1K_V1
259+
260+
261+
class ConvNeXt_Large_Weights(WeightsEnum):
262+
IMAGENET1K_V1 = Weights(
263+
url="https://download.pytorch.org/models/convnext_large-ea097f82.pth",
264+
transforms=partial(ImageNetEval, crop_size=224, resize_size=232),
265+
meta={
266+
**_COMMON_META,
267+
"num_params": 197767336,
268+
"acc@1": 84.414,
269+
"acc@5": 96.976,
270+
},
271+
)
272+
DEFAULT = IMAGENET1K_V1
273+
274+
201275
@handle_legacy_interface(weights=("pretrained", ConvNeXt_Tiny_Weights.IMAGENET1K_V1))
202276
def convnext_tiny(*, weights: Optional[ConvNeXt_Tiny_Weights] = None, progress: bool = True, **kwargs: Any) -> ConvNeXt:
203-
r"""ConvNeXt model architecture from the
277+
r"""ConvNeXt Tiny model architecture from the
204278
`"A ConvNet for the 2020s" <https://arxiv.org/abs/2201.03545>`_ paper.
205279
206280
Args:
@@ -209,19 +283,57 @@ def convnext_tiny(*, weights: Optional[ConvNeXt_Tiny_Weights] = None, progress:
209283
"""
210284
weights = ConvNeXt_Tiny_Weights.verify(weights)
211285

212-
if weights is not None:
213-
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
214-
215286
block_setting = [
216287
CNBlockConfig(96, 192, 3),
217288
CNBlockConfig(192, 384, 3),
218289
CNBlockConfig(384, 768, 9),
219290
CNBlockConfig(768, None, 3),
220291
]
221292
stochastic_depth_prob = kwargs.pop("stochastic_depth_prob", 0.1)
222-
model = ConvNeXt(block_setting, stochastic_depth_prob=stochastic_depth_prob, **kwargs)
293+
return _convnext(block_setting, stochastic_depth_prob, weights, progress, **kwargs)
223294

224-
if weights is not None:
225-
model.load_state_dict(weights.get_state_dict(progress=progress))
226295

227-
return model
296+
@handle_legacy_interface(weights=("pretrained", ConvNeXt_Small_Weights.IMAGENET1K_V1))
297+
def convnext_small(
298+
*, weights: Optional[ConvNeXt_Small_Weights] = None, progress: bool = True, **kwargs: Any
299+
) -> ConvNeXt:
300+
weights = ConvNeXt_Small_Weights.verify(weights)
301+
302+
block_setting = [
303+
CNBlockConfig(96, 192, 3),
304+
CNBlockConfig(192, 384, 3),
305+
CNBlockConfig(384, 768, 27),
306+
CNBlockConfig(768, None, 3),
307+
]
308+
stochastic_depth_prob = kwargs.pop("stochastic_depth_prob", 0.4)
309+
return _convnext(block_setting, stochastic_depth_prob, weights, progress, **kwargs)
310+
311+
312+
@handle_legacy_interface(weights=("pretrained", ConvNeXt_Base_Weights.IMAGENET1K_V1))
313+
def convnext_base(*, weights: Optional[ConvNeXt_Base_Weights] = None, progress: bool = True, **kwargs: Any) -> ConvNeXt:
314+
weights = ConvNeXt_Base_Weights.verify(weights)
315+
316+
block_setting = [
317+
CNBlockConfig(128, 256, 3),
318+
CNBlockConfig(256, 512, 3),
319+
CNBlockConfig(512, 1024, 27),
320+
CNBlockConfig(1024, None, 3),
321+
]
322+
stochastic_depth_prob = kwargs.pop("stochastic_depth_prob", 0.5)
323+
return _convnext(block_setting, stochastic_depth_prob, weights, progress, **kwargs)
324+
325+
326+
@handle_legacy_interface(weights=("pretrained", ConvNeXt_Large_Weights.IMAGENET1K_V1))
327+
def convnext_large(
328+
*, weights: Optional[ConvNeXt_Large_Weights] = None, progress: bool = True, **kwargs: Any
329+
) -> ConvNeXt:
330+
weights = ConvNeXt_Large_Weights.verify(weights)
331+
332+
block_setting = [
333+
CNBlockConfig(192, 384, 3),
334+
CNBlockConfig(384, 768, 3),
335+
CNBlockConfig(768, 1536, 27),
336+
CNBlockConfig(1536, None, 3),
337+
]
338+
stochastic_depth_prob = kwargs.pop("stochastic_depth_prob", 0.5)
339+
return _convnext(block_setting, stochastic_depth_prob, weights, progress, **kwargs)

0 commit comments

Comments
 (0)