Shortcuts

Source code for mmseg.models.decode_heads.sep_aspp_head

# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule

from mmseg.ops import resize
from ..builder import HEADS
from .aspp_head import ASPPHead, ASPPModule


class DepthwiseSeparableASPPModule(ASPPModule):
    """Atrous Spatial Pyramid Pooling (ASPP) Module with depthwise separable
    conv."""

    def __init__(self, **kwargs):
        super(DepthwiseSeparableASPPModule, self).__init__(**kwargs)
        for i, dilation in enumerate(self.dilations):
            if dilation > 1:
                self[i] = DepthwiseSeparableConvModule(
                    self.in_channels,
                    self.channels,
                    3,
                    dilation=dilation,
                    padding=dilation,
                    norm_cfg=self.norm_cfg,
                    act_cfg=self.act_cfg)


[docs]@HEADS.register_module() class DepthwiseSeparableASPPHead(ASPPHead): """Encoder-Decoder with Atrous Separable Convolution for Semantic Image Segmentation. This head is the implementation of `DeepLabV3+ <https://arxiv.org/abs/1802.02611>`_. Args: c1_in_channels (int): The input channels of c1 decoder. If is 0, the no decoder will be used. c1_channels (int): The intermediate channels of c1 decoder. """ def __init__(self, c1_in_channels, c1_channels, **kwargs): super(DepthwiseSeparableASPPHead, self).__init__(**kwargs) assert c1_in_channels >= 0 self.aspp_modules = DepthwiseSeparableASPPModule( dilations=self.dilations, in_channels=self.in_channels, channels=self.channels, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg) if c1_in_channels > 0: self.c1_bottleneck = ConvModule( c1_in_channels, c1_channels, 1, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg) else: self.c1_bottleneck = None self.sep_bottleneck = nn.Sequential( DepthwiseSeparableConvModule( self.channels + c1_channels, self.channels, 3, padding=1, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg), DepthwiseSeparableConvModule( self.channels, self.channels, 3, padding=1, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg))
[docs] def forward(self, inputs): """Forward function.""" x = self._transform_inputs(inputs) aspp_outs = [ resize( self.image_pool(x), size=x.size()[2:], mode='bilinear', align_corners=self.align_corners) ] aspp_outs.extend(self.aspp_modules(x)) aspp_outs = torch.cat(aspp_outs, dim=1) output = self.bottleneck(aspp_outs) if self.c1_bottleneck is not None: c1_output = self.c1_bottleneck(inputs[0]) output = resize( input=output, size=c1_output.shape[2:], mode='bilinear', align_corners=self.align_corners) output = torch.cat([output, c1_output], dim=1) output = self.sep_bottleneck(output) output = self.cls_seg(output) return output
Read the Docs v: latest
Versions
latest
stable
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.