Coverage for src/deepdraw/models/lwnet.py: 92%

119 statements  

« prev     ^ index     » next       coverage.py v7.3.1, created at 2023-11-30 15:00 +0100

1# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> 

2# 

3# SPDX-License-Identifier: GPL-3.0-or-later 

4 

5"""Little W-Net. 

6 

7Code was originally developed by Adrian Galdran 

8(https://github.com/agaldran/lwnet), loosely inspired on 

9https://github.com/jvanvugt/pytorch-unet 

10 

11It is based on two simple U-Nets with 3 layers concatenated to each other. The 

12first U-Net produces a segmentation map that is used by the second to better 

13guide segmentation. 

14 

15Reference: [GALDRAN-2020]_ 

16""" 

17 

18 

19import torch 

20import torch.nn 

21 

22 

23def _conv1x1(in_planes, out_planes, stride=1): 

24 return torch.nn.Conv2d( 

25 in_planes, out_planes, kernel_size=1, stride=stride, bias=False 

26 ) 

27 

28 

29class ConvBlock(torch.nn.Module): 

30 def __init__(self, in_c, out_c, k_sz=3, shortcut=False, pool=True): 

31 """pool_mode can be False (no pooling) or True ('maxpool')""" 

32 

33 super().__init__() 

34 if shortcut is True: 

35 self.shortcut = torch.nn.Sequential( 

36 _conv1x1(in_c, out_c), torch.nn.BatchNorm2d(out_c) 

37 ) 

38 else: 

39 self.shortcut = False 

40 pad = (k_sz - 1) // 2 

41 

42 block = [] 

43 if pool: 

44 self.pool = torch.nn.MaxPool2d(kernel_size=2) 

45 else: 

46 self.pool = False 

47 

48 block.append( 

49 torch.nn.Conv2d(in_c, out_c, kernel_size=k_sz, padding=pad) 

50 ) 

51 block.append(torch.nn.ReLU()) 

52 block.append(torch.nn.BatchNorm2d(out_c)) 

53 

54 block.append( 

55 torch.nn.Conv2d(out_c, out_c, kernel_size=k_sz, padding=pad) 

56 ) 

57 block.append(torch.nn.ReLU()) 

58 block.append(torch.nn.BatchNorm2d(out_c)) 

59 

60 self.block = torch.nn.Sequential(*block) 

61 

62 def forward(self, x): 

63 if self.pool: 

64 x = self.pool(x) 

65 out = self.block(x) 

66 if self.shortcut: 

67 return out + self.shortcut(x) 

68 else: 

69 return out 

70 

71 

72class UpsampleBlock(torch.nn.Module): 

73 def __init__(self, in_c, out_c, up_mode="transp_conv"): 

74 super().__init__() 

75 block = [] 

76 if up_mode == "transp_conv": 

77 block.append( 

78 torch.nn.ConvTranspose2d(in_c, out_c, kernel_size=2, stride=2) 

79 ) 

80 elif up_mode == "up_conv": 

81 block.append( 

82 torch.nn.Upsample( 

83 mode="bilinear", scale_factor=2, align_corners=False 

84 ) 

85 ) 

86 block.append(torch.nn.Conv2d(in_c, out_c, kernel_size=1)) 

87 else: 

88 raise Exception("Upsampling mode not supported") 

89 

90 self.block = torch.nn.Sequential(*block) 

91 

92 def forward(self, x): 

93 out = self.block(x) 

94 return out 

95 

96 

97class ConvBridgeBlock(torch.nn.Module): 

98 def __init__(self, channels, k_sz=3): 

99 super().__init__() 

100 pad = (k_sz - 1) // 2 

101 block = [] 

102 

103 block.append( 

104 torch.nn.Conv2d(channels, channels, kernel_size=k_sz, padding=pad) 

105 ) 

106 block.append(torch.nn.ReLU()) 

107 block.append(torch.nn.BatchNorm2d(channels)) 

108 

109 self.block = torch.nn.Sequential(*block) 

110 

111 def forward(self, x): 

112 out = self.block(x) 

113 return out 

114 

115 

116class UpConvBlock(torch.nn.Module): 

117 def __init__( 

118 self, 

119 in_c, 

120 out_c, 

121 k_sz=3, 

122 up_mode="up_conv", 

123 conv_bridge=False, 

124 shortcut=False, 

125 ): 

126 super().__init__() 

127 self.conv_bridge = conv_bridge 

128 

129 self.up_layer = UpsampleBlock(in_c, out_c, up_mode=up_mode) 

130 self.conv_layer = ConvBlock( 

131 2 * out_c, out_c, k_sz=k_sz, shortcut=shortcut, pool=False 

132 ) 

133 if self.conv_bridge: 

134 self.conv_bridge_layer = ConvBridgeBlock(out_c, k_sz=k_sz) 

135 

136 def forward(self, x, skip): 

137 up = self.up_layer(x) 

138 if self.conv_bridge: 

139 out = torch.cat([up, self.conv_bridge_layer(skip)], dim=1) 

140 else: 

141 out = torch.cat([up, skip], dim=1) 

142 out = self.conv_layer(out) 

143 return out 

144 

145 

146class LittleUNet(torch.nn.Module): 

147 """Little U-Net model.""" 

148 

149 def __init__( 

150 self, 

151 in_c, 

152 n_classes, 

153 layers, 

154 k_sz=3, 

155 up_mode="transp_conv", 

156 conv_bridge=True, 

157 shortcut=True, 

158 ): 

159 super().__init__() 

160 self.n_classes = n_classes 

161 self.first = ConvBlock( 

162 in_c=in_c, out_c=layers[0], k_sz=k_sz, shortcut=shortcut, pool=False 

163 ) 

164 

165 self.down_path = torch.nn.ModuleList() 

166 for i in range(len(layers) - 1): 

167 block = ConvBlock( 

168 in_c=layers[i], 

169 out_c=layers[i + 1], 

170 k_sz=k_sz, 

171 shortcut=shortcut, 

172 pool=True, 

173 ) 

174 self.down_path.append(block) 

175 

176 self.up_path = torch.nn.ModuleList() 

177 reversed_layers = list(reversed(layers)) 

178 for i in range(len(layers) - 1): 

179 block = UpConvBlock( 

180 in_c=reversed_layers[i], 

181 out_c=reversed_layers[i + 1], 

182 k_sz=k_sz, 

183 up_mode=up_mode, 

184 conv_bridge=conv_bridge, 

185 shortcut=shortcut, 

186 ) 

187 self.up_path.append(block) 

188 

189 # init, shamelessly lifted from torchvision/models/resnet.py 

190 for m in self.modules(): 

191 if isinstance(m, torch.nn.Conv2d): 

192 torch.nn.init.kaiming_normal_( 

193 m.weight, mode="fan_out", nonlinearity="relu" 

194 ) 

195 elif isinstance(m, (torch.nn.BatchNorm2d, torch.nn.GroupNorm)): 

196 torch.nn.init.constant_(m.weight, 1) 

197 torch.nn.init.constant_(m.bias, 0) 

198 

199 self.final = torch.nn.Conv2d(layers[0], n_classes, kernel_size=1) 

200 

201 def forward(self, x): 

202 x = self.first(x) 

203 down_activations = [] 

204 for i, down in enumerate(self.down_path): 

205 down_activations.append(x) 

206 x = down(x) 

207 down_activations.reverse() 

208 for i, up in enumerate(self.up_path): 

209 x = up(x, down_activations[i]) 

210 return self.final(x) 

211 

212 

213class LittleWNet(torch.nn.Module): 

214 """Little W-Net model, concatenating two Little U-Net models.""" 

215 

216 def __init__( 

217 self, 

218 n_classes=1, 

219 in_c=3, 

220 layers=(8, 16, 32), 

221 conv_bridge=True, 

222 shortcut=True, 

223 mode="train", 

224 ): 

225 super().__init__() 

226 self.unet1 = LittleUNet( 

227 in_c=in_c, 

228 n_classes=n_classes, 

229 layers=layers, 

230 conv_bridge=conv_bridge, 

231 shortcut=shortcut, 

232 ) 

233 self.unet2 = LittleUNet( 

234 in_c=in_c + n_classes, 

235 n_classes=n_classes, 

236 layers=layers, 

237 conv_bridge=conv_bridge, 

238 shortcut=shortcut, 

239 ) 

240 self.n_classes = n_classes 

241 self.mode = mode 

242 

243 def forward(self, x): 

244 x1 = self.unet1(x) 

245 x2 = self.unet2(torch.cat([x, x1], dim=1)) 

246 if self.mode != "train": 

247 return x2 

248 return x1, x2 

249 

250 

251def lunet(input_channels=3, output_classes=1): 

252 """Builds Little U-Net segmentation network (uninitialized) 

253 

254 Parameters 

255 ---------- 

256 

257 input_channels : :py:class:`int`, Optional 

258 Number of input channels the network should operate with 

259 

260 output_classes : :py:class:`int`, Optional 

261 Number of output classes 

262 

263 

264 Returns 

265 ------- 

266 

267 module : :py:class:`torch.nn.Module` 

268 Network model for Little U-Net 

269 """ 

270 

271 return LittleUNet( 

272 in_c=input_channels, 

273 n_classes=output_classes, 

274 layers=[8, 16, 32], 

275 conv_bridge=True, 

276 shortcut=True, 

277 ) 

278 

279 

280def lwnet(input_channels=3, output_classes=1): 

281 """Builds Little W-Net segmentation network (uninitialized) 

282 

283 Parameters 

284 ---------- 

285 

286 input_channels : :py:class:`int`, Optional 

287 Number of input channels the network should operate with 

288 

289 output_classes : :py:class:`int`, Optional 

290 Number of output classes 

291 

292 

293 Returns 

294 ------- 

295 

296 module : :py:class:`torch.nn.Module` 

297 Network model for Little W-Net 

298 """ 

299 

300 return LittleWNet( 

301 in_c=input_channels, 

302 n_classes=output_classes, 

303 layers=[8, 16, 32], 

304 conv_bridge=True, 

305 shortcut=True, 

306 )