Coverage for src/deepdraw/models/driu.py: 75%
40 statements
« prev ^ index » next coverage.py v7.3.1, created at 2023-11-30 15:00 +0100
« prev ^ index » next coverage.py v7.3.1, created at 2023-11-30 15:00 +0100
1# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
2#
3# SPDX-License-Identifier: GPL-3.0-or-later
5from collections import OrderedDict
7import torch
8import torch.nn
10from .backbones.vgg import vgg16_for_segmentation
11from .make_layers import UpsampleCropBlock, conv_with_kaiming_uniform
14class ConcatFuseBlock(torch.nn.Module):
15 """Takes in four feature maps with 16 channels each, concatenates them and
16 applies a 1x1 convolution with 1 output channel."""
18 def __init__(self):
19 super().__init__()
20 self.conv = conv_with_kaiming_uniform(4 * 16, 1, 1, 1, 0)
22 def forward(self, x1, x2, x3, x4):
23 x_cat = torch.cat([x1, x2, x3, x4], dim=1)
24 x = self.conv(x_cat)
25 return x
28class DRIU(torch.nn.Module):
29 """DRIU head module.
31 Based on paper by [MANINIS-2016]_.
33 Parameters
34 ----------
35 in_channels_list : list
36 number of channels for each feature map that is returned from backbone
37 """
39 def __init__(self, in_channels_list=None):
40 super().__init__()
41 (
42 in_conv_1_2_16,
43 in_upsample2,
44 in_upsample_4,
45 in_upsample_8,
46 ) = in_channels_list
48 self.conv1_2_16 = torch.nn.Conv2d(in_conv_1_2_16, 16, 3, 1, 1)
49 # Upsample layers
50 self.upsample2 = UpsampleCropBlock(in_upsample2, 16, 4, 2, 0)
51 self.upsample4 = UpsampleCropBlock(in_upsample_4, 16, 8, 4, 0)
52 self.upsample8 = UpsampleCropBlock(in_upsample_8, 16, 16, 8, 0)
54 # Concat and Fuse
55 self.concatfuse = ConcatFuseBlock()
57 def forward(self, x):
58 """
60 Parameters
61 ----------
63 x : list
64 list of tensors as returned from the backbone network. First
65 element: height and width of input image. Remaining elements:
66 feature maps for each feature level.
68 Returns
69 -------
71 tensor : :py:class:`torch.Tensor`
73 """
74 hw = x[0]
75 conv1_2_16 = self.conv1_2_16(x[1]) # conv1_2_16
76 upsample2 = self.upsample2(x[2], hw) # side-multi2-up
77 upsample4 = self.upsample4(x[3], hw) # side-multi3-up
78 upsample8 = self.upsample8(x[4], hw) # side-multi4-up
79 out = self.concatfuse(conv1_2_16, upsample2, upsample4, upsample8)
80 return out
83def driu(pretrained_backbone=True, progress=True):
84 """Builds DRIU for vessel segmentation by adding backbone and head
85 together.
87 Parameters
88 ----------
90 pretrained_backbone : :py:class:`bool`, Optional
91 If set to ``True``, then loads a pre-trained version of the backbone
92 (not the head) for the DRIU network using VGG-16 trained for ImageNet
93 classification.
95 progress : :py:class:`bool`, Optional
96 If set to ``True``, and you decided to use a ``pretrained_backbone``,
97 then, shows a progress bar of the backbone model downloading if
98 download is necesssary.
101 Returns
102 -------
104 module : :py:class:`torch.nn.Module`
105 Network model for DRIU (vessel segmentation)
106 """
108 backbone = vgg16_for_segmentation(
109 pretrained=pretrained_backbone,
110 progress=progress,
111 return_features=[3, 8, 14, 22],
112 )
113 head = DRIU([64, 128, 256, 512])
115 order = [("backbone", backbone), ("head", head)]
116 if pretrained_backbone:
117 from .normalizer import TorchVisionNormalizer
119 order = [("normalizer", TorchVisionNormalizer())] + order
121 model = torch.nn.Sequential(OrderedDict(order))
122 model.name = "driu"
123 return model