Coverage for src/deepdraw/models/resunet.py: 77%
35 statements
« prev ^ index » next coverage.py v7.3.1, created at 2023-11-30 15:00 +0100
« prev ^ index » next coverage.py v7.3.1, created at 2023-11-30 15:00 +0100
1# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
2#
3# SPDX-License-Identifier: GPL-3.0-or-later
5from collections import OrderedDict
7import torch.nn
9from .backbones.resnet import resnet50_for_segmentation
10from .make_layers import (
11 PixelShuffle_ICNR,
12 UnetBlock,
13 conv_with_kaiming_uniform,
14 convtrans_with_kaiming_uniform,
15)
18class ResUNet(torch.nn.Module):
19 """UNet head module for ResNet backbones.
21 Parameters
22 ----------
24 in_channels_list : :py:class:`list`, Optional
25 number of channels for each feature map that is returned from backbone
27 pixel_shuffle : :py:class:`bool`, Optional
28 if should use pixel shuffling instead of pooling
29 """
31 def __init__(self, in_channels_list=None, pixel_shuffle=False):
32 super().__init__()
33 # number of channels
34 c_decode1, c_decode2, c_decode3, c_decode4, c_decode5 = in_channels_list
35 # number of channels for last upsampling operation
36 c_decode0 = (c_decode1 + c_decode2 // 2) // 2
38 # build layers
39 self.decode4 = UnetBlock(c_decode5, c_decode4, pixel_shuffle)
40 self.decode3 = UnetBlock(c_decode4, c_decode3, pixel_shuffle)
41 self.decode2 = UnetBlock(c_decode3, c_decode2, pixel_shuffle)
42 self.decode1 = UnetBlock(c_decode2, c_decode1, pixel_shuffle)
43 if pixel_shuffle:
44 self.decode0 = PixelShuffle_ICNR(c_decode0, c_decode0)
45 else:
46 self.decode0 = convtrans_with_kaiming_uniform(
47 c_decode0, c_decode0, 2, 2
48 )
49 self.final = conv_with_kaiming_uniform(c_decode0, 1, 1)
51 def forward(self, x):
52 """
53 Parameters
54 ----------
55 x : list
56 list of tensors as returned from the backbone network.
57 First element: height and width of input image.
58 Remaining elements: feature maps for each feature level.
59 """
60 # NOTE: x[0]: height and width of input image not needed in U-Net
61 # architecture
62 decode4 = self.decode4(x[5], x[4])
63 decode3 = self.decode3(decode4, x[3])
64 decode2 = self.decode2(decode3, x[2])
65 decode1 = self.decode1(decode2, x[1])
66 decode0 = self.decode0(decode1)
67 out = self.final(decode0)
68 return out
71def resunet50(pretrained_backbone=True, progress=True):
72 """Builds Residual-U-Net-50 by adding backbone and head together.
74 Parameters
75 ----------
77 pretrained_backbone : :py:class:`bool`, Optional
78 If set to ``True``, then loads a pre-trained version of the backbone
79 (not the head) for the DRIU network using VGG-16 trained for ImageNet
80 classification.
82 progress : :py:class:`bool`, Optional
83 If set to ``True``, and you decided to use a ``pretrained_backbone``,
84 then, shows a progress bar of the backbone model downloading if
85 download is necesssary.
88 Returns
89 -------
91 module : :py:class:`torch.nn.Module`
92 Network model for Residual U-Net 50
93 """
95 backbone = resnet50_for_segmentation(
96 pretrained=pretrained_backbone,
97 progress=progress,
98 return_features=[2, 4, 5, 6, 7],
99 )
100 head = ResUNet([64, 256, 512, 1024, 2048], pixel_shuffle=False)
102 order = [("backbone", backbone), ("head", head)]
103 if pretrained_backbone:
104 from .normalizer import TorchVisionNormalizer
106 order = [("normalizer", TorchVisionNormalizer())] + order
108 model = torch.nn.Sequential(OrderedDict(order))
109 model.name = "resunet50"
110 return model