Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1#!/usr/bin/env python 

2# coding=utf-8 

3 

4"""Little W-Net 

5 

6Code was originally developed by Adrian Galdran 

7(https://github.com/agaldran/lwnet), loosely inspired on 

8https://github.com/jvanvugt/pytorch-unet 

9 

10It is based on two simple U-Nets with 3 layers concatenated to each other. The 

11first U-Net produces a segmentation map that is used by the second to better 

12guide segmentation. 

13 

14Reference: [GALDRAN-2020]_ 

15""" 

16 

17 

18import torch 

19import torch.nn 

20 

21 

22def _conv1x1(in_planes, out_planes, stride=1): 

23 return torch.nn.Conv2d( 

24 in_planes, out_planes, kernel_size=1, stride=stride, bias=False 

25 ) 

26 

27 

28class ConvBlock(torch.nn.Module): 

29 def __init__(self, in_c, out_c, k_sz=3, shortcut=False, pool=True): 

30 """ 

31 pool_mode can be False (no pooling) or True ('maxpool') 

32 """ 

33 

34 super(ConvBlock, self).__init__() 

35 if shortcut is True: 

36 self.shortcut = torch.nn.Sequential( 

37 _conv1x1(in_c, out_c), torch.nn.BatchNorm2d(out_c) 

38 ) 

39 else: 

40 self.shortcut = False 

41 pad = (k_sz - 1) // 2 

42 

43 block = [] 

44 if pool: 

45 self.pool = torch.nn.MaxPool2d(kernel_size=2) 

46 else: 

47 self.pool = False 

48 

49 block.append( 

50 torch.nn.Conv2d(in_c, out_c, kernel_size=k_sz, padding=pad) 

51 ) 

52 block.append(torch.nn.ReLU()) 

53 block.append(torch.nn.BatchNorm2d(out_c)) 

54 

55 block.append( 

56 torch.nn.Conv2d(out_c, out_c, kernel_size=k_sz, padding=pad) 

57 ) 

58 block.append(torch.nn.ReLU()) 

59 block.append(torch.nn.BatchNorm2d(out_c)) 

60 

61 self.block = torch.nn.Sequential(*block) 

62 

63 def forward(self, x): 

64 if self.pool: 

65 x = self.pool(x) 

66 out = self.block(x) 

67 if self.shortcut: 

68 return out + self.shortcut(x) 

69 else: 

70 return out 

71 

72 

73class UpsampleBlock(torch.nn.Module): 

74 def __init__(self, in_c, out_c, up_mode="transp_conv"): 

75 super(UpsampleBlock, self).__init__() 

76 block = [] 

77 if up_mode == "transp_conv": 

78 block.append( 

79 torch.nn.ConvTranspose2d(in_c, out_c, kernel_size=2, stride=2) 

80 ) 

81 elif up_mode == "up_conv": 

82 block.append( 

83 torch.nn.Upsample( 

84 mode="bilinear", scale_factor=2, align_corners=False 

85 ) 

86 ) 

87 block.append(torch.nn.Conv2d(in_c, out_c, kernel_size=1)) 

88 else: 

89 raise Exception("Upsampling mode not supported") 

90 

91 self.block = torch.nn.Sequential(*block) 

92 

93 def forward(self, x): 

94 out = self.block(x) 

95 return out 

96 

97 

98class ConvBridgeBlock(torch.nn.Module): 

99 def __init__(self, channels, k_sz=3): 

100 super(ConvBridgeBlock, self).__init__() 

101 pad = (k_sz - 1) // 2 

102 block = [] 

103 

104 block.append( 

105 torch.nn.Conv2d(channels, channels, kernel_size=k_sz, padding=pad) 

106 ) 

107 block.append(torch.nn.ReLU()) 

108 block.append(torch.nn.BatchNorm2d(channels)) 

109 

110 self.block = torch.nn.Sequential(*block) 

111 

112 def forward(self, x): 

113 out = self.block(x) 

114 return out 

115 

116 

117class UpConvBlock(torch.nn.Module): 

118 def __init__( 

119 self, 

120 in_c, 

121 out_c, 

122 k_sz=3, 

123 up_mode="up_conv", 

124 conv_bridge=False, 

125 shortcut=False, 

126 ): 

127 super(UpConvBlock, self).__init__() 

128 self.conv_bridge = conv_bridge 

129 

130 self.up_layer = UpsampleBlock(in_c, out_c, up_mode=up_mode) 

131 self.conv_layer = ConvBlock( 

132 2 * out_c, out_c, k_sz=k_sz, shortcut=shortcut, pool=False 

133 ) 

134 if self.conv_bridge: 

135 self.conv_bridge_layer = ConvBridgeBlock(out_c, k_sz=k_sz) 

136 

137 def forward(self, x, skip): 

138 up = self.up_layer(x) 

139 if self.conv_bridge: 

140 out = torch.cat([up, self.conv_bridge_layer(skip)], dim=1) 

141 else: 

142 out = torch.cat([up, skip], dim=1) 

143 out = self.conv_layer(out) 

144 return out 

145 

146 

147class LittleUNet(torch.nn.Module): 

148 """Little U-Net model""" 

149 

150 def __init__( 

151 self, 

152 in_c, 

153 n_classes, 

154 layers, 

155 k_sz=3, 

156 up_mode="transp_conv", 

157 conv_bridge=True, 

158 shortcut=True, 

159 ): 

160 super(LittleUNet, self).__init__() 

161 self.n_classes = n_classes 

162 self.first = ConvBlock( 

163 in_c=in_c, out_c=layers[0], k_sz=k_sz, shortcut=shortcut, pool=False 

164 ) 

165 

166 self.down_path = torch.nn.ModuleList() 

167 for i in range(len(layers) - 1): 

168 block = ConvBlock( 

169 in_c=layers[i], 

170 out_c=layers[i + 1], 

171 k_sz=k_sz, 

172 shortcut=shortcut, 

173 pool=True, 

174 ) 

175 self.down_path.append(block) 

176 

177 self.up_path = torch.nn.ModuleList() 

178 reversed_layers = list(reversed(layers)) 

179 for i in range(len(layers) - 1): 

180 block = UpConvBlock( 

181 in_c=reversed_layers[i], 

182 out_c=reversed_layers[i + 1], 

183 k_sz=k_sz, 

184 up_mode=up_mode, 

185 conv_bridge=conv_bridge, 

186 shortcut=shortcut, 

187 ) 

188 self.up_path.append(block) 

189 

190 # init, shamelessly lifted from torchvision/models/resnet.py 

191 for m in self.modules(): 

192 if isinstance(m, torch.nn.Conv2d): 

193 torch.nn.init.kaiming_normal_( 

194 m.weight, mode="fan_out", nonlinearity="relu" 

195 ) 

196 elif isinstance(m, (torch.nn.BatchNorm2d, torch.nn.GroupNorm)): 

197 torch.nn.init.constant_(m.weight, 1) 

198 torch.nn.init.constant_(m.bias, 0) 

199 

200 self.final = torch.nn.Conv2d(layers[0], n_classes, kernel_size=1) 

201 

202 def forward(self, x): 

203 x = self.first(x) 

204 down_activations = [] 

205 for i, down in enumerate(self.down_path): 

206 down_activations.append(x) 

207 x = down(x) 

208 down_activations.reverse() 

209 for i, up in enumerate(self.up_path): 

210 x = up(x, down_activations[i]) 

211 return self.final(x) 

212 

213 

214class LittleWNet(torch.nn.Module): 

215 """Little W-Net model, concatenating two Little U-Net models""" 

216 

217 def __init__( 

218 self, 

219 n_classes=1, 

220 in_c=3, 

221 layers=(8, 16, 32), 

222 conv_bridge=True, 

223 shortcut=True, 

224 mode="train", 

225 ): 

226 

227 super(LittleWNet, self).__init__() 

228 self.unet1 = LittleUNet( 

229 in_c=in_c, 

230 n_classes=n_classes, 

231 layers=layers, 

232 conv_bridge=conv_bridge, 

233 shortcut=shortcut, 

234 ) 

235 self.unet2 = LittleUNet( 

236 in_c=in_c + n_classes, 

237 n_classes=n_classes, 

238 layers=layers, 

239 conv_bridge=conv_bridge, 

240 shortcut=shortcut, 

241 ) 

242 self.n_classes = n_classes 

243 self.mode = mode 

244 

245 def forward(self, x): 

246 x1 = self.unet1(x) 

247 x2 = self.unet2(torch.cat([x, x1], dim=1)) 

248 if self.mode != "train": 

249 return x2 

250 return x1, x2 

251 

252 

253def lunet(input_channels=3, output_classes=1): 

254 """Builds Little U-Net segmentation network (uninitialized) 

255 

256 

257 Parameters 

258 ---------- 

259 

260 input_channels : :py:class:`int`, Optional 

261 Number of input channels the network should operate with 

262 

263 output_classes : :py:class:`int`, Optional 

264 Number of output classes 

265 

266 

267 Returns 

268 ------- 

269 

270 module : :py:class:`torch.nn.Module` 

271 Network model for Little U-Net 

272 

273 """ 

274 

275 return LittleUNet( 

276 in_c=input_channels, 

277 n_classes=output_classes, 

278 layers=[8, 16, 32], 

279 conv_bridge=True, 

280 shortcut=True, 

281 ) 

282 

283 

284def lwnet(input_channels=3, output_classes=1): 

285 """Builds Little W-Net segmentation network (uninitialized) 

286 

287 

288 Parameters 

289 ---------- 

290 

291 input_channels : :py:class:`int`, Optional 

292 Number of input channels the network should operate with 

293 

294 output_classes : :py:class:`int`, Optional 

295 Number of output classes 

296 

297 

298 Returns 

299 ------- 

300 

301 module : :py:class:`torch.nn.Module` 

302 Network model for Little W-Net 

303 

304 """ 

305 

306 return LittleWNet( 

307 in_c=input_channels, 

308 n_classes=output_classes, 

309 layers=[8, 16, 32], 

310 conv_bridge=True, 

311 shortcut=True, 

312 )