Coverage for src/deepdraw/models/m2unet.py: 76%
58 statements
« prev ^ index » next coverage.py v7.3.1, created at 2023-11-30 15:00 +0100
« prev ^ index » next coverage.py v7.3.1, created at 2023-11-30 15:00 +0100
1# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
2#
3# SPDX-License-Identifier: GPL-3.0-or-later
5from collections import OrderedDict
7import torch
8import torch.nn
10from torchvision.models.mobilenetv2 import InvertedResidual
12from .backbones.mobilenetv2 import mobilenet_v2_for_segmentation
15class DecoderBlock(torch.nn.Module):
16 """Decoder block: upsample and concatenate with features maps from the
17 encoder part."""
19 def __init__(
20 self, up_in_c, x_in_c, upsamplemode="bilinear", expand_ratio=0.15
21 ):
22 super().__init__()
23 self.upsample = torch.nn.Upsample(
24 scale_factor=2, mode=upsamplemode, align_corners=False
25 ) # H, W -> 2H, 2W
26 self.ir1 = InvertedResidual(
27 up_in_c + x_in_c,
28 (x_in_c + up_in_c) // 2,
29 stride=1,
30 expand_ratio=expand_ratio,
31 )
33 def forward(self, up_in, x_in):
34 up_out = self.upsample(up_in)
35 cat_x = torch.cat([up_out, x_in], dim=1)
36 x = self.ir1(cat_x)
37 return x
40class LastDecoderBlock(torch.nn.Module):
41 def __init__(self, x_in_c, upsamplemode="bilinear", expand_ratio=0.15):
42 super().__init__()
43 self.upsample = torch.nn.Upsample(
44 scale_factor=2, mode=upsamplemode, align_corners=False
45 ) # H, W -> 2H, 2W
46 self.ir1 = InvertedResidual(
47 x_in_c, 1, stride=1, expand_ratio=expand_ratio
48 )
50 def forward(self, up_in, x_in):
51 up_out = self.upsample(up_in)
52 cat_x = torch.cat([up_out, x_in], dim=1)
53 x = self.ir1(cat_x)
54 return x
57class M2UNet(torch.nn.Module):
58 """M2U-Net head module.
60 Parameters
61 ----------
62 in_channels_list : list
63 number of channels for each feature map that is returned from backbone
64 """
66 def __init__(
67 self, in_channels_list=None, upsamplemode="bilinear", expand_ratio=0.15
68 ):
69 super().__init__()
71 # Decoder
72 self.decode4 = DecoderBlock(96, 32, upsamplemode, expand_ratio)
73 self.decode3 = DecoderBlock(64, 24, upsamplemode, expand_ratio)
74 self.decode2 = DecoderBlock(44, 16, upsamplemode, expand_ratio)
75 self.decode1 = LastDecoderBlock(33, upsamplemode, expand_ratio)
77 # initilaize weights
78 self._initialize_weights()
80 def _initialize_weights(self):
81 for m in self.modules():
82 if isinstance(m, torch.nn.Conv2d):
83 torch.nn.init.kaiming_uniform_(m.weight, a=1)
84 if m.bias is not None:
85 torch.nn.init.constant_(m.bias, 0)
86 elif isinstance(m, torch.nn.BatchNorm2d):
87 m.weight.data.fill_(1)
88 m.bias.data.zero_()
90 def forward(self, x):
91 """
92 Parameters
93 ----------
94 x : list
95 list of tensors as returned from the backbone network.
96 First element: height and width of input image.
97 Remaining elements: feature maps for each feature level.
98 Returns
99 -------
100 tensor : :py:class:`torch.Tensor`
101 """
102 decode4 = self.decode4(x[5], x[4]) # 96, 32
103 decode3 = self.decode3(decode4, x[3]) # 64, 24
104 decode2 = self.decode2(decode3, x[2]) # 44, 16
105 decode1 = self.decode1(decode2, x[1]) # 30, 3
107 return decode1
110def m2unet(pretrained_backbone=True, progress=True):
111 """Builds M2U-Net for segmentation by adding backbone and head together.
113 Parameters
114 ----------
116 pretrained_backbone : :py:class:`bool`, Optional
117 If set to ``True``, then loads a pre-trained version of the backbone
118 (not the head) for the DRIU network using VGG-16 trained for ImageNet
119 classification.
121 progress : :py:class:`bool`, Optional
122 If set to ``True``, and you decided to use a ``pretrained_backbone``,
123 then, shows a progress bar of the backbone model downloading if
124 download is necesssary.
127 Returns
128 -------
130 module : :py:class:`torch.nn.Module`
131 Network model for M2U-Net (segmentation)
132 """
134 backbone = mobilenet_v2_for_segmentation(
135 pretrained=pretrained_backbone,
136 progress=progress,
137 return_features=[1, 3, 6, 13],
138 )
139 head = M2UNet(in_channels_list=[16, 24, 32, 96])
141 order = [("backbone", backbone), ("head", head)]
142 if pretrained_backbone:
143 from .normalizer import TorchVisionNormalizer
145 order = [("normalizer", TorchVisionNormalizer())] + order
147 model = torch.nn.Sequential(OrderedDict(order))
148 model.name = "m2unet"
149 return model