Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1#!/usr/bin/env python
2# vim: set fileencoding=utf-8 :
4from collections import OrderedDict
6import torch
7import torch.nn
9from torchvision.models.mobilenetv2 import InvertedResidual
11from .backbones.mobilenetv2 import mobilenet_v2_for_segmentation
14class DecoderBlock(torch.nn.Module):
15 """
16 Decoder block: upsample and concatenate with features maps from the encoder part
17 """
19 def __init__(
20 self, up_in_c, x_in_c, upsamplemode="bilinear", expand_ratio=0.15
21 ):
22 super().__init__()
23 self.upsample = torch.nn.Upsample(
24 scale_factor=2, mode=upsamplemode, align_corners=False
25 ) # H, W -> 2H, 2W
26 self.ir1 = InvertedResidual(
27 up_in_c + x_in_c,
28 (x_in_c + up_in_c) // 2,
29 stride=1,
30 expand_ratio=expand_ratio,
31 )
33 def forward(self, up_in, x_in):
34 up_out = self.upsample(up_in)
35 cat_x = torch.cat([up_out, x_in], dim=1)
36 x = self.ir1(cat_x)
37 return x
40class LastDecoderBlock(torch.nn.Module):
41 def __init__(self, x_in_c, upsamplemode="bilinear", expand_ratio=0.15):
42 super().__init__()
43 self.upsample = torch.nn.Upsample(
44 scale_factor=2, mode=upsamplemode, align_corners=False
45 ) # H, W -> 2H, 2W
46 self.ir1 = InvertedResidual(
47 x_in_c, 1, stride=1, expand_ratio=expand_ratio
48 )
50 def forward(self, up_in, x_in):
51 up_out = self.upsample(up_in)
52 cat_x = torch.cat([up_out, x_in], dim=1)
53 x = self.ir1(cat_x)
54 return x
57class M2UNet(torch.nn.Module):
58 """
59 M2U-Net head module
61 Parameters
62 ----------
63 in_channels_list : list
64 number of channels for each feature map that is returned from backbone
65 """
67 def __init__(
68 self, in_channels_list=None, upsamplemode="bilinear", expand_ratio=0.15
69 ):
70 super(M2UNet, self).__init__()
72 # Decoder
73 self.decode4 = DecoderBlock(96, 32, upsamplemode, expand_ratio)
74 self.decode3 = DecoderBlock(64, 24, upsamplemode, expand_ratio)
75 self.decode2 = DecoderBlock(44, 16, upsamplemode, expand_ratio)
76 self.decode1 = LastDecoderBlock(33, upsamplemode, expand_ratio)
78 # initilaize weights
79 self._initialize_weights()
81 def _initialize_weights(self):
82 for m in self.modules():
83 if isinstance(m, torch.nn.Conv2d):
84 torch.nn.init.kaiming_uniform_(m.weight, a=1)
85 if m.bias is not None:
86 torch.nn.init.constant_(m.bias, 0)
87 elif isinstance(m, torch.nn.BatchNorm2d):
88 m.weight.data.fill_(1)
89 m.bias.data.zero_()
91 def forward(self, x):
92 """
93 Parameters
94 ----------
95 x : list
96 list of tensors as returned from the backbone network.
97 First element: height and width of input image.
98 Remaining elements: feature maps for each feature level.
99 Returns
100 -------
101 tensor : :py:class:`torch.Tensor`
102 """
103 decode4 = self.decode4(x[5], x[4]) # 96, 32
104 decode3 = self.decode3(decode4, x[3]) # 64, 24
105 decode2 = self.decode2(decode3, x[2]) # 44, 16
106 decode1 = self.decode1(decode2, x[1]) # 30, 3
108 return decode1
111def m2unet(pretrained_backbone=True, progress=True):
112 """Builds M2U-Net for segmentation by adding backbone and head together
115 Parameters
116 ----------
118 pretrained_backbone : :py:class:`bool`, Optional
119 If set to ``True``, then loads a pre-trained version of the backbone
120 (not the head) for the DRIU network using VGG-16 trained for ImageNet
121 classification.
123 progress : :py:class:`bool`, Optional
124 If set to ``True``, and you decided to use a ``pretrained_backbone``,
125 then, shows a progress bar of the backbone model downloading if
126 download is necesssary.
129 Returns
130 -------
132 module : :py:class:`torch.nn.Module`
133 Network model for M2U-Net (segmentation)
135 """
137 backbone = mobilenet_v2_for_segmentation(
138 pretrained=pretrained_backbone,
139 progress=progress,
140 return_features=[1, 3, 6, 13],
141 )
142 head = M2UNet(in_channels_list=[16, 24, 32, 96])
144 order = [("backbone", backbone), ("head", head)]
145 if pretrained_backbone:
146 from .normalizer import TorchVisionNormalizer
148 order = [("normalizer", TorchVisionNormalizer())] + order
150 model = torch.nn.Sequential(OrderedDict(order))
151 model.name = "m2unet"
152 return model