Shortcuts

Source code for mmseg.models.backbones.stdc

# Copyright (c) OpenMMLab. All rights reserved.
"""Modified from https://github.com/MichaelFan01/STDC-Seg."""
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import ConvModule
from mmcv.runner.base_module import BaseModule, ModuleList, Sequential

from mmseg.ops import resize
from ..builder import BACKBONES, build_backbone
from .bisenetv1 import AttentionRefinementModule


class STDCModule(BaseModule):
    """STDCModule.

    Args:
        in_channels (int): The number of input channels.
        out_channels (int): The number of output channels before scaling.
        stride (int): The number of stride for the first conv layer.
        norm_cfg (dict): Config dict for normalization layer. Default: None.
        act_cfg (dict): The activation config for conv layers.
        num_convs (int): Numbers of conv layers.
        fusion_type (str): Type of fusion operation. Default: 'add'.
        init_cfg (dict or list[dict], optional): Initialization config dict.
            Default: None.
    """

    def __init__(self,
                 in_channels,
                 out_channels,
                 stride,
                 norm_cfg=None,
                 act_cfg=None,
                 num_convs=4,
                 fusion_type='add',
                 init_cfg=None):
        super(STDCModule, self).__init__(init_cfg=init_cfg)
        assert num_convs > 1
        assert fusion_type in ['add', 'cat']
        self.stride = stride
        self.with_downsample = True if self.stride == 2 else False
        self.fusion_type = fusion_type

        self.layers = ModuleList()
        conv_0 = ConvModule(
            in_channels, out_channels // 2, kernel_size=1, norm_cfg=norm_cfg)

        if self.with_downsample:
            self.downsample = ConvModule(
                out_channels // 2,
                out_channels // 2,
                kernel_size=3,
                stride=2,
                padding=1,
                groups=out_channels // 2,
                norm_cfg=norm_cfg,
                act_cfg=None)

            if self.fusion_type == 'add':
                self.layers.append(nn.Sequential(conv_0, self.downsample))
                self.skip = Sequential(
                    ConvModule(
                        in_channels,
                        in_channels,
                        kernel_size=3,
                        stride=2,
                        padding=1,
                        groups=in_channels,
                        norm_cfg=norm_cfg,
                        act_cfg=None),
                    ConvModule(
                        in_channels,
                        out_channels,
                        1,
                        norm_cfg=norm_cfg,
                        act_cfg=None))
            else:
                self.layers.append(conv_0)
                self.skip = nn.AvgPool2d(kernel_size=3, stride=2, padding=1)
        else:
            self.layers.append(conv_0)

        for i in range(1, num_convs):
            out_factor = 2**(i + 1) if i != num_convs - 1 else 2**i
            self.layers.append(
                ConvModule(
                    out_channels // 2**i,
                    out_channels // out_factor,
                    kernel_size=3,
                    stride=1,
                    padding=1,
                    norm_cfg=norm_cfg,
                    act_cfg=act_cfg))

    def forward(self, inputs):
        if self.fusion_type == 'add':
            out = self.forward_add(inputs)
        else:
            out = self.forward_cat(inputs)
        return out

    def forward_add(self, inputs):
        layer_outputs = []
        x = inputs.clone()
        for layer in self.layers:
            x = layer(x)
            layer_outputs.append(x)
        if self.with_downsample:
            inputs = self.skip(inputs)

        return torch.cat(layer_outputs, dim=1) + inputs

    def forward_cat(self, inputs):
        x0 = self.layers[0](inputs)
        layer_outputs = [x0]
        for i, layer in enumerate(self.layers[1:]):
            if i == 0:
                if self.with_downsample:
                    x = layer(self.downsample(x0))
                else:
                    x = layer(x0)
            else:
                x = layer(x)
            layer_outputs.append(x)
        if self.with_downsample:
            layer_outputs[0] = self.skip(x0)
        return torch.cat(layer_outputs, dim=1)


class FeatureFusionModule(BaseModule):
    """Feature Fusion Module. This module is different from FeatureFusionModule
    in BiSeNetV1. It uses two ConvModules in `self.attention` whose inter
    channel number is calculated by given `scale_factor`, while
    FeatureFusionModule in BiSeNetV1 only uses one ConvModule in
    `self.conv_atten`.

    Args:
        in_channels (int): The number of input channels.
        out_channels (int): The number of output channels.
        scale_factor (int): The number of channel scale factor.
            Default: 4.
        norm_cfg (dict): Config dict for normalization layer.
            Default: dict(type='BN').
        act_cfg (dict): The activation config for conv layers.
            Default: dict(type='ReLU').
        init_cfg (dict or list[dict], optional): Initialization config dict.
            Default: None.
    """

    def __init__(self,
                 in_channels,
                 out_channels,
                 scale_factor=4,
                 norm_cfg=dict(type='BN'),
                 act_cfg=dict(type='ReLU'),
                 init_cfg=None):
        super(FeatureFusionModule, self).__init__(init_cfg=init_cfg)
        channels = out_channels // scale_factor
        self.conv0 = ConvModule(
            in_channels, out_channels, 1, norm_cfg=norm_cfg, act_cfg=act_cfg)
        self.attention = nn.Sequential(
            nn.AdaptiveAvgPool2d((1, 1)),
            ConvModule(
                out_channels,
                channels,
                1,
                norm_cfg=None,
                bias=False,
                act_cfg=act_cfg),
            ConvModule(
                channels,
                out_channels,
                1,
                norm_cfg=None,
                bias=False,
                act_cfg=None), nn.Sigmoid())

    def forward(self, spatial_inputs, context_inputs):
        inputs = torch.cat([spatial_inputs, context_inputs], dim=1)
        x = self.conv0(inputs)
        attn = self.attention(x)
        x_attn = x * attn
        return x_attn + x


[docs]@BACKBONES.register_module() class STDCNet(BaseModule): """This backbone is the implementation of `Rethinking BiSeNet For Real-time Semantic Segmentation <https://arxiv.org/abs/2104.13188>`_. Args: stdc_type (int): The type of backbone structure, `STDCNet1` and`STDCNet2` denotes two main backbones in paper, whose FLOPs is 813M and 1446M, respectively. in_channels (int): The num of input_channels. channels (tuple[int]): The output channels for each stage. bottleneck_type (str): The type of STDC Module type, the value must be 'add' or 'cat'. norm_cfg (dict): Config dict for normalization layer. act_cfg (dict): The activation config for conv layers. num_convs (int): Numbers of conv layer at each STDC Module. Default: 4. with_final_conv (bool): Whether add a conv layer at the Module output. Default: True. pretrained (str, optional): Model pretrained path. Default: None. init_cfg (dict or list[dict], optional): Initialization config dict. Default: None. Example: >>> import torch >>> stdc_type = 'STDCNet1' >>> in_channels = 3 >>> channels = (32, 64, 256, 512, 1024) >>> bottleneck_type = 'cat' >>> inputs = torch.rand(1, 3, 1024, 2048) >>> self = STDCNet(stdc_type, in_channels, ... channels, bottleneck_type).eval() >>> outputs = self.forward(inputs) >>> for i in range(len(outputs)): ... print(f'outputs[{i}].shape = {outputs[i].shape}') outputs[0].shape = torch.Size([1, 256, 128, 256]) outputs[1].shape = torch.Size([1, 512, 64, 128]) outputs[2].shape = torch.Size([1, 1024, 32, 64]) """ arch_settings = { 'STDCNet1': [(2, 1), (2, 1), (2, 1)], 'STDCNet2': [(2, 1, 1, 1), (2, 1, 1, 1, 1), (2, 1, 1)] } def __init__(self, stdc_type, in_channels, channels, bottleneck_type, norm_cfg, act_cfg, num_convs=4, with_final_conv=False, pretrained=None, init_cfg=None): super(STDCNet, self).__init__(init_cfg=init_cfg) assert stdc_type in self.arch_settings, \ f'invalid structure {stdc_type} for STDCNet.' assert bottleneck_type in ['add', 'cat'],\ f'bottleneck_type must be `add` or `cat`, got {bottleneck_type}' assert len(channels) == 5,\ f'invalid channels length {len(channels)} for STDCNet.' self.in_channels = in_channels self.channels = channels self.stage_strides = self.arch_settings[stdc_type] self.prtrained = pretrained self.num_convs = num_convs self.with_final_conv = with_final_conv self.stages = ModuleList([ ConvModule( self.in_channels, self.channels[0], kernel_size=3, stride=2, padding=1, norm_cfg=norm_cfg, act_cfg=act_cfg), ConvModule( self.channels[0], self.channels[1], kernel_size=3, stride=2, padding=1, norm_cfg=norm_cfg, act_cfg=act_cfg) ]) # `self.num_shallow_features` is the number of shallow modules in # `STDCNet`, which is noted as `Stage1` and `Stage2` in original paper. # They are both not used for following modules like Attention # Refinement Module and Feature Fusion Module. # Thus they would be cut from `outs`. Please refer to Figure 4 # of original paper for more details. self.num_shallow_features = len(self.stages) for strides in self.stage_strides: idx = len(self.stages) - 1 self.stages.append( self._make_stage(self.channels[idx], self.channels[idx + 1], strides, norm_cfg, act_cfg, bottleneck_type)) # After appending, `self.stages` is a ModuleList including several # shallow modules and STDCModules. # (len(self.stages) == # self.num_shallow_features + len(self.stage_strides)) if self.with_final_conv: self.final_conv = ConvModule( self.channels[-1], max(1024, self.channels[-1]), 1, norm_cfg=norm_cfg, act_cfg=act_cfg) def _make_stage(self, in_channels, out_channels, strides, norm_cfg, act_cfg, bottleneck_type): layers = [] for i, stride in enumerate(strides): layers.append( STDCModule( in_channels if i == 0 else out_channels, out_channels, stride, norm_cfg, act_cfg, num_convs=self.num_convs, fusion_type=bottleneck_type)) return Sequential(*layers)
[docs] def forward(self, x): outs = [] for stage in self.stages: x = stage(x) outs.append(x) if self.with_final_conv: outs[-1] = self.final_conv(outs[-1]) outs = outs[self.num_shallow_features:] return tuple(outs)
[docs]@BACKBONES.register_module() class STDCContextPathNet(BaseModule): """STDCNet with Context Path. The `outs` below is a list of three feature maps from deep to shallow, whose height and width is from small to big, respectively. The biggest feature map of `outs` is outputted for `STDCHead`, where Detail Loss would be calculated by Detail Ground-truth. The other two feature maps are used for Attention Refinement Module, respectively. Besides, the biggest feature map of `outs` and the last output of Attention Refinement Module are concatenated for Feature Fusion Module. Then, this fusion feature map `feat_fuse` would be outputted for `decode_head`. More details please refer to Figure 4 of original paper. Args: backbone_cfg (dict): Config dict for stdc backbone. last_in_channels (tuple(int)), The number of channels of last two feature maps from stdc backbone. Default: (1024, 512). out_channels (int): The channels of output feature maps. Default: 128. ffm_cfg (dict): Config dict for Feature Fusion Module. Default: `dict(in_channels=512, out_channels=256, scale_factor=4)`. upsample_mode (str): Algorithm used for upsampling: ``'nearest'`` | ``'linear'`` | ``'bilinear'`` | ``'bicubic'`` | ``'trilinear'``. Default: ``'nearest'``. align_corners (str): align_corners argument of F.interpolate. It must be `None` if upsample_mode is ``'nearest'``. Default: None. norm_cfg (dict): Config dict for normalization layer. Default: dict(type='BN'). init_cfg (dict or list[dict], optional): Initialization config dict. Default: None. Return: outputs (tuple): The tuple of list of output feature map for auxiliary heads and decoder head. """ def __init__(self, backbone_cfg, last_in_channels=(1024, 512), out_channels=128, ffm_cfg=dict( in_channels=512, out_channels=256, scale_factor=4), upsample_mode='nearest', align_corners=None, norm_cfg=dict(type='BN'), init_cfg=None): super(STDCContextPathNet, self).__init__(init_cfg=init_cfg) self.backbone = build_backbone(backbone_cfg) self.arms = ModuleList() self.convs = ModuleList() for channels in last_in_channels: self.arms.append(AttentionRefinementModule(channels, out_channels)) self.convs.append( ConvModule( out_channels, out_channels, 3, padding=1, norm_cfg=norm_cfg)) self.conv_avg = ConvModule( last_in_channels[0], out_channels, 1, norm_cfg=norm_cfg) self.ffm = FeatureFusionModule(**ffm_cfg) self.upsample_mode = upsample_mode self.align_corners = align_corners
[docs] def forward(self, x): outs = list(self.backbone(x)) avg = F.adaptive_avg_pool2d(outs[-1], 1) avg_feat = self.conv_avg(avg) feature_up = resize( avg_feat, size=outs[-1].shape[2:], mode=self.upsample_mode, align_corners=self.align_corners) arms_out = [] for i in range(len(self.arms)): x_arm = self.arms[i](outs[len(outs) - 1 - i]) + feature_up feature_up = resize( x_arm, size=outs[len(outs) - 1 - i - 1].shape[2:], mode=self.upsample_mode, align_corners=self.align_corners) feature_up = self.convs[i](feature_up) arms_out.append(feature_up) feat_fuse = self.ffm(outs[0], arms_out[1]) # The `outputs` has four feature maps. # `outs[0]` is outputted for `STDCHead` auxiliary head. # Two feature maps of `arms_out` are outputted for auxiliary head. # `feat_fuse` is outputted for decoder head. outputs = [outs[0]] + list(arms_out) + [feat_fuse] return tuple(outputs)
Read the Docs v: latest
Versions
latest
stable
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.