Shortcuts

Source code for mmseg.models.decode_heads.aspp_head

# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
from mmcv.cnn import ConvModule

from mmseg.ops import resize
from ..builder import HEADS
from .decode_head import BaseDecodeHead


class ASPPModule(nn.ModuleList):
    """Atrous Spatial Pyramid Pooling (ASPP) Module.

    Args:
        dilations (tuple[int]): Dilation rate of each layer.
        in_channels (int): Input channels.
        channels (int): Channels after modules, before conv_seg.
        conv_cfg (dict|None): Config of conv layers.
        norm_cfg (dict|None): Config of norm layers.
        act_cfg (dict): Config of activation layers.
    """

    def __init__(self, dilations, in_channels, channels, conv_cfg, norm_cfg,
                 act_cfg):
        super(ASPPModule, self).__init__()
        self.dilations = dilations
        self.in_channels = in_channels
        self.channels = channels
        self.conv_cfg = conv_cfg
        self.norm_cfg = norm_cfg
        self.act_cfg = act_cfg
        for dilation in dilations:
            self.append(
                ConvModule(
                    self.in_channels,
                    self.channels,
                    1 if dilation == 1 else 3,
                    dilation=dilation,
                    padding=0 if dilation == 1 else dilation,
                    conv_cfg=self.conv_cfg,
                    norm_cfg=self.norm_cfg,
                    act_cfg=self.act_cfg))

    def forward(self, x):
        """Forward function."""
        aspp_outs = []
        for aspp_module in self:
            aspp_outs.append(aspp_module(x))

        return aspp_outs


[docs]@HEADS.register_module() class ASPPHead(BaseDecodeHead): """Rethinking Atrous Convolution for Semantic Image Segmentation. This head is the implementation of `DeepLabV3 <https://arxiv.org/abs/1706.05587>`_. Args: dilations (tuple[int]): Dilation rates for ASPP module. Default: (1, 6, 12, 18). """ def __init__(self, dilations=(1, 6, 12, 18), **kwargs): super(ASPPHead, self).__init__(**kwargs) assert isinstance(dilations, (list, tuple)) self.dilations = dilations self.image_pool = nn.Sequential( nn.AdaptiveAvgPool2d(1), ConvModule( self.in_channels, self.channels, 1, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg)) self.aspp_modules = ASPPModule( dilations, self.in_channels, self.channels, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg) self.bottleneck = ConvModule( (len(dilations) + 1) * self.channels, self.channels, 3, padding=1, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg) def _forward_feature(self, inputs): """Forward function for feature maps before classifying each pixel with ``self.cls_seg`` fc. Args: inputs (list[Tensor]): List of multi-level img features. Returns: feats (Tensor): A tensor of shape (batch_size, self.channels, H, W) which is feature map for last layer of decoder head. """ x = self._transform_inputs(inputs) aspp_outs = [ resize( self.image_pool(x), size=x.size()[2:], mode='bilinear', align_corners=self.align_corners) ] aspp_outs.extend(self.aspp_modules(x)) aspp_outs = torch.cat(aspp_outs, dim=1) feats = self.bottleneck(aspp_outs) return feats
[docs] def forward(self, inputs): """Forward function.""" output = self._forward_feature(inputs) output = self.cls_seg(output) return output
Read the Docs v: latest
Versions
latest
stable
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.