Coverage for src/deepdraw/models/lwnet.py: 92%
119 statements
« prev ^ index » next coverage.py v7.3.1, created at 2023-11-30 15:00 +0100
« prev ^ index » next coverage.py v7.3.1, created at 2023-11-30 15:00 +0100
1# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
2#
3# SPDX-License-Identifier: GPL-3.0-or-later
5"""Little W-Net.
7Code was originally developed by Adrian Galdran
8(https://github.com/agaldran/lwnet), loosely inspired on
9https://github.com/jvanvugt/pytorch-unet
11It is based on two simple U-Nets with 3 layers concatenated to each other. The
12first U-Net produces a segmentation map that is used by the second to better
13guide segmentation.
15Reference: [GALDRAN-2020]_
16"""
19import torch
20import torch.nn
23def _conv1x1(in_planes, out_planes, stride=1):
24 return torch.nn.Conv2d(
25 in_planes, out_planes, kernel_size=1, stride=stride, bias=False
26 )
29class ConvBlock(torch.nn.Module):
30 def __init__(self, in_c, out_c, k_sz=3, shortcut=False, pool=True):
31 """pool_mode can be False (no pooling) or True ('maxpool')"""
33 super().__init__()
34 if shortcut is True:
35 self.shortcut = torch.nn.Sequential(
36 _conv1x1(in_c, out_c), torch.nn.BatchNorm2d(out_c)
37 )
38 else:
39 self.shortcut = False
40 pad = (k_sz - 1) // 2
42 block = []
43 if pool:
44 self.pool = torch.nn.MaxPool2d(kernel_size=2)
45 else:
46 self.pool = False
48 block.append(
49 torch.nn.Conv2d(in_c, out_c, kernel_size=k_sz, padding=pad)
50 )
51 block.append(torch.nn.ReLU())
52 block.append(torch.nn.BatchNorm2d(out_c))
54 block.append(
55 torch.nn.Conv2d(out_c, out_c, kernel_size=k_sz, padding=pad)
56 )
57 block.append(torch.nn.ReLU())
58 block.append(torch.nn.BatchNorm2d(out_c))
60 self.block = torch.nn.Sequential(*block)
62 def forward(self, x):
63 if self.pool:
64 x = self.pool(x)
65 out = self.block(x)
66 if self.shortcut:
67 return out + self.shortcut(x)
68 else:
69 return out
72class UpsampleBlock(torch.nn.Module):
73 def __init__(self, in_c, out_c, up_mode="transp_conv"):
74 super().__init__()
75 block = []
76 if up_mode == "transp_conv":
77 block.append(
78 torch.nn.ConvTranspose2d(in_c, out_c, kernel_size=2, stride=2)
79 )
80 elif up_mode == "up_conv":
81 block.append(
82 torch.nn.Upsample(
83 mode="bilinear", scale_factor=2, align_corners=False
84 )
85 )
86 block.append(torch.nn.Conv2d(in_c, out_c, kernel_size=1))
87 else:
88 raise Exception("Upsampling mode not supported")
90 self.block = torch.nn.Sequential(*block)
92 def forward(self, x):
93 out = self.block(x)
94 return out
97class ConvBridgeBlock(torch.nn.Module):
98 def __init__(self, channels, k_sz=3):
99 super().__init__()
100 pad = (k_sz - 1) // 2
101 block = []
103 block.append(
104 torch.nn.Conv2d(channels, channels, kernel_size=k_sz, padding=pad)
105 )
106 block.append(torch.nn.ReLU())
107 block.append(torch.nn.BatchNorm2d(channels))
109 self.block = torch.nn.Sequential(*block)
111 def forward(self, x):
112 out = self.block(x)
113 return out
116class UpConvBlock(torch.nn.Module):
117 def __init__(
118 self,
119 in_c,
120 out_c,
121 k_sz=3,
122 up_mode="up_conv",
123 conv_bridge=False,
124 shortcut=False,
125 ):
126 super().__init__()
127 self.conv_bridge = conv_bridge
129 self.up_layer = UpsampleBlock(in_c, out_c, up_mode=up_mode)
130 self.conv_layer = ConvBlock(
131 2 * out_c, out_c, k_sz=k_sz, shortcut=shortcut, pool=False
132 )
133 if self.conv_bridge:
134 self.conv_bridge_layer = ConvBridgeBlock(out_c, k_sz=k_sz)
136 def forward(self, x, skip):
137 up = self.up_layer(x)
138 if self.conv_bridge:
139 out = torch.cat([up, self.conv_bridge_layer(skip)], dim=1)
140 else:
141 out = torch.cat([up, skip], dim=1)
142 out = self.conv_layer(out)
143 return out
146class LittleUNet(torch.nn.Module):
147 """Little U-Net model."""
149 def __init__(
150 self,
151 in_c,
152 n_classes,
153 layers,
154 k_sz=3,
155 up_mode="transp_conv",
156 conv_bridge=True,
157 shortcut=True,
158 ):
159 super().__init__()
160 self.n_classes = n_classes
161 self.first = ConvBlock(
162 in_c=in_c, out_c=layers[0], k_sz=k_sz, shortcut=shortcut, pool=False
163 )
165 self.down_path = torch.nn.ModuleList()
166 for i in range(len(layers) - 1):
167 block = ConvBlock(
168 in_c=layers[i],
169 out_c=layers[i + 1],
170 k_sz=k_sz,
171 shortcut=shortcut,
172 pool=True,
173 )
174 self.down_path.append(block)
176 self.up_path = torch.nn.ModuleList()
177 reversed_layers = list(reversed(layers))
178 for i in range(len(layers) - 1):
179 block = UpConvBlock(
180 in_c=reversed_layers[i],
181 out_c=reversed_layers[i + 1],
182 k_sz=k_sz,
183 up_mode=up_mode,
184 conv_bridge=conv_bridge,
185 shortcut=shortcut,
186 )
187 self.up_path.append(block)
189 # init, shamelessly lifted from torchvision/models/resnet.py
190 for m in self.modules():
191 if isinstance(m, torch.nn.Conv2d):
192 torch.nn.init.kaiming_normal_(
193 m.weight, mode="fan_out", nonlinearity="relu"
194 )
195 elif isinstance(m, (torch.nn.BatchNorm2d, torch.nn.GroupNorm)):
196 torch.nn.init.constant_(m.weight, 1)
197 torch.nn.init.constant_(m.bias, 0)
199 self.final = torch.nn.Conv2d(layers[0], n_classes, kernel_size=1)
201 def forward(self, x):
202 x = self.first(x)
203 down_activations = []
204 for i, down in enumerate(self.down_path):
205 down_activations.append(x)
206 x = down(x)
207 down_activations.reverse()
208 for i, up in enumerate(self.up_path):
209 x = up(x, down_activations[i])
210 return self.final(x)
213class LittleWNet(torch.nn.Module):
214 """Little W-Net model, concatenating two Little U-Net models."""
216 def __init__(
217 self,
218 n_classes=1,
219 in_c=3,
220 layers=(8, 16, 32),
221 conv_bridge=True,
222 shortcut=True,
223 mode="train",
224 ):
225 super().__init__()
226 self.unet1 = LittleUNet(
227 in_c=in_c,
228 n_classes=n_classes,
229 layers=layers,
230 conv_bridge=conv_bridge,
231 shortcut=shortcut,
232 )
233 self.unet2 = LittleUNet(
234 in_c=in_c + n_classes,
235 n_classes=n_classes,
236 layers=layers,
237 conv_bridge=conv_bridge,
238 shortcut=shortcut,
239 )
240 self.n_classes = n_classes
241 self.mode = mode
243 def forward(self, x):
244 x1 = self.unet1(x)
245 x2 = self.unet2(torch.cat([x, x1], dim=1))
246 if self.mode != "train":
247 return x2
248 return x1, x2
251def lunet(input_channels=3, output_classes=1):
252 """Builds Little U-Net segmentation network (uninitialized)
254 Parameters
255 ----------
257 input_channels : :py:class:`int`, Optional
258 Number of input channels the network should operate with
260 output_classes : :py:class:`int`, Optional
261 Number of output classes
264 Returns
265 -------
267 module : :py:class:`torch.nn.Module`
268 Network model for Little U-Net
269 """
271 return LittleUNet(
272 in_c=input_channels,
273 n_classes=output_classes,
274 layers=[8, 16, 32],
275 conv_bridge=True,
276 shortcut=True,
277 )
280def lwnet(input_channels=3, output_classes=1):
281 """Builds Little W-Net segmentation network (uninitialized)
283 Parameters
284 ----------
286 input_channels : :py:class:`int`, Optional
287 Number of input channels the network should operate with
289 output_classes : :py:class:`int`, Optional
290 Number of output classes
293 Returns
294 -------
296 module : :py:class:`torch.nn.Module`
297 Network model for Little W-Net
298 """
300 return LittleWNet(
301 in_c=input_channels,
302 n_classes=output_classes,
303 layers=[8, 16, 32],
304 conv_bridge=True,
305 shortcut=True,
306 )