Shortcuts

Source code for mmseg.models.decode_heads.cc_head

# Copyright (c) OpenMMLab. All rights reserved.
import torch

from ..builder import HEADS
from .fcn_head import FCNHead

try:
    from mmcv.ops import CrissCrossAttention
except ModuleNotFoundError:
    CrissCrossAttention = None


[docs]@HEADS.register_module() class CCHead(FCNHead): """CCNet: Criss-Cross Attention for Semantic Segmentation. This head is the implementation of `CCNet <https://arxiv.org/abs/1811.11721>`_. Args: recurrence (int): Number of recurrence of Criss Cross Attention module. Default: 2. """ def __init__(self, recurrence=2, **kwargs): if CrissCrossAttention is None: raise RuntimeError('Please install mmcv-full for ' 'CrissCrossAttention ops') super(CCHead, self).__init__(num_convs=2, **kwargs) self.recurrence = recurrence self.cca = CrissCrossAttention(self.channels)
[docs] def forward(self, inputs): """Forward function.""" x = self._transform_inputs(inputs) output = self.convs[0](x) for _ in range(self.recurrence): output = self.cca(output) output = self.convs[1](output) if self.concat_input: output = self.conv_cat(torch.cat([x, output], dim=1)) output = self.cls_seg(output) return output
Read the Docs v: latest
Versions
latest
stable
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.