diff --git a/docs/source/ops.rst b/docs/source/ops.rst index 472c45fbab4..7124c85bb79 100644 --- a/docs/source/ops.rst +++ b/docs/source/ops.rst @@ -87,8 +87,9 @@ TorchVision provides commonly used building blocks as layers: DeformConv2d DropBlock2d DropBlock3d - MLP FrozenBatchNorm2d + MLP + Permute SqueezeExcitation StochasticDepth diff --git a/torchvision/models/convnext.py b/torchvision/models/convnext.py index 4cd75690df4..435789ca0e2 100644 --- a/torchvision/models/convnext.py +++ b/torchvision/models/convnext.py @@ -5,7 +5,7 @@ from torch import nn, Tensor from torch.nn import functional as F -from ..ops.misc import Conv2dNormActivation +from ..ops.misc import Conv2dNormActivation, Permute from ..ops.stochastic_depth import StochasticDepth from ..transforms._presets import ImageClassification from ..utils import _log_api_usage_once @@ -35,15 +35,6 @@ def forward(self, x: Tensor) -> Tensor: return x -class Permute(nn.Module): - def __init__(self, dims: List[int]): - super().__init__() - self.dims = dims - - def forward(self, x): - return torch.permute(x, self.dims) - - class CNBlock(nn.Module): def __init__( self, diff --git a/torchvision/models/swin_transformer.py b/torchvision/models/swin_transformer.py index 148bfa1c4a2..25e8900db56 100644 --- a/torchvision/models/swin_transformer.py +++ b/torchvision/models/swin_transformer.py @@ -5,14 +5,13 @@ import torch.nn.functional as F from torch import nn, Tensor -from ..ops.misc import MLP +from ..ops.misc import MLP, Permute from ..ops.stochastic_depth import StochasticDepth from ..transforms._presets import ImageClassification, InterpolationMode from ..utils import _log_api_usage_once from ._api import WeightsEnum, Weights from ._meta import _IMAGENET_CATEGORIES from ._utils import _ovewrite_named_param -from .convnext import Permute # TODO: move Permute on ops __all__ = [ diff --git a/torchvision/ops/__init__.py b/torchvision/ops/__init__.py index 333e9246401..5d56f0bca42 100644 --- a/torchvision/ops/__init__.py +++ b/torchvision/ops/__init__.py @@ -19,7 +19,7 @@ from .feature_pyramid_network import FeaturePyramidNetwork from .focal_loss import sigmoid_focal_loss from .giou_loss import generalized_box_iou_loss -from .misc import FrozenBatchNorm2d, Conv2dNormActivation, Conv3dNormActivation, SqueezeExcitation, MLP +from .misc import FrozenBatchNorm2d, Conv2dNormActivation, Conv3dNormActivation, SqueezeExcitation, MLP, Permute from .poolers import MultiScaleRoIAlign from .ps_roi_align import ps_roi_align, PSRoIAlign from .ps_roi_pool import ps_roi_pool, PSRoIPool @@ -62,6 +62,7 @@ "Conv3dNormActivation", "SqueezeExcitation", "MLP", + "Permute", "generalized_box_iou_loss", "distance_box_iou_loss", "complete_box_iou_loss", diff --git a/torchvision/ops/misc.py b/torchvision/ops/misc.py index 2e4816c9f22..b1463cf315b 100644 --- a/torchvision/ops/misc.py +++ b/torchvision/ops/misc.py @@ -10,7 +10,6 @@ interpolate = torch.nn.functional.interpolate -# This is not in nn class FrozenBatchNorm2d(torch.nn.Module): """ BatchNorm2d where the batch statistics and the affine parameters are fixed @@ -297,3 +296,18 @@ def __init__( super().__init__(*layers) _log_api_usage_once(self) + + +class Permute(torch.nn.Module): + """This module returns a view of the tensor input with its dimensions permuted. + + Args: + dims (List[int]): The desired ordering of dimensions + """ + + def __init__(self, dims: List[int]): + super().__init__() + self.dims = dims + + def forward(self, x: Tensor) -> Tensor: + return torch.permute(x, self.dims)