Skip to content

Commit 77cad12

Browse files
authored
Adding multi-layer perceptron in ops (#6053)
* Adding an MLP block. * Adding documentation * Update typos. * Fix inplace for Dropout. * Apply recommendations from code review. * Making changes on pre-trained models. * Fix linter
1 parent e65372e commit 77cad12

File tree

5 files changed

+101
-22
lines changed

5 files changed

+101
-22
lines changed

docs/source/ops.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ TorchVision provides commonly used building blocks as layers:
8787
DeformConv2d
8888
DropBlock2d
8989
DropBlock3d
90+
MLP
9091
FrozenBatchNorm2d
9192
SqueezeExcitation
9293
StochasticDepth

torchvision/models/swin_transformer.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,14 @@
55
import torch.nn.functional as F
66
from torch import nn, Tensor
77

8+
from ..ops.misc import MLP
89
from ..ops.stochastic_depth import StochasticDepth
910
from ..transforms._presets import ImageClassification, InterpolationMode
1011
from ..utils import _log_api_usage_once
1112
from ._api import WeightsEnum, Weights
1213
from ._meta import _IMAGENET_CATEGORIES
1314
from ._utils import _ovewrite_named_param
14-
from .convnext import Permute
15-
from .vision_transformer import MLPBlock
15+
from .convnext import Permute # TODO: move Permute on ops
1616

1717

1818
__all__ = [
@@ -263,7 +263,13 @@ def __init__(
263263
)
264264
self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row")
265265
self.norm2 = norm_layer(dim)
266-
self.mlp = MLPBlock(dim, int(dim * mlp_ratio), dropout)
266+
self.mlp = MLP(dim, [int(dim * mlp_ratio), dim], activation_layer=nn.GELU, inplace=None, dropout=dropout)
267+
268+
for m in self.mlp.modules():
269+
if isinstance(m, nn.Linear):
270+
nn.init.xavier_uniform_(m.weight)
271+
if m.bias is not None:
272+
nn.init.normal_(m.bias, std=1e-6)
267273

268274
def forward(self, x: Tensor):
269275
x = x + self.stochastic_depth(self.attn(self.norm1(x)))
@@ -412,7 +418,7 @@ def _swin_transformer(
412418

413419
class Swin_T_Weights(WeightsEnum):
414420
IMAGENET1K_V1 = Weights(
415-
url="https://download.pytorch.org/models/swin_t-4c37bd06.pth",
421+
url="https://download.pytorch.org/models/swin_t-704ceda3.pth",
416422
transforms=partial(
417423
ImageClassification, crop_size=224, resize_size=232, interpolation=InterpolationMode.BICUBIC
418424
),
@@ -435,7 +441,7 @@ class Swin_T_Weights(WeightsEnum):
435441

436442
class Swin_S_Weights(WeightsEnum):
437443
IMAGENET1K_V1 = Weights(
438-
url="https://download.pytorch.org/models/swin_s-30134662.pth",
444+
url="https://download.pytorch.org/models/swin_s-5e29d889.pth",
439445
transforms=partial(
440446
ImageClassification, crop_size=224, resize_size=246, interpolation=InterpolationMode.BICUBIC
441447
),
@@ -458,7 +464,7 @@ class Swin_S_Weights(WeightsEnum):
458464

459465
class Swin_B_Weights(WeightsEnum):
460466
IMAGENET1K_V1 = Weights(
461-
url="https://download.pytorch.org/models/swin_b-1f1feb5c.pth",
467+
url="https://download.pytorch.org/models/swin_b-68c6b09e.pth",
462468
transforms=partial(
463469
ImageClassification, crop_size=224, resize_size=238, interpolation=InterpolationMode.BICUBIC
464470
),

torchvision/models/vision_transformer.py

Lines changed: 40 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import torch
77
import torch.nn as nn
88

9-
from ..ops.misc import Conv2dNormActivation
9+
from ..ops.misc import Conv2dNormActivation, MLP
1010
from ..transforms._presets import ImageClassification, InterpolationMode
1111
from ..utils import _log_api_usage_once
1212
from ._api import WeightsEnum, Weights
@@ -37,21 +37,48 @@ class ConvStemConfig(NamedTuple):
3737
activation_layer: Callable[..., nn.Module] = nn.ReLU
3838

3939

40-
class MLPBlock(nn.Sequential):
40+
class MLPBlock(MLP):
4141
"""Transformer MLP block."""
4242

4343
def __init__(self, in_dim: int, mlp_dim: int, dropout: float):
44-
super().__init__()
45-
self.linear_1 = nn.Linear(in_dim, mlp_dim)
46-
self.act = nn.GELU()
47-
self.dropout_1 = nn.Dropout(dropout)
48-
self.linear_2 = nn.Linear(mlp_dim, in_dim)
49-
self.dropout_2 = nn.Dropout(dropout)
50-
51-
nn.init.xavier_uniform_(self.linear_1.weight)
52-
nn.init.xavier_uniform_(self.linear_2.weight)
53-
nn.init.normal_(self.linear_1.bias, std=1e-6)
54-
nn.init.normal_(self.linear_2.bias, std=1e-6)
44+
super().__init__(in_dim, [mlp_dim, in_dim], activation_layer=nn.GELU, inplace=None, dropout=dropout)
45+
46+
for m in self.modules():
47+
if isinstance(m, nn.Linear):
48+
nn.init.xavier_uniform_(m.weight)
49+
if m.bias is not None:
50+
nn.init.normal_(m.bias, std=1e-6)
51+
52+
def _load_from_state_dict(
53+
self,
54+
state_dict,
55+
prefix,
56+
local_metadata,
57+
strict,
58+
missing_keys,
59+
unexpected_keys,
60+
error_msgs,
61+
):
62+
version = local_metadata.get("version", None)
63+
64+
if version is None or version < 2:
65+
# Replacing legacy MLPBlock with MLP. See https://github.com/pytorch/vision/pull/6053
66+
for i in range(2):
67+
for type in ["weight", "bias"]:
68+
old_key = f"{prefix}linear_{i+1}.{type}"
69+
new_key = f"{prefix}{3*i}.{type}"
70+
if old_key in state_dict:
71+
state_dict[new_key] = state_dict.pop(old_key)
72+
73+
super()._load_from_state_dict(
74+
state_dict,
75+
prefix,
76+
local_metadata,
77+
strict,
78+
missing_keys,
79+
unexpected_keys,
80+
error_msgs,
81+
)
5582

5683

5784
class EncoderBlock(nn.Module):

torchvision/ops/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from .feature_pyramid_network import FeaturePyramidNetwork
2020
from .focal_loss import sigmoid_focal_loss
2121
from .giou_loss import generalized_box_iou_loss
22-
from .misc import FrozenBatchNorm2d, Conv2dNormActivation, Conv3dNormActivation, SqueezeExcitation
22+
from .misc import FrozenBatchNorm2d, Conv2dNormActivation, Conv3dNormActivation, SqueezeExcitation, MLP
2323
from .poolers import MultiScaleRoIAlign
2424
from .ps_roi_align import ps_roi_align, PSRoIAlign
2525
from .ps_roi_pool import ps_roi_pool, PSRoIPool
@@ -61,6 +61,7 @@
6161
"Conv2dNormActivation",
6262
"Conv3dNormActivation",
6363
"SqueezeExcitation",
64+
"MLP",
6465
"generalized_box_iou_loss",
6566
"distance_box_iou_loss",
6667
"complete_box_iou_loss",

torchvision/ops/misc.py

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ class Conv2dNormActivation(ConvNormActivation):
129129
padding (int, tuple or str, optional): Padding added to all four sides of the input. Default: None, in which case it will calculated as ``padding = (kernel_size - 1) // 2 * dilation``
130130
groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
131131
norm_layer (Callable[..., torch.nn.Module], optional): Norm layer that will be stacked on top of the convolution layer. If ``None`` this layer wont be used. Default: ``torch.nn.BatchNorm2d``
132-
activation_layer (Callable[..., torch.nn.Module], optinal): Activation function which will be stacked on top of the normalization layer (if not None), otherwise on top of the conv layer. If ``None`` this layer wont be used. Default: ``torch.nn.ReLU``
132+
activation_layer (Callable[..., torch.nn.Module], optional): Activation function which will be stacked on top of the normalization layer (if not None), otherwise on top of the conv layer. If ``None`` this layer wont be used. Default: ``torch.nn.ReLU``
133133
dilation (int): Spacing between kernel elements. Default: 1
134134
inplace (bool): Parameter for the activation layer, which can optionally do the operation in-place. Default ``True``
135135
bias (bool, optional): Whether to use bias in the convolution layer. By default, biases are included if ``norm_layer is None``.
@@ -179,7 +179,7 @@ class Conv3dNormActivation(ConvNormActivation):
179179
padding (int, tuple or str, optional): Padding added to all four sides of the input. Default: None, in which case it will calculated as ``padding = (kernel_size - 1) // 2 * dilation``
180180
groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
181181
norm_layer (Callable[..., torch.nn.Module], optional): Norm layer that will be stacked on top of the convolution layer. If ``None`` this layer wont be used. Default: ``torch.nn.BatchNorm3d``
182-
activation_layer (Callable[..., torch.nn.Module], optinal): Activation function which will be stacked on top of the normalization layer (if not None), otherwise on top of the conv layer. If ``None`` this layer wont be used. Default: ``torch.nn.ReLU``
182+
activation_layer (Callable[..., torch.nn.Module], optional): Activation function which will be stacked on top of the normalization layer (if not None), otherwise on top of the conv layer. If ``None`` this layer wont be used. Default: ``torch.nn.ReLU``
183183
dilation (int): Spacing between kernel elements. Default: 1
184184
inplace (bool): Parameter for the activation layer, which can optionally do the operation in-place. Default ``True``
185185
bias (bool, optional): Whether to use bias in the convolution layer. By default, biases are included if ``norm_layer is None``.
@@ -253,3 +253,47 @@ def _scale(self, input: Tensor) -> Tensor:
253253
def forward(self, input: Tensor) -> Tensor:
254254
scale = self._scale(input)
255255
return scale * input
256+
257+
258+
class MLP(torch.nn.Sequential):
259+
"""This block implements the multi-layer perceptron (MLP) module.
260+
261+
Args:
262+
in_channels (int): Number of channels of the input
263+
hidden_channels (List[int]): List of the hidden channel dimensions
264+
norm_layer (Callable[..., torch.nn.Module], optional): Norm layer that will be stacked on top of the convolution layer. If ``None`` this layer wont be used. Default: ``None``
265+
activation_layer (Callable[..., torch.nn.Module], optional): Activation function which will be stacked on top of the normalization layer (if not None), otherwise on top of the conv layer. If ``None`` this layer wont be used. Default: ``torch.nn.ReLU``
266+
inplace (bool): Parameter for the activation layer, which can optionally do the operation in-place. Default ``True``
267+
bias (bool): Whether to use bias in the linear layer. Default ``True``
268+
dropout (float): The probability for the dropout layer. Default: 0.0
269+
"""
270+
271+
def __init__(
272+
self,
273+
in_channels: int,
274+
hidden_channels: List[int],
275+
norm_layer: Optional[Callable[..., torch.nn.Module]] = None,
276+
activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU,
277+
inplace: Optional[bool] = True,
278+
bias: bool = True,
279+
dropout: float = 0.0,
280+
):
281+
# The addition of `norm_layer` is inspired from the implementation of TorchMultimodal:
282+
# https://github.com/facebookresearch/multimodal/blob/5dec8a/torchmultimodal/modules/layers/mlp.py
283+
params = {} if inplace is None else {"inplace": inplace}
284+
285+
layers = []
286+
in_dim = in_channels
287+
for hidden_dim in hidden_channels[:-1]:
288+
layers.append(torch.nn.Linear(in_dim, hidden_dim, bias=bias))
289+
if norm_layer is not None:
290+
layers.append(norm_layer(hidden_dim))
291+
layers.append(activation_layer(**params))
292+
layers.append(torch.nn.Dropout(dropout, **params))
293+
in_dim = hidden_dim
294+
295+
layers.append(torch.nn.Linear(in_dim, hidden_channels[-1], bias=bias))
296+
layers.append(torch.nn.Dropout(dropout, **params))
297+
298+
super().__init__(*layers)
299+
_log_api_usage_once(self)

0 commit comments

Comments
 (0)