Shortcuts

Source code for mmseg.models.decode_heads.fpn_head

# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import torch.nn as nn
from mmcv.cnn import ConvModule

from mmseg.ops import Upsample, resize
from ..builder import HEADS
from .decode_head import BaseDecodeHead


[docs]@HEADS.register_module() class FPNHead(BaseDecodeHead): """Panoptic Feature Pyramid Networks. This head is the implementation of `Semantic FPN <https://arxiv.org/abs/1901.02446>`_. Args: feature_strides (tuple[int]): The strides for input feature maps. stack_lateral. All strides suppose to be power of 2. The first one is of largest resolution. """ def __init__(self, feature_strides, **kwargs): super(FPNHead, self).__init__( input_transform='multiple_select', **kwargs) assert len(feature_strides) == len(self.in_channels) assert min(feature_strides) == feature_strides[0] self.feature_strides = feature_strides self.scale_heads = nn.ModuleList() for i in range(len(feature_strides)): head_length = max( 1, int(np.log2(feature_strides[i]) - np.log2(feature_strides[0]))) scale_head = [] for k in range(head_length): scale_head.append( ConvModule( self.in_channels[i] if k == 0 else self.channels, self.channels, 3, padding=1, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg)) if feature_strides[i] != feature_strides[0]: scale_head.append( Upsample( scale_factor=2, mode='bilinear', align_corners=self.align_corners)) self.scale_heads.append(nn.Sequential(*scale_head))
[docs] def forward(self, inputs): x = self._transform_inputs(inputs) output = self.scale_heads[0](x[0]) for i in range(1, len(self.feature_strides)): # non inplace output = output + resize( self.scale_heads[i](x[i]), size=output.shape[2:], mode='bilinear', align_corners=self.align_corners) output = self.cls_seg(output) return output
Read the Docs v: latest
Versions
latest
stable
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.