Skip to content

Document backbone output stride and minimum input size for DeepLabV3 #7955

@sudotensor

Description

@sudotensor

📚 The doc issue

Currently, the documentation for DeepLabV3 (segmentation) doesn't state both the minimum input size1 and the output stride of the backbone network, as configured by the original contributor in #820. Both these pieces of information are interlinked with the choice of atrous rates (see background) and vital for end-users in configuring and effectively using this class of models for their specific needs.

Many use off-the-shelf models, more generally so without concerning themselves with the details of the underlying architecture. For some, this may work well, but there are cases where such details (incl. implementation caveats) can help end-users make well-informed decisions that impact their results; one such case related to medical imaging is detailed later in the issue. This issue covers a fair few technical details in explaining why the aforementioned pieces of information are vital. Perhaps these can make their way into the documentation where deemed relevant.

Background

Output Stride

The term output stride originates from the DeepLabV3 paper (Link). The authors note that traditional deep convolutional networks aggressively decimate valuable, detailed information in feature maps through repeated striding or pooling operations. To quantify this decimation, they introduce the notion of output stride, defined as the ratio of the input spatial resolution and the final backbone-encoded feature map resolution; in Fig. 1, the output stride is 1024/64 = 16. Models with a low output stride tend to produce higher quality, finely detailed segmentation masks as there's more information available for reconstruction (upsampling stage). However, such models are also more resource-intensive due to the comparatively larger intermediary activations.

Atrous (Dilated) Convolutions

CleanShot 2023-09-11 at 12 16 37@2x

Instead of using strided transposed convolutions for reconstruction (like UNets), DeepLabV3 uses an atrous spatial pyramidal pooling module, at the core of which lies three atrous convolutions configured with different atrous rates; the convolutional layers in question form an inverted triangle in the architecture diagram visualised in Fig. 1. Atrous convolutions facilitate a larger receptive field without aggressively downsampling the feature map (unlike striding) by poking holes through convolutional kernels to inflate their size, as visualised in the figure on the right; the receptive field is demarcated by dotted vertical markers. Varying the atrous or dilation rate allows the capture of information from a larger spatial context. The ASPP module, thus, is architected to learn multi-scale contextual information. The choice of atrous rates impacts the module's effectiveness in capturing said information.

CleanShot 2023-09-11 at 12 27 25@2x

DeepLabV3, as initially proposed, uses atrous rates 6, 12, and 18 when the output stride is 16; these rates are doubled for an output stride of 82. The authors mention that its crucial to choose atrous rates according to the output stride and feature map resolution. There's a good reason for this: when the dilation/atrous rate grows, kernel parameters, except the central one(s), spend a growing fraction of their time "observing" the zero-padded region rather than the feature region. In the extreme cases that this rate grows to or beyond the feature map size, the convolution degenerates to a pointwise one, as visualised in the figure on the left (6-padded, 6-dilated 3×3 conv. on a 5×5 map).

Figure 1

CleanShot 2023-09-11 at 11 58 12@2x

Now, DeepLabV3 (Google Research) was originally written in TensorFlow (Link), and the Torchvision port (Link) configures the network with hardcoded atrous rates 12, 24, and 36. Following the guideline noted in the paper (Link), the backbone networks should be configured with an output stride of 83. So, the minimum input size to avoid the atrous convolutions from degenerating to pointwise ones is 36 × 8 = 2882.

In echonet/dynamic and echonet/lvh, the authors use DeepLabV3 (ResNet-50 backbone) via Torchvision with input images having a spatial resolution of 112×112; that's more than 2x lower than the minimum. With an output stride of 8, the backbone-encoded feature map resolution is 14×14. Considering the chosen rates, all three atrous convolutions in the ASPP module, each with 4.71 million parameters and together accounting for 35.7% of the model's total parameter count (39.6M), are completely ineffective! This is backed up via an experiment of low-magnitude weight pruning (see sidetrack).

Furthermore, DeepLabV3 was designed for segmenting objects at multiple scales (e.g., cars in traffic). As such, it's inherently ineffective on echocardiograms (fairly standardised), where the segmentation target is a single, continuous blob in roughly the same position for every sample, not differing much in scale. Any researcher who neither has the time to dive deep into architectural details nor the experience to navigate and understand source code in intricate codebases may have easily missed these critical requirements.

Sidetrack

As nearly all of the kernel parameters of the three atrous convolutions only observe zeros, their weights are bound to be tiny and much less likely to influence the output. It's easy to verify this using globally scoped L1-norm unstructured weight pruning. Figure 2 showcases parameter sparsity diagrams of the entire network after pruning at different sparsity targets; these diagrams were obtained by taking all the network parameters via the following and reshaping the resulting vector into a square-ish matrix:

torch.nn.utils.parameters_to_vector(model.parameters(recurse=True)).detach() 

Figure 2

CleanShot 2023-09-11 at 15 34 54@2x CleanShot 2023-09-11 at 15 53 54@2x

The 50% sparse case is the most interesting. The distinctive characteristic of this subfigure is the missing band of parameters, which, upon further investigation, was identified to belong to the three atrous convolutions in question. Manually changing the atrous rates to 2, 4, and 6 in an effort to adapt to the low backbone-encoded feature map resolution4 results in a much denser band for the same sparsity target, as seen in the figure on the right. However, presumably due to the aforementioned lack of variety in scale, there wasn't a notable performance difference.

Unrelated yet interesting is the compression rate. The 98.44% sparse case is within 1% of the baseline's DICE score while only having 633K NNZ parameters. Applying channel pruning via Microsoft's NNI results in a model with 34.5K parameters (1148x smaller) that's within 4% of the baseline. Simply choosing an off-the-shelf model without considering task complexity is analogous to using a truck to transport a single grape.

Suggest a potential alternative/fix

The proposal is to update the documentation with a callout card that makes the requirements for effective use clear. It'd also be great if the documentation highlighted relevant characteristics of the underlying architecture (e.g., "works best for inputs where the segmentation target(s) differ in position and scale"), which would serve as a concise guideline for end-users.

Footnotes

  1. The Torch hub model card (Link) makes a remark on this. However, not only is it easy to miss, there isn't any reasoning to explain why the height and width of the input are expected to be at least 224.

  2. The DeepLabV3 authors assume that with a low output stride of 8, the backbone encoded feature map resolution is high enough to use larger atrous rates 12, 24, and 36. This expectation sets a minimum constraint on the resolution of the input image: 36 * 8 = 288. 2

  3. Torchvision supports three backbone networks: ResNet-50, ResNet-101 and MobileNet-V3 (Large). While the ResNets are correctly configured with an output stride of 8, MobileNet-V3 (Large) is configured with an output stride of 16. This looks to be an implementation error and will be raised shortly via another issue.

  4. Perhaps it might be useful to allow end-users to configure the atrous rates for added flexibility in use cases, such as the one presented, where the input resolution is much smaller than the minimum (determined by the default hard-coded rates). The same applies to the converse where the input is quite large; gigapixel images are common with certain imaging modalities in healthcare settings.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions