Source code for bob.ip.binseg.models.m2unet

#!/usr/bin/env python
# vim: set fileencoding=utf-8 :

from collections import OrderedDict

import torch
import torch.nn

from torchvision.models.mobilenetv2 import InvertedResidual

from .backbones.mobilenetv2 import mobilenet_v2_for_segmentation


[docs]class DecoderBlock(torch.nn.Module): """ Decoder block: upsample and concatenate with features maps from the encoder part """ def __init__( self, up_in_c, x_in_c, upsamplemode="bilinear", expand_ratio=0.15 ): super().__init__() self.upsample = torch.nn.Upsample( scale_factor=2, mode=upsamplemode, align_corners=False ) # H, W -> 2H, 2W self.ir1 = InvertedResidual( up_in_c + x_in_c, (x_in_c + up_in_c) // 2, stride=1, expand_ratio=expand_ratio, )
[docs] def forward(self, up_in, x_in): up_out = self.upsample(up_in) cat_x = torch.cat([up_out, x_in], dim=1) x = self.ir1(cat_x) return x
[docs]class LastDecoderBlock(torch.nn.Module): def __init__(self, x_in_c, upsamplemode="bilinear", expand_ratio=0.15): super().__init__() self.upsample = torch.nn.Upsample( scale_factor=2, mode=upsamplemode, align_corners=False ) # H, W -> 2H, 2W self.ir1 = InvertedResidual( x_in_c, 1, stride=1, expand_ratio=expand_ratio )
[docs] def forward(self, up_in, x_in): up_out = self.upsample(up_in) cat_x = torch.cat([up_out, x_in], dim=1) x = self.ir1(cat_x) return x
[docs]class M2UNet(torch.nn.Module): """ M2U-Net head module Parameters ---------- in_channels_list : list number of channels for each feature map that is returned from backbone """ def __init__( self, in_channels_list=None, upsamplemode="bilinear", expand_ratio=0.15 ): super(M2UNet, self).__init__() # Decoder self.decode4 = DecoderBlock(96, 32, upsamplemode, expand_ratio) self.decode3 = DecoderBlock(64, 24, upsamplemode, expand_ratio) self.decode2 = DecoderBlock(44, 16, upsamplemode, expand_ratio) self.decode1 = LastDecoderBlock(33, upsamplemode, expand_ratio) # initilaize weights self._initialize_weights() def _initialize_weights(self): for m in self.modules(): if isinstance(m, torch.nn.Conv2d): torch.nn.init.kaiming_uniform_(m.weight, a=1) if m.bias is not None: torch.nn.init.constant_(m.bias, 0) elif isinstance(m, torch.nn.BatchNorm2d): m.weight.data.fill_(1) m.bias.data.zero_()
[docs] def forward(self, x): """ Parameters ---------- x : list list of tensors as returned from the backbone network. First element: height and width of input image. Remaining elements: feature maps for each feature level. Returns ------- tensor : :py:class:`torch.Tensor` """ decode4 = self.decode4(x[5], x[4]) # 96, 32 decode3 = self.decode3(decode4, x[3]) # 64, 24 decode2 = self.decode2(decode3, x[2]) # 44, 16 decode1 = self.decode1(decode2, x[1]) # 30, 3 return decode1
[docs]def m2unet(pretrained_backbone=True, progress=True): """Builds M2U-Net for segmentation by adding backbone and head together Parameters ---------- pretrained_backbone : :py:class:`bool`, Optional If set to ``True``, then loads a pre-trained version of the backbone (not the head) for the DRIU network using VGG-16 trained for ImageNet classification. progress : :py:class:`bool`, Optional If set to ``True``, and you decided to use a ``pretrained_backbone``, then, shows a progress bar of the backbone model downloading if download is necesssary. Returns ------- module : :py:class:`torch.nn.Module` Network model for M2U-Net (segmentation) """ backbone = mobilenet_v2_for_segmentation( pretrained=pretrained_backbone, progress=progress, return_features=[1, 3, 6, 13], ) head = M2UNet(in_channels_list=[16, 24, 32, 96]) order = [("backbone", backbone), ("head", head)] if pretrained_backbone: from .normalizer import TorchVisionNormalizer order = [("normalizer", TorchVisionNormalizer())] + order model = torch.nn.Sequential(OrderedDict(order)) model.name = "m2unet" return model