diff --git a/docs/source/ops.rst b/docs/source/ops.rst index a50e202e00b..3e996ecd6c8 100644 --- a/docs/source/ops.rst +++ b/docs/source/ops.rst @@ -20,6 +20,8 @@ Operators box_iou clip_boxes_to_image deform_conv2d + drop_block2d + drop_block3d generalized_box_iou generalized_box_iou_loss masks_to_boxes @@ -48,3 +50,5 @@ Operators Conv2dNormActivation Conv3dNormActivation SqueezeExcitation + DropBlock2d + DropBlock3d diff --git a/test/test_ops.py b/test/test_ops.py index 6b35b4f0091..ad9aaefee52 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -2,6 +2,7 @@ import os from abc import ABC, abstractmethod from functools import lru_cache +from itertools import product from typing import Callable, List, Tuple import numpy as np @@ -57,6 +58,16 @@ def forward(self, a): self.layer(a) +class DropBlockWrapper(nn.Module): + def __init__(self, obj): + super().__init__() + self.layer = obj + self.n_inputs = 1 + + def forward(self, a): + self.layer(a) + + class RoIOpTester(ABC): dtype = torch.float64 @@ -1357,5 +1368,87 @@ def test_split_normalization_params(self, norm_layer): assert len(params[1]) == 82 +class TestDropBlock: + @pytest.mark.parametrize("seed", range(10)) + @pytest.mark.parametrize("dim", [2, 3]) + @pytest.mark.parametrize("p", [0, 0.5]) + @pytest.mark.parametrize("block_size", [5, 11]) + @pytest.mark.parametrize("inplace", [True, False]) + def test_drop_block(self, seed, dim, p, block_size, inplace): + torch.manual_seed(seed) + batch_size = 5 + channels = 3 + height = 11 + width = height + depth = height + if dim == 2: + x = torch.ones(size=(batch_size, channels, height, width)) + layer = ops.DropBlock2d(p=p, block_size=block_size, inplace=inplace) + feature_size = height * width + elif dim == 3: + x = torch.ones(size=(batch_size, channels, depth, height, width)) + layer = ops.DropBlock3d(p=p, block_size=block_size, inplace=inplace) + feature_size = depth * height * width + layer.__repr__() + + out = layer(x) + if p == 0: + assert out.equal(x) + if block_size == height: + for b, c in product(range(batch_size), range(channels)): + assert out[b, c].count_nonzero() in (0, feature_size) + + @pytest.mark.parametrize("seed", range(10)) + @pytest.mark.parametrize("dim", [2, 3]) + @pytest.mark.parametrize("p", [0.1, 0.2]) + @pytest.mark.parametrize("block_size", [3]) + @pytest.mark.parametrize("inplace", [False]) + def test_drop_block_random(self, seed, dim, p, block_size, inplace): + torch.manual_seed(seed) + batch_size = 5 + channels = 3 + height = 11 + width = height + depth = height + if dim == 2: + x = torch.ones(size=(batch_size, channels, height, width)) + layer = ops.DropBlock2d(p=p, block_size=block_size, inplace=inplace) + elif dim == 3: + x = torch.ones(size=(batch_size, channels, depth, height, width)) + layer = ops.DropBlock3d(p=p, block_size=block_size, inplace=inplace) + + trials = 250 + num_samples = 0 + counts = 0 + cell_numel = torch.tensor(x.shape).prod() + for _ in range(trials): + with torch.no_grad(): + out = layer(x) + non_zero_count = out.nonzero().size(0) + counts += cell_numel - non_zero_count + num_samples += cell_numel + + assert abs(p - counts / num_samples) / p < 0.15 + + def make_obj(self, dim, p, block_size, inplace, wrap=False): + if dim == 2: + obj = ops.DropBlock2d(p, block_size, inplace) + elif dim == 3: + obj = ops.DropBlock3d(p, block_size, inplace) + return DropBlockWrapper(obj) if wrap else obj + + @pytest.mark.parametrize("dim", (2, 3)) + @pytest.mark.parametrize("p", [0, 1]) + @pytest.mark.parametrize("block_size", [5, 7]) + @pytest.mark.parametrize("inplace", [True, False]) + def test_is_leaf_node(self, dim, p, block_size, inplace): + op_obj = self.make_obj(dim, p, block_size, inplace, wrap=True) + graph_node_names = get_graph_node_names(op_obj) + + assert len(graph_node_names) == 2 + assert len(graph_node_names[0]) == len(graph_node_names[1]) + assert len(graph_node_names[0]) == 1 + op_obj.n_inputs + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/torchvision/ops/__init__.py b/torchvision/ops/__init__.py index 9da336764d3..ceb78250415 100644 --- a/torchvision/ops/__init__.py +++ b/torchvision/ops/__init__.py @@ -11,6 +11,7 @@ ) from .boxes import box_convert from .deform_conv import deform_conv2d, DeformConv2d +from .drop_block import drop_block2d, DropBlock2d, drop_block3d, DropBlock3d from .feature_pyramid_network import FeaturePyramidNetwork from .focal_loss import sigmoid_focal_loss from .giou_loss import generalized_box_iou_loss @@ -55,4 +56,8 @@ "Conv3dNormActivation", "SqueezeExcitation", "generalized_box_iou_loss", + "drop_block2d", + "DropBlock2d", + "drop_block3d", + "DropBlock3d", ] diff --git a/torchvision/ops/drop_block.py b/torchvision/ops/drop_block.py new file mode 100644 index 00000000000..a798677f60f --- /dev/null +++ b/torchvision/ops/drop_block.py @@ -0,0 +1,155 @@ +import torch +import torch.fx +import torch.nn.functional as F +from torch import nn, Tensor + +from ..utils import _log_api_usage_once + + +def drop_block2d( + input: Tensor, p: float, block_size: int, inplace: bool = False, eps: float = 1e-06, training: bool = True +) -> Tensor: + """ + Implements DropBlock2d from `"DropBlock: A regularization method for convolutional networks" + `. + + Args: + input (Tensor[N, C, H, W]): The input tensor or 4-dimensions with the first one + being its batch i.e. a batch with ``N`` rows. + p (float): Probability of an element to be dropped. + block_size (int): Size of the block to drop. + inplace (bool): If set to ``True``, will do this operation in-place. Default: ``False``. + eps (float): A value added to the denominator for numerical stability. Default: 1e-6. + training (bool): apply dropblock if is ``True``. Default: ``True``. + + Returns: + Tensor[N, C, H, W]: The randomly zeroed tensor after dropblock. + """ + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): + _log_api_usage_once(drop_block2d) + if p < 0.0 or p > 1.0: + raise ValueError(f"drop probability has to be between 0 and 1, but got {p}.") + if input.ndim != 4: + raise ValueError(f"input should be 4 dimensional. Got {input.ndim} dimensions.") + if not training or p == 0.0: + return input + + N, C, H, W = input.size() + block_size = min(block_size, W, H) + # compute the gamma of Bernoulli distribution + gamma = (p * H * W) / ((block_size ** 2) * ((H - block_size + 1) * (W - block_size + 1))) + noise = torch.empty((N, C, H - block_size + 1, W - block_size + 1), dtype=input.dtype, device=input.device) + noise.bernoulli_(gamma) + + noise = F.pad(noise, [block_size // 2] * 4, value=0) + noise = F.max_pool2d(noise, stride=(1, 1), kernel_size=(block_size, block_size), padding=block_size // 2) + noise = 1 - noise + normalize_scale = noise.numel() / (eps + noise.sum()) + if inplace: + input.mul_(noise).mul_(normalize_scale) + else: + input = input * noise * normalize_scale + return input + + +def drop_block3d( + input: Tensor, p: float, block_size: int, inplace: bool = False, eps: float = 1e-06, training: bool = True +) -> Tensor: + """ + Implements DropBlock3d from `"DropBlock: A regularization method for convolutional networks" + `. + + Args: + input (Tensor[N, C, D, H, W]): The input tensor or 5-dimensions with the first one + being its batch i.e. a batch with ``N`` rows. + p (float): Probability of an element to be dropped. + block_size (int): Size of the block to drop. + inplace (bool): If set to ``True``, will do this operation in-place. Default: ``False``. + eps (float): A value added to the denominator for numerical stability. Default: 1e-6. + training (bool): apply dropblock if is ``True``. Default: ``True``. + + Returns: + Tensor[N, C, D, H, W]: The randomly zeroed tensor after dropblock. + """ + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): + _log_api_usage_once(drop_block3d) + if p < 0.0 or p > 1.0: + raise ValueError(f"drop probability has to be between 0 and 1, but got {p}.") + if input.ndim != 5: + raise ValueError(f"input should be 5 dimensional. Got {input.ndim} dimensions.") + if not training or p == 0.0: + return input + + N, C, D, H, W = input.size() + block_size = min(block_size, D, H, W) + # compute the gamma of Bernoulli distribution + gamma = (p * D * H * W) / ((block_size ** 3) * ((D - block_size + 1) * (H - block_size + 1) * (W - block_size + 1))) + noise = torch.empty( + (N, C, D - block_size + 1, H - block_size + 1, W - block_size + 1), dtype=input.dtype, device=input.device + ) + noise.bernoulli_(gamma) + + noise = F.pad(noise, [block_size // 2] * 6, value=0) + noise = F.max_pool3d( + noise, stride=(1, 1, 1), kernel_size=(block_size, block_size, block_size), padding=block_size // 2 + ) + noise = 1 - noise + normalize_scale = noise.numel() / (eps + noise.sum()) + if inplace: + input.mul_(noise).mul_(normalize_scale) + else: + input = input * noise * normalize_scale + return input + + +torch.fx.wrap("drop_block2d") + + +class DropBlock2d(nn.Module): + """ + See :func:`drop_block2d`. + """ + + def __init__(self, p: float, block_size: int, inplace: bool = False, eps: float = 1e-06) -> None: + super().__init__() + + self.p = p + self.block_size = block_size + self.inplace = inplace + self.eps = eps + + def forward(self, input: Tensor) -> Tensor: + """ + Args: + input (Tensor): Input feature map on which some areas will be randomly + dropped. + Returns: + Tensor: The tensor after DropBlock layer. + """ + return drop_block2d(input, self.p, self.block_size, self.inplace, self.eps, self.training) + + def __repr__(self) -> str: + s = f"{self.__class__.__name__}(p={self.p}, block_size={self.block_size}, inplace={self.inplace})" + return s + + +torch.fx.wrap("drop_block3d") + + +class DropBlock3d(DropBlock2d): + """ + See :func:`drop_block3d`. + """ + + def __init__(self, p: float, block_size: int, inplace: bool = False, eps: float = 1e-06) -> None: + super().__init__(p, block_size, inplace, eps) + + def forward(self, input: Tensor) -> Tensor: + """ + Args: + input (Tensor): Input feature map on which some areas will be randomly + dropped. + Returns: + Tensor: The tensor after DropBlock layer. + """ + return drop_block3d(input, self.p, self.block_size, self.inplace, self.eps, self.training)