Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1#!/usr/bin/env python 

2# vim: set fileencoding=utf-8 : 

3 

4from collections import OrderedDict 

5 

6import torch 

7import torch.nn 

8 

9from torchvision.models.mobilenetv2 import InvertedResidual 

10 

11from .backbones.mobilenetv2 import mobilenet_v2_for_segmentation 

12 

13 

14class DecoderBlock(torch.nn.Module): 

15 """ 

16 Decoder block: upsample and concatenate with features maps from the encoder part 

17 """ 

18 

19 def __init__( 

20 self, up_in_c, x_in_c, upsamplemode="bilinear", expand_ratio=0.15 

21 ): 

22 super().__init__() 

23 self.upsample = torch.nn.Upsample( 

24 scale_factor=2, mode=upsamplemode, align_corners=False 

25 ) # H, W -> 2H, 2W 

26 self.ir1 = InvertedResidual( 

27 up_in_c + x_in_c, 

28 (x_in_c + up_in_c) // 2, 

29 stride=1, 

30 expand_ratio=expand_ratio, 

31 ) 

32 

33 def forward(self, up_in, x_in): 

34 up_out = self.upsample(up_in) 

35 cat_x = torch.cat([up_out, x_in], dim=1) 

36 x = self.ir1(cat_x) 

37 return x 

38 

39 

40class LastDecoderBlock(torch.nn.Module): 

41 def __init__(self, x_in_c, upsamplemode="bilinear", expand_ratio=0.15): 

42 super().__init__() 

43 self.upsample = torch.nn.Upsample( 

44 scale_factor=2, mode=upsamplemode, align_corners=False 

45 ) # H, W -> 2H, 2W 

46 self.ir1 = InvertedResidual( 

47 x_in_c, 1, stride=1, expand_ratio=expand_ratio 

48 ) 

49 

50 def forward(self, up_in, x_in): 

51 up_out = self.upsample(up_in) 

52 cat_x = torch.cat([up_out, x_in], dim=1) 

53 x = self.ir1(cat_x) 

54 return x 

55 

56 

57class M2UNet(torch.nn.Module): 

58 """ 

59 M2U-Net head module 

60 

61 Parameters 

62 ---------- 

63 in_channels_list : list 

64 number of channels for each feature map that is returned from backbone 

65 """ 

66 

67 def __init__( 

68 self, in_channels_list=None, upsamplemode="bilinear", expand_ratio=0.15 

69 ): 

70 super(M2UNet, self).__init__() 

71 

72 # Decoder 

73 self.decode4 = DecoderBlock(96, 32, upsamplemode, expand_ratio) 

74 self.decode3 = DecoderBlock(64, 24, upsamplemode, expand_ratio) 

75 self.decode2 = DecoderBlock(44, 16, upsamplemode, expand_ratio) 

76 self.decode1 = LastDecoderBlock(33, upsamplemode, expand_ratio) 

77 

78 # initilaize weights 

79 self._initialize_weights() 

80 

81 def _initialize_weights(self): 

82 for m in self.modules(): 

83 if isinstance(m, torch.nn.Conv2d): 

84 torch.nn.init.kaiming_uniform_(m.weight, a=1) 

85 if m.bias is not None: 

86 torch.nn.init.constant_(m.bias, 0) 

87 elif isinstance(m, torch.nn.BatchNorm2d): 

88 m.weight.data.fill_(1) 

89 m.bias.data.zero_() 

90 

91 def forward(self, x): 

92 """ 

93 Parameters 

94 ---------- 

95 x : list 

96 list of tensors as returned from the backbone network. 

97 First element: height and width of input image. 

98 Remaining elements: feature maps for each feature level. 

99 Returns 

100 ------- 

101 tensor : :py:class:`torch.Tensor` 

102 """ 

103 decode4 = self.decode4(x[5], x[4]) # 96, 32 

104 decode3 = self.decode3(decode4, x[3]) # 64, 24 

105 decode2 = self.decode2(decode3, x[2]) # 44, 16 

106 decode1 = self.decode1(decode2, x[1]) # 30, 3 

107 

108 return decode1 

109 

110 

111def m2unet(pretrained_backbone=True, progress=True): 

112 """Builds M2U-Net for segmentation by adding backbone and head together 

113 

114 

115 Parameters 

116 ---------- 

117 

118 pretrained_backbone : :py:class:`bool`, Optional 

119 If set to ``True``, then loads a pre-trained version of the backbone 

120 (not the head) for the DRIU network using VGG-16 trained for ImageNet 

121 classification. 

122 

123 progress : :py:class:`bool`, Optional 

124 If set to ``True``, and you decided to use a ``pretrained_backbone``, 

125 then, shows a progress bar of the backbone model downloading if 

126 download is necesssary. 

127 

128 

129 Returns 

130 ------- 

131 

132 module : :py:class:`torch.nn.Module` 

133 Network model for M2U-Net (segmentation) 

134 

135 """ 

136 

137 backbone = mobilenet_v2_for_segmentation( 

138 pretrained=pretrained_backbone, 

139 progress=progress, 

140 return_features=[1, 3, 6, 13], 

141 ) 

142 head = M2UNet(in_channels_list=[16, 24, 32, 96]) 

143 

144 order = [("backbone", backbone), ("head", head)] 

145 if pretrained_backbone: 

146 from .normalizer import TorchVisionNormalizer 

147 

148 order = [("normalizer", TorchVisionNormalizer())] + order 

149 

150 model = torch.nn.Sequential(OrderedDict(order)) 

151 model.name = "m2unet" 

152 return model