Coverage for src/deepdraw/models/m2unet.py: 76%

58 statements  

« prev     ^ index     » next       coverage.py v7.3.1, created at 2023-11-30 15:00 +0100

1# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> 

2# 

3# SPDX-License-Identifier: GPL-3.0-or-later 

4 

5from collections import OrderedDict 

6 

7import torch 

8import torch.nn 

9 

10from torchvision.models.mobilenetv2 import InvertedResidual 

11 

12from .backbones.mobilenetv2 import mobilenet_v2_for_segmentation 

13 

14 

15class DecoderBlock(torch.nn.Module): 

16 """Decoder block: upsample and concatenate with features maps from the 

17 encoder part.""" 

18 

19 def __init__( 

20 self, up_in_c, x_in_c, upsamplemode="bilinear", expand_ratio=0.15 

21 ): 

22 super().__init__() 

23 self.upsample = torch.nn.Upsample( 

24 scale_factor=2, mode=upsamplemode, align_corners=False 

25 ) # H, W -> 2H, 2W 

26 self.ir1 = InvertedResidual( 

27 up_in_c + x_in_c, 

28 (x_in_c + up_in_c) // 2, 

29 stride=1, 

30 expand_ratio=expand_ratio, 

31 ) 

32 

33 def forward(self, up_in, x_in): 

34 up_out = self.upsample(up_in) 

35 cat_x = torch.cat([up_out, x_in], dim=1) 

36 x = self.ir1(cat_x) 

37 return x 

38 

39 

40class LastDecoderBlock(torch.nn.Module): 

41 def __init__(self, x_in_c, upsamplemode="bilinear", expand_ratio=0.15): 

42 super().__init__() 

43 self.upsample = torch.nn.Upsample( 

44 scale_factor=2, mode=upsamplemode, align_corners=False 

45 ) # H, W -> 2H, 2W 

46 self.ir1 = InvertedResidual( 

47 x_in_c, 1, stride=1, expand_ratio=expand_ratio 

48 ) 

49 

50 def forward(self, up_in, x_in): 

51 up_out = self.upsample(up_in) 

52 cat_x = torch.cat([up_out, x_in], dim=1) 

53 x = self.ir1(cat_x) 

54 return x 

55 

56 

57class M2UNet(torch.nn.Module): 

58 """M2U-Net head module. 

59 

60 Parameters 

61 ---------- 

62 in_channels_list : list 

63 number of channels for each feature map that is returned from backbone 

64 """ 

65 

66 def __init__( 

67 self, in_channels_list=None, upsamplemode="bilinear", expand_ratio=0.15 

68 ): 

69 super().__init__() 

70 

71 # Decoder 

72 self.decode4 = DecoderBlock(96, 32, upsamplemode, expand_ratio) 

73 self.decode3 = DecoderBlock(64, 24, upsamplemode, expand_ratio) 

74 self.decode2 = DecoderBlock(44, 16, upsamplemode, expand_ratio) 

75 self.decode1 = LastDecoderBlock(33, upsamplemode, expand_ratio) 

76 

77 # initilaize weights 

78 self._initialize_weights() 

79 

80 def _initialize_weights(self): 

81 for m in self.modules(): 

82 if isinstance(m, torch.nn.Conv2d): 

83 torch.nn.init.kaiming_uniform_(m.weight, a=1) 

84 if m.bias is not None: 

85 torch.nn.init.constant_(m.bias, 0) 

86 elif isinstance(m, torch.nn.BatchNorm2d): 

87 m.weight.data.fill_(1) 

88 m.bias.data.zero_() 

89 

90 def forward(self, x): 

91 """ 

92 Parameters 

93 ---------- 

94 x : list 

95 list of tensors as returned from the backbone network. 

96 First element: height and width of input image. 

97 Remaining elements: feature maps for each feature level. 

98 Returns 

99 ------- 

100 tensor : :py:class:`torch.Tensor` 

101 """ 

102 decode4 = self.decode4(x[5], x[4]) # 96, 32 

103 decode3 = self.decode3(decode4, x[3]) # 64, 24 

104 decode2 = self.decode2(decode3, x[2]) # 44, 16 

105 decode1 = self.decode1(decode2, x[1]) # 30, 3 

106 

107 return decode1 

108 

109 

110def m2unet(pretrained_backbone=True, progress=True): 

111 """Builds M2U-Net for segmentation by adding backbone and head together. 

112 

113 Parameters 

114 ---------- 

115 

116 pretrained_backbone : :py:class:`bool`, Optional 

117 If set to ``True``, then loads a pre-trained version of the backbone 

118 (not the head) for the DRIU network using VGG-16 trained for ImageNet 

119 classification. 

120 

121 progress : :py:class:`bool`, Optional 

122 If set to ``True``, and you decided to use a ``pretrained_backbone``, 

123 then, shows a progress bar of the backbone model downloading if 

124 download is necesssary. 

125 

126 

127 Returns 

128 ------- 

129 

130 module : :py:class:`torch.nn.Module` 

131 Network model for M2U-Net (segmentation) 

132 """ 

133 

134 backbone = mobilenet_v2_for_segmentation( 

135 pretrained=pretrained_backbone, 

136 progress=progress, 

137 return_features=[1, 3, 6, 13], 

138 ) 

139 head = M2UNet(in_channels_list=[16, 24, 32, 96]) 

140 

141 order = [("backbone", backbone), ("head", head)] 

142 if pretrained_backbone: 

143 from .normalizer import TorchVisionNormalizer 

144 

145 order = [("normalizer", TorchVisionNormalizer())] + order 

146 

147 model = torch.nn.Sequential(OrderedDict(order)) 

148 model.name = "m2unet" 

149 return model