-
-
Notifications
You must be signed in to change notification settings - Fork 1.8k
Add MAnet #310
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Add MAnet #310
Changes from 3 commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .model import MAnet |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,188 @@ | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
from ..base import modules as md | ||
|
||
|
||
class PAB(nn.Module): | ||
def __init__(self, in_channels, out_channels, im_channels=64): | ||
super(PAB, self).__init__() | ||
# Series of 1x1 conv to generate attention feature maps | ||
self.im_channels = im_channels | ||
self.in_channels = in_channels | ||
self.top_conv = nn.Conv2d(in_channels, im_channels, kernel_size=1) | ||
self.center_conv = nn.Conv2d(in_channels, im_channels, kernel_size=1) | ||
self.bottom_conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1) | ||
self.map_softmax = nn.Softmax(dim=1) | ||
self.out_conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1) | ||
|
||
def forward(self, x): | ||
bsize = x.size()[0] | ||
h = x.size()[2] | ||
w = x.size()[3] | ||
x_top = self.top_conv(x) | ||
x_center = self.center_conv(x) | ||
x_bottom = self.bottom_conv(x) | ||
|
||
x_top = x_top.flatten(2) | ||
x_center = x_center.flatten(2).transpose(1, 2) | ||
x_bottom = x_bottom.flatten(2).transpose(1, 2) | ||
|
||
sp_map = torch.matmul(x_center, x_top) | ||
sp_map = self.map_softmax(sp_map.view(bsize, -1)).view(bsize, h*w, h*w) | ||
sp_map = torch.matmul(sp_map, x_bottom) | ||
sp_map = sp_map.reshape(bsize, self.in_channels, h, w) | ||
x = x + sp_map | ||
x = self.out_conv(x) | ||
return x | ||
|
||
|
||
class MFAB(nn.Module): | ||
def __init__(self, in_channels, skip_channels, out_channels, use_batchnorm=True, reduction=16): | ||
# MFAB is just a modified version of SE-blocks, one for skip, one for input | ||
super(MFAB, self).__init__() | ||
self.hl_conv = nn.Sequential( | ||
md.Conv2dReLU( | ||
in_channels, | ||
in_channels, | ||
kernel_size=3, | ||
padding=1, | ||
use_batchnorm=use_batchnorm, | ||
), | ||
md.Conv2dReLU( | ||
in_channels, | ||
skip_channels, | ||
kernel_size=1, | ||
use_batchnorm=use_batchnorm, | ||
) | ||
) | ||
self.SE_ll = nn.Sequential( | ||
nn.AdaptiveAvgPool2d(1), | ||
nn.Conv2d(skip_channels, skip_channels // reduction, 1), | ||
nn.ReLU(inplace=True), | ||
nn.Conv2d(skip_channels // reduction, skip_channels, 1), | ||
nn.Sigmoid(), | ||
) | ||
self.SE_hl = nn.Sequential( | ||
nn.AdaptiveAvgPool2d(1), | ||
nn.Conv2d(skip_channels, skip_channels // reduction, 1), | ||
nn.ReLU(inplace=True), | ||
nn.Conv2d(skip_channels // reduction, skip_channels, 1), | ||
nn.Sigmoid(), | ||
) | ||
self.conv1 = md.Conv2dReLU( | ||
skip_channels + skip_channels, # we transform C-prime form high level to C from skip connection | ||
out_channels, | ||
kernel_size=3, | ||
padding=1, | ||
use_batchnorm=use_batchnorm, | ||
) | ||
self.conv2 = md.Conv2dReLU( | ||
out_channels, | ||
out_channels, | ||
kernel_size=3, | ||
padding=1, | ||
use_batchnorm=use_batchnorm, | ||
) | ||
|
||
def forward(self, x, skip=None): | ||
x = self.hl_conv(x) | ||
x = F.interpolate(x, scale_factor=2, mode="nearest") | ||
attention_hl = self.SE_hl(x) | ||
if skip is not None: | ||
attention_ll = self.SE_ll(skip) | ||
attention_hl = attention_hl + attention_ll | ||
x = x * attention_hl | ||
x = torch.cat([x, skip], dim=1) | ||
x = self.conv1(x) | ||
x = self.conv2(x) | ||
return x | ||
|
||
|
||
class DecoderBlock(nn.Module): | ||
def __init__( | ||
self, | ||
in_channels, | ||
skip_channels, | ||
out_channels, | ||
use_batchnorm=True | ||
): | ||
super().__init__() | ||
self.conv1 = md.Conv2dReLU( | ||
in_channels + skip_channels, | ||
out_channels, | ||
kernel_size=3, | ||
padding=1, | ||
use_batchnorm=use_batchnorm, | ||
) | ||
self.conv2 = md.Conv2dReLU( | ||
out_channels, | ||
out_channels, | ||
kernel_size=3, | ||
padding=1, | ||
use_batchnorm=use_batchnorm, | ||
) | ||
|
||
def forward(self, x, skip=None): | ||
x = F.interpolate(x, scale_factor=2, mode="nearest") | ||
if skip is not None: | ||
x = torch.cat([x, skip], dim=1) | ||
x = self.conv1(x) | ||
x = self.conv2(x) | ||
return x | ||
|
||
|
||
class MAnetDecoder(nn.Module): | ||
def __init__( | ||
self, | ||
encoder_channels, | ||
decoder_channels, | ||
n_blocks=5, | ||
reduction=16, | ||
use_batchnorm=True, | ||
im_channels=64 | ||
): | ||
super().__init__() | ||
|
||
if n_blocks != len(decoder_channels): | ||
raise ValueError( | ||
"Model depth is {}, but you provide `decoder_channels` for {} blocks.".format( | ||
n_blocks, len(decoder_channels) | ||
) | ||
) | ||
|
||
encoder_channels = encoder_channels[1:] # remove first skip with same spatial resolution | ||
encoder_channels = encoder_channels[::-1] # reverse channels to start from head of encoder | ||
|
||
# computing blocks input and output channels | ||
head_channels = encoder_channels[0] | ||
in_channels = [head_channels] + list(decoder_channels[:-1]) | ||
skip_channels = list(encoder_channels[1:]) + [0] | ||
out_channels = decoder_channels | ||
|
||
self.center = PAB(head_channels, head_channels, im_channels=im_channels) | ||
|
||
# combine decoder keyword arguments | ||
kwargs = dict(use_batchnorm=use_batchnorm) # no attention type here | ||
blocks = [ | ||
MFAB(in_ch, skip_ch, out_ch, reduction=reduction, **kwargs) if skip_ch > 0 else | ||
DecoderBlock(in_ch, skip_ch, out_ch, **kwargs) | ||
for in_ch, skip_ch, out_ch in zip(in_channels, skip_channels, out_channels) | ||
] | ||
# for the last we dont have skip connection -> use simple decoder block | ||
self.blocks = nn.ModuleList(blocks) | ||
|
||
def forward(self, *features): | ||
|
||
features = features[1:] # remove first skip with same spatial resolution | ||
features = features[::-1] # reverse channels to start from head of encoder | ||
|
||
head = features[0] | ||
skips = features[1:] | ||
|
||
x = self.center(head) | ||
for i, decoder_block in enumerate(self.blocks): | ||
skip = skips[i] if i < len(skips) else None | ||
x = decoder_block(x, skip) | ||
|
||
return x |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
from typing import Optional, Union, List | ||
from .decoder import MAnetDecoder | ||
from ..encoders import get_encoder | ||
from ..base import SegmentationModel | ||
from ..base import SegmentationHead, ClassificationHead | ||
|
||
|
||
class MAnet(SegmentationModel): | ||
"""MAnet_ : Multi-scale Attention Net. | ||
The MA-Net can capture rich contextual dependencies based on the attention mechanism, using two blocks: | ||
Position-wise Attention Block (PAB, which captures the spatial dependencies between pixels in a global view) | ||
and Multi-scale Fusion Attention Block (MFAB, which captures the channel dependencies between any feature map by | ||
multi-scale semantic feature fusion) | ||
Args: | ||
qubvel marked this conversation as resolved.
Show resolved
Hide resolved
|
||
encoder_name: name of classification model (without last dense layers) used as feature | ||
extractor to build segmentation model. | ||
encoder_depth (int): number of stages used in decoder, larger depth - more features are generated. | ||
e.g. for depth=3 encoder will generate list of features with following spatial shapes | ||
[(H,W), (H/2, W/2), (H/4, W/4), (H/8, W/8)], so in general the deepest feature tensor will have | ||
spatial resolution (H/(2^depth), W/(2^depth)] | ||
encoder_weights: one of ``None`` (random initialization), ``imagenet`` (pre-training on ImageNet). | ||
decoder_channels: list of numbers of ``Conv2D`` layer filters in decoder blocks | ||
decoder_use_batchnorm: if ``True``, ``BatchNormalisation`` layer between ``Conv2D`` and ``Activation`` layers | ||
is used. If 'inplace' InplaceABN will be used, allows to decrease memory consumption. | ||
One of [True, False, 'inplace'] | ||
decoder_im_channels: number of layers for PAB layer. | ||
in_channels: number of input channels for model, default is 3. | ||
classes: a number of classes for output (output shape - ``(batch, classes, h, w)``). | ||
activation: activation function to apply after final convolution; | ||
One of [``sigmoid``, ``softmax``, ``logsoftmax``, ``identity``, callable, None] | ||
aux_params: if specified model will have additional classification auxiliary output | ||
build on top of encoder, supported params: | ||
- classes (int): number of classes | ||
- pooling (str): one of 'max', 'avg'. Default is 'avg'. | ||
- dropout (float): dropout factor in [0, 1) | ||
- activation (str): activation function to apply "sigmoid"/"softmax" (could be None to return logits) | ||
Returns: | ||
``torch.nn.Module``: **MAnet** | ||
.. _MAnet: | ||
https://ieeexplore.ieee.org/abstract/document/9201310 | ||
""" | ||
|
||
def __init__( | ||
self, | ||
encoder_name: str = "resnet34", | ||
encoder_depth: int = 5, | ||
encoder_weights: str = "imagenet", | ||
decoder_use_batchnorm: bool = True, | ||
decoder_channels: List[int] = (256, 128, 64, 32, 16), | ||
decoder_im_channels: int = 64, | ||
qubvel marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
in_channels: int = 3, | ||
classes: int = 1, | ||
activation: Optional[Union[str, callable]] = None, | ||
aux_params: Optional[dict] = None | ||
): | ||
super().__init__() | ||
|
||
self.encoder = get_encoder( | ||
encoder_name, | ||
in_channels=in_channels, | ||
depth=encoder_depth, | ||
weights=encoder_weights, | ||
) | ||
|
||
self.decoder = MAnetDecoder( | ||
encoder_channels=self.encoder.out_channels, | ||
decoder_channels=decoder_channels, | ||
n_blocks=encoder_depth, | ||
use_batchnorm=decoder_use_batchnorm, | ||
im_channels=decoder_im_channels | ||
) | ||
|
||
self.segmentation_head = SegmentationHead( | ||
in_channels=decoder_channels[-1], | ||
out_channels=classes, | ||
activation=activation, | ||
kernel_size=3, | ||
) | ||
|
||
if aux_params is not None: | ||
self.classification_head = ClassificationHead( | ||
in_channels=self.encoder.out_channels[-1], **aux_params | ||
) | ||
else: | ||
self.classification_head = None | ||
|
||
self.name = "manet-{}".format(encoder_name) | ||
self.initialize() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.