Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1#!/usr/bin/env python
2# coding=utf-8
4"""Little W-Net
6Code was originally developed by Adrian Galdran
7(https://github.com/agaldran/lwnet), loosely inspired on
8https://github.com/jvanvugt/pytorch-unet
10It is based on two simple U-Nets with 3 layers concatenated to each other. The
11first U-Net produces a segmentation map that is used by the second to better
12guide segmentation.
14Reference: [GALDRAN-2020]_
15"""
18import torch
19import torch.nn
22def _conv1x1(in_planes, out_planes, stride=1):
23 return torch.nn.Conv2d(
24 in_planes, out_planes, kernel_size=1, stride=stride, bias=False
25 )
28class ConvBlock(torch.nn.Module):
29 def __init__(self, in_c, out_c, k_sz=3, shortcut=False, pool=True):
30 """
31 pool_mode can be False (no pooling) or True ('maxpool')
32 """
34 super(ConvBlock, self).__init__()
35 if shortcut is True:
36 self.shortcut = torch.nn.Sequential(
37 _conv1x1(in_c, out_c), torch.nn.BatchNorm2d(out_c)
38 )
39 else:
40 self.shortcut = False
41 pad = (k_sz - 1) // 2
43 block = []
44 if pool:
45 self.pool = torch.nn.MaxPool2d(kernel_size=2)
46 else:
47 self.pool = False
49 block.append(
50 torch.nn.Conv2d(in_c, out_c, kernel_size=k_sz, padding=pad)
51 )
52 block.append(torch.nn.ReLU())
53 block.append(torch.nn.BatchNorm2d(out_c))
55 block.append(
56 torch.nn.Conv2d(out_c, out_c, kernel_size=k_sz, padding=pad)
57 )
58 block.append(torch.nn.ReLU())
59 block.append(torch.nn.BatchNorm2d(out_c))
61 self.block = torch.nn.Sequential(*block)
63 def forward(self, x):
64 if self.pool:
65 x = self.pool(x)
66 out = self.block(x)
67 if self.shortcut:
68 return out + self.shortcut(x)
69 else:
70 return out
73class UpsampleBlock(torch.nn.Module):
74 def __init__(self, in_c, out_c, up_mode="transp_conv"):
75 super(UpsampleBlock, self).__init__()
76 block = []
77 if up_mode == "transp_conv":
78 block.append(
79 torch.nn.ConvTranspose2d(in_c, out_c, kernel_size=2, stride=2)
80 )
81 elif up_mode == "up_conv":
82 block.append(
83 torch.nn.Upsample(
84 mode="bilinear", scale_factor=2, align_corners=False
85 )
86 )
87 block.append(torch.nn.Conv2d(in_c, out_c, kernel_size=1))
88 else:
89 raise Exception("Upsampling mode not supported")
91 self.block = torch.nn.Sequential(*block)
93 def forward(self, x):
94 out = self.block(x)
95 return out
98class ConvBridgeBlock(torch.nn.Module):
99 def __init__(self, channels, k_sz=3):
100 super(ConvBridgeBlock, self).__init__()
101 pad = (k_sz - 1) // 2
102 block = []
104 block.append(
105 torch.nn.Conv2d(channels, channels, kernel_size=k_sz, padding=pad)
106 )
107 block.append(torch.nn.ReLU())
108 block.append(torch.nn.BatchNorm2d(channels))
110 self.block = torch.nn.Sequential(*block)
112 def forward(self, x):
113 out = self.block(x)
114 return out
117class UpConvBlock(torch.nn.Module):
118 def __init__(
119 self,
120 in_c,
121 out_c,
122 k_sz=3,
123 up_mode="up_conv",
124 conv_bridge=False,
125 shortcut=False,
126 ):
127 super(UpConvBlock, self).__init__()
128 self.conv_bridge = conv_bridge
130 self.up_layer = UpsampleBlock(in_c, out_c, up_mode=up_mode)
131 self.conv_layer = ConvBlock(
132 2 * out_c, out_c, k_sz=k_sz, shortcut=shortcut, pool=False
133 )
134 if self.conv_bridge:
135 self.conv_bridge_layer = ConvBridgeBlock(out_c, k_sz=k_sz)
137 def forward(self, x, skip):
138 up = self.up_layer(x)
139 if self.conv_bridge:
140 out = torch.cat([up, self.conv_bridge_layer(skip)], dim=1)
141 else:
142 out = torch.cat([up, skip], dim=1)
143 out = self.conv_layer(out)
144 return out
147class LittleUNet(torch.nn.Module):
148 """Little U-Net model"""
150 def __init__(
151 self,
152 in_c,
153 n_classes,
154 layers,
155 k_sz=3,
156 up_mode="transp_conv",
157 conv_bridge=True,
158 shortcut=True,
159 ):
160 super(LittleUNet, self).__init__()
161 self.n_classes = n_classes
162 self.first = ConvBlock(
163 in_c=in_c, out_c=layers[0], k_sz=k_sz, shortcut=shortcut, pool=False
164 )
166 self.down_path = torch.nn.ModuleList()
167 for i in range(len(layers) - 1):
168 block = ConvBlock(
169 in_c=layers[i],
170 out_c=layers[i + 1],
171 k_sz=k_sz,
172 shortcut=shortcut,
173 pool=True,
174 )
175 self.down_path.append(block)
177 self.up_path = torch.nn.ModuleList()
178 reversed_layers = list(reversed(layers))
179 for i in range(len(layers) - 1):
180 block = UpConvBlock(
181 in_c=reversed_layers[i],
182 out_c=reversed_layers[i + 1],
183 k_sz=k_sz,
184 up_mode=up_mode,
185 conv_bridge=conv_bridge,
186 shortcut=shortcut,
187 )
188 self.up_path.append(block)
190 # init, shamelessly lifted from torchvision/models/resnet.py
191 for m in self.modules():
192 if isinstance(m, torch.nn.Conv2d):
193 torch.nn.init.kaiming_normal_(
194 m.weight, mode="fan_out", nonlinearity="relu"
195 )
196 elif isinstance(m, (torch.nn.BatchNorm2d, torch.nn.GroupNorm)):
197 torch.nn.init.constant_(m.weight, 1)
198 torch.nn.init.constant_(m.bias, 0)
200 self.final = torch.nn.Conv2d(layers[0], n_classes, kernel_size=1)
202 def forward(self, x):
203 x = self.first(x)
204 down_activations = []
205 for i, down in enumerate(self.down_path):
206 down_activations.append(x)
207 x = down(x)
208 down_activations.reverse()
209 for i, up in enumerate(self.up_path):
210 x = up(x, down_activations[i])
211 return self.final(x)
214class LittleWNet(torch.nn.Module):
215 """Little W-Net model, concatenating two Little U-Net models"""
217 def __init__(
218 self,
219 n_classes=1,
220 in_c=3,
221 layers=(8, 16, 32),
222 conv_bridge=True,
223 shortcut=True,
224 mode="train",
225 ):
227 super(LittleWNet, self).__init__()
228 self.unet1 = LittleUNet(
229 in_c=in_c,
230 n_classes=n_classes,
231 layers=layers,
232 conv_bridge=conv_bridge,
233 shortcut=shortcut,
234 )
235 self.unet2 = LittleUNet(
236 in_c=in_c + n_classes,
237 n_classes=n_classes,
238 layers=layers,
239 conv_bridge=conv_bridge,
240 shortcut=shortcut,
241 )
242 self.n_classes = n_classes
243 self.mode = mode
245 def forward(self, x):
246 x1 = self.unet1(x)
247 x2 = self.unet2(torch.cat([x, x1], dim=1))
248 if self.mode != "train":
249 return x2
250 return x1, x2
253def lunet(input_channels=3, output_classes=1):
254 """Builds Little U-Net segmentation network (uninitialized)
257 Parameters
258 ----------
260 input_channels : :py:class:`int`, Optional
261 Number of input channels the network should operate with
263 output_classes : :py:class:`int`, Optional
264 Number of output classes
267 Returns
268 -------
270 module : :py:class:`torch.nn.Module`
271 Network model for Little U-Net
273 """
275 return LittleUNet(
276 in_c=input_channels,
277 n_classes=output_classes,
278 layers=[8, 16, 32],
279 conv_bridge=True,
280 shortcut=True,
281 )
284def lwnet(input_channels=3, output_classes=1):
285 """Builds Little W-Net segmentation network (uninitialized)
288 Parameters
289 ----------
291 input_channels : :py:class:`int`, Optional
292 Number of input channels the network should operate with
294 output_classes : :py:class:`int`, Optional
295 Number of output classes
298 Returns
299 -------
301 module : :py:class:`torch.nn.Module`
302 Network model for Little W-Net
304 """
306 return LittleWNet(
307 in_c=input_channels,
308 n_classes=output_classes,
309 layers=[8, 16, 32],
310 conv_bridge=True,
311 shortcut=True,
312 )