Shortcuts

Source code for mmseg.models.decode_heads.enc_head

# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import ConvModule, build_norm_layer

from mmseg.ops import Encoding, resize
from ..builder import HEADS, build_loss
from .decode_head import BaseDecodeHead


class EncModule(nn.Module):
    """Encoding Module used in EncNet.

    Args:
        in_channels (int): Input channels.
        num_codes (int): Number of code words.
        conv_cfg (dict|None): Config of conv layers.
        norm_cfg (dict|None): Config of norm layers.
        act_cfg (dict): Config of activation layers.
    """

    def __init__(self, in_channels, num_codes, conv_cfg, norm_cfg, act_cfg):
        super(EncModule, self).__init__()
        self.encoding_project = ConvModule(
            in_channels,
            in_channels,
            1,
            conv_cfg=conv_cfg,
            norm_cfg=norm_cfg,
            act_cfg=act_cfg)
        # TODO: resolve this hack
        # change to 1d
        if norm_cfg is not None:
            encoding_norm_cfg = norm_cfg.copy()
            if encoding_norm_cfg['type'] in ['BN', 'IN']:
                encoding_norm_cfg['type'] += '1d'
            else:
                encoding_norm_cfg['type'] = encoding_norm_cfg['type'].replace(
                    '2d', '1d')
        else:
            # fallback to BN1d
            encoding_norm_cfg = dict(type='BN1d')
        self.encoding = nn.Sequential(
            Encoding(channels=in_channels, num_codes=num_codes),
            build_norm_layer(encoding_norm_cfg, num_codes)[1],
            nn.ReLU(inplace=True))
        self.fc = nn.Sequential(
            nn.Linear(in_channels, in_channels), nn.Sigmoid())

    def forward(self, x):
        """Forward function."""
        encoding_projection = self.encoding_project(x)
        encoding_feat = self.encoding(encoding_projection).mean(dim=1)
        batch_size, channels, _, _ = x.size()
        gamma = self.fc(encoding_feat)
        y = gamma.view(batch_size, channels, 1, 1)
        output = F.relu_(x + x * y)
        return encoding_feat, output


[docs]@HEADS.register_module() class EncHead(BaseDecodeHead): """Context Encoding for Semantic Segmentation. This head is the implementation of `EncNet <https://arxiv.org/abs/1803.08904>`_. Args: num_codes (int): Number of code words. Default: 32. use_se_loss (bool): Whether use Semantic Encoding Loss (SE-loss) to regularize the training. Default: True. add_lateral (bool): Whether use lateral connection to fuse features. Default: False. loss_se_decode (dict): Config of decode loss. Default: dict(type='CrossEntropyLoss', use_sigmoid=True). """ def __init__(self, num_codes=32, use_se_loss=True, add_lateral=False, loss_se_decode=dict( type='CrossEntropyLoss', use_sigmoid=True, loss_weight=0.2), **kwargs): super(EncHead, self).__init__( input_transform='multiple_select', **kwargs) self.use_se_loss = use_se_loss self.add_lateral = add_lateral self.num_codes = num_codes self.bottleneck = ConvModule( self.in_channels[-1], self.channels, 3, padding=1, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg) if add_lateral: self.lateral_convs = nn.ModuleList() for in_channels in self.in_channels[:-1]: # skip the last one self.lateral_convs.append( ConvModule( in_channels, self.channels, 1, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg)) self.fusion = ConvModule( len(self.in_channels) * self.channels, self.channels, 3, padding=1, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg) self.enc_module = EncModule( self.channels, num_codes=num_codes, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg) if self.use_se_loss: self.loss_se_decode = build_loss(loss_se_decode) self.se_layer = nn.Linear(self.channels, self.num_classes)
[docs] def forward(self, inputs): """Forward function.""" inputs = self._transform_inputs(inputs) feat = self.bottleneck(inputs[-1]) if self.add_lateral: laterals = [ resize( lateral_conv(inputs[i]), size=feat.shape[2:], mode='bilinear', align_corners=self.align_corners) for i, lateral_conv in enumerate(self.lateral_convs) ] feat = self.fusion(torch.cat([feat, *laterals], 1)) encode_feat, output = self.enc_module(feat) output = self.cls_seg(output) if self.use_se_loss: se_output = self.se_layer(encode_feat) return output, se_output else: return output
[docs] def forward_test(self, inputs, img_metas, test_cfg): """Forward function for testing, ignore se_loss.""" if self.use_se_loss: return self.forward(inputs)[0] else: return self.forward(inputs)
@staticmethod def _convert_to_onehot_labels(seg_label, num_classes): """Convert segmentation label to onehot. Args: seg_label (Tensor): Segmentation label of shape (N, H, W). num_classes (int): Number of classes. Returns: Tensor: Onehot labels of shape (N, num_classes). """ batch_size = seg_label.size(0) onehot_labels = seg_label.new_zeros((batch_size, num_classes)) for i in range(batch_size): hist = seg_label[i].float().histc( bins=num_classes, min=0, max=num_classes - 1) onehot_labels[i] = hist > 0 return onehot_labels
[docs] def losses(self, seg_logit, seg_label): """Compute segmentation and semantic encoding loss.""" seg_logit, se_seg_logit = seg_logit loss = dict() loss.update(super(EncHead, self).losses(seg_logit, seg_label)) se_loss = self.loss_se_decode( se_seg_logit, self._convert_to_onehot_labels(seg_label, self.num_classes)) loss['loss_se'] = se_loss return loss
Read the Docs v: latest
Versions
latest
stable
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.