Coverage for src/deepdraw/models/unet.py: 80%
30 statements
« prev ^ index » next coverage.py v7.3.1, created at 2023-11-30 15:00 +0100
« prev ^ index » next coverage.py v7.3.1, created at 2023-11-30 15:00 +0100
1# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
2#
3# SPDX-License-Identifier: GPL-3.0-or-later
5from collections import OrderedDict
7import torch.nn
9from .backbones.vgg import vgg16_for_segmentation
10from .make_layers import UnetBlock, conv_with_kaiming_uniform
13class UNet(torch.nn.Module):
14 """UNet head module.
16 Parameters
17 ----------
18 in_channels_list : list
19 number of channels for each feature map that is returned from backbone
20 """
22 def __init__(self, in_channels_list=None, pixel_shuffle=False):
23 super().__init__()
24 # number of channels
25 c_decode1, c_decode2, c_decode3, c_decode4, c_decode5 = in_channels_list
27 # build layers
28 self.decode4 = UnetBlock(
29 c_decode5, c_decode4, pixel_shuffle, middle_block=True
30 )
31 self.decode3 = UnetBlock(c_decode4, c_decode3, pixel_shuffle)
32 self.decode2 = UnetBlock(c_decode3, c_decode2, pixel_shuffle)
33 self.decode1 = UnetBlock(c_decode2, c_decode1, pixel_shuffle)
34 self.final = conv_with_kaiming_uniform(c_decode1, 1, 1)
36 def forward(self, x):
37 """
38 Parameters
39 ----------
40 x : list
41 list of tensors as returned from the backbone network.
42 First element: height and width of input image.
43 Remaining elements: feature maps for each feature level.
44 """
45 # NOTE: x[0]: height and width of input image not needed in U-Net architecture
46 decode4 = self.decode4(x[5], x[4])
47 decode3 = self.decode3(decode4, x[3])
48 decode2 = self.decode2(decode3, x[2])
49 decode1 = self.decode1(decode2, x[1])
50 out = self.final(decode1)
51 return out
54def unet(pretrained_backbone=True, progress=True):
55 """Builds U-Net segmentation network by adding backbone and head together.
57 Parameters
58 ----------
60 pretrained_backbone : :py:class:`bool`, Optional
61 If set to ``True``, then loads a pre-trained version of the backbone
62 (not the head) for the DRIU network using VGG-16 trained for ImageNet
63 classification.
65 progress : :py:class:`bool`, Optional
66 If set to ``True``, and you decided to use a ``pretrained_backbone``,
67 then, shows a progress bar of the backbone model downloading if
68 download is necesssary.
71 Returns
72 -------
74 module : :py:class:`torch.nn.Module`
75 Network model for U-Net
76 """
78 backbone = vgg16_for_segmentation(
79 pretrained=pretrained_backbone,
80 progress=progress,
81 return_features=[3, 8, 14, 22, 29],
82 )
83 head = UNet([64, 128, 256, 512, 512], pixel_shuffle=False)
85 order = [("backbone", backbone), ("head", head)]
86 if pretrained_backbone:
87 from .normalizer import TorchVisionNormalizer
89 order = [("normalizer", TorchVisionNormalizer())] + order
91 model = torch.nn.Sequential(OrderedDict(order))
92 model.name = "unet"
93 return model