Shortcuts

Source code for mmseg.models.decode_heads.segmenter_mask_head

# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import build_norm_layer
from mmcv.cnn.utils.weight_init import (constant_init, trunc_normal_,
                                        trunc_normal_init)
from mmcv.runner import ModuleList

from mmseg.models.backbones.vit import TransformerEncoderLayer
from ..builder import HEADS
from .decode_head import BaseDecodeHead


[docs]@HEADS.register_module() class SegmenterMaskTransformerHead(BaseDecodeHead): """Segmenter: Transformer for Semantic Segmentation. This head is the implementation of `Segmenter: <https://arxiv.org/abs/2105.05633>`_. Args: backbone_cfg:(dict): Config of backbone of Context Path. in_channels (int): The number of channels of input image. num_layers (int): The depth of transformer. num_heads (int): The number of attention heads. embed_dims (int): The number of embedding dimension. mlp_ratio (int): ratio of mlp hidden dim to embedding dim. Default: 4. drop_path_rate (float): stochastic depth rate. Default 0.1. drop_rate (float): Probability of an element to be zeroed. Default 0.0 attn_drop_rate (float): The drop out rate for attention layer. Default 0.0 num_fcs (int): The number of fully-connected layers for FFNs. Default: 2. qkv_bias (bool): Enable bias for qkv if True. Default: True. act_cfg (dict): The activation config for FFNs. Default: dict(type='GELU'). norm_cfg (dict): Config dict for normalization layer. Default: dict(type='LN') init_std (float): The value of std in weight initialization. Default: 0.02. """ def __init__( self, in_channels, num_layers, num_heads, embed_dims, mlp_ratio=4, drop_path_rate=0.1, drop_rate=0.0, attn_drop_rate=0.0, num_fcs=2, qkv_bias=True, act_cfg=dict(type='GELU'), norm_cfg=dict(type='LN'), init_std=0.02, **kwargs, ): super(SegmenterMaskTransformerHead, self).__init__( in_channels=in_channels, **kwargs) dpr = [x.item() for x in torch.linspace(0, drop_path_rate, num_layers)] self.layers = ModuleList() for i in range(num_layers): self.layers.append( TransformerEncoderLayer( embed_dims=embed_dims, num_heads=num_heads, feedforward_channels=mlp_ratio * embed_dims, attn_drop_rate=attn_drop_rate, drop_rate=drop_rate, drop_path_rate=dpr[i], num_fcs=num_fcs, qkv_bias=qkv_bias, act_cfg=act_cfg, norm_cfg=norm_cfg, batch_first=True, )) self.dec_proj = nn.Linear(in_channels, embed_dims) self.cls_emb = nn.Parameter( torch.randn(1, self.num_classes, embed_dims)) self.patch_proj = nn.Linear(embed_dims, embed_dims, bias=False) self.classes_proj = nn.Linear(embed_dims, embed_dims, bias=False) self.decoder_norm = build_norm_layer( norm_cfg, embed_dims, postfix=1)[1] self.mask_norm = build_norm_layer( norm_cfg, self.num_classes, postfix=2)[1] self.init_std = init_std delattr(self, 'conv_seg')
[docs] def init_weights(self): trunc_normal_(self.cls_emb, std=self.init_std) trunc_normal_init(self.patch_proj, std=self.init_std) trunc_normal_init(self.classes_proj, std=self.init_std) for n, m in self.named_modules(): if isinstance(m, nn.Linear): trunc_normal_init(m, std=self.init_std, bias=0) elif isinstance(m, nn.LayerNorm): constant_init(m, val=1.0, bias=0.0)
[docs] def forward(self, inputs): x = self._transform_inputs(inputs) b, c, h, w = x.shape x = x.permute(0, 2, 3, 1).contiguous().view(b, -1, c) x = self.dec_proj(x) cls_emb = self.cls_emb.expand(x.size(0), -1, -1) x = torch.cat((x, cls_emb), 1) for layer in self.layers: x = layer(x) x = self.decoder_norm(x) patches = self.patch_proj(x[:, :-self.num_classes]) cls_seg_feat = self.classes_proj(x[:, -self.num_classes:]) patches = F.normalize(patches, dim=2, p=2) cls_seg_feat = F.normalize(cls_seg_feat, dim=2, p=2) masks = patches @ cls_seg_feat.transpose(1, 2) masks = self.mask_norm(masks) masks = masks.permute(0, 2, 1).contiguous().view(b, -1, h, w) return masks
Read the Docs v: latest
Versions
latest
stable
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.