Coverage for src/deepdraw/models/driu_bn.py: 75%
40 statements
« prev ^ index » next coverage.py v7.3.1, created at 2023-11-30 15:00 +0100
« prev ^ index » next coverage.py v7.3.1, created at 2023-11-30 15:00 +0100
1# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
2#
3# SPDX-License-Identifier: GPL-3.0-or-later
5from collections import OrderedDict
7import torch
8import torch.nn
10from .backbones.vgg import vgg16_bn_for_segmentation
11from .make_layers import UpsampleCropBlock, conv_with_kaiming_uniform
14class ConcatFuseBlock(torch.nn.Module):
15 """Takes in four feature maps with 16 channels each, concatenates them and
16 applies a 1x1 convolution with 1 output channel."""
18 def __init__(self):
19 super().__init__()
20 self.conv = torch.nn.Sequential(
21 conv_with_kaiming_uniform(4 * 16, 1, 1, 1, 0),
22 torch.nn.BatchNorm2d(1),
23 )
25 def forward(self, x1, x2, x3, x4):
26 x_cat = torch.cat([x1, x2, x3, x4], dim=1)
27 x = self.conv(x_cat)
28 return x
31class DRIUBN(torch.nn.Module):
32 """DRIU with Batch-Normalization head module.
34 Based on paper by [MANINIS-2016]_.
36 Parameters
37 ----------
38 in_channels_list : list
39 number of channels for each feature map that is returned from backbone
40 """
42 def __init__(self, in_channels_list=None):
43 super().__init__()
44 (
45 in_conv_1_2_16,
46 in_upsample2,
47 in_upsample_4,
48 in_upsample_8,
49 ) = in_channels_list
51 self.conv1_2_16 = torch.nn.Conv2d(in_conv_1_2_16, 16, 3, 1, 1)
52 # Upsample layers
53 self.upsample2 = UpsampleCropBlock(in_upsample2, 16, 4, 2, 0)
54 self.upsample4 = UpsampleCropBlock(in_upsample_4, 16, 8, 4, 0)
55 self.upsample8 = UpsampleCropBlock(in_upsample_8, 16, 16, 8, 0)
57 # Concat and Fuse
58 self.concatfuse = ConcatFuseBlock()
60 def forward(self, x):
61 """
62 Parameters
63 ----------
64 x : list
65 list of tensors as returned from the backbone network.
66 First element: height and width of input image.
67 Remaining elements: feature maps for each feature level.
69 Returns
70 -------
71 :py:class:`torch.Tensor`
72 """
73 hw = x[0]
74 conv1_2_16 = self.conv1_2_16(x[1]) # conv1_2_16
75 upsample2 = self.upsample2(x[2], hw) # side-multi2-up
76 upsample4 = self.upsample4(x[3], hw) # side-multi3-up
77 upsample8 = self.upsample8(x[4], hw) # side-multi4-up
78 out = self.concatfuse(conv1_2_16, upsample2, upsample4, upsample8)
79 return out
82def driu_bn(pretrained_backbone=True, progress=True):
83 """Builds DRIU with batch-normalization by adding backbone and head
84 together.
86 Parameters
87 ----------
89 pretrained_backbone : :py:class:`bool`, Optional
90 If set to ``True``, then loads a pre-trained version of the backbone
91 (not the head) for the DRIU network using VGG-16 trained for ImageNet
92 classification.
94 progress : :py:class:`bool`, Optional
95 If set to ``True``, and you decided to use a ``pretrained_backbone``,
96 then, shows a progress bar of the backbone model downloading if
97 download is necesssary.
100 Returns
101 -------
103 module : :py:class:`torch.nn.Module`
104 Network model for DRIU (vessel segmentation) using batch normalization
105 """
107 backbone = vgg16_bn_for_segmentation(
108 pretrained=False, return_features=[5, 12, 19, 29]
109 )
110 head = DRIUBN([64, 128, 256, 512])
112 order = [("backbone", backbone), ("head", head)]
113 if pretrained_backbone:
114 from .normalizer import TorchVisionNormalizer
116 order = [("normalizer", TorchVisionNormalizer())] + order
118 model = torch.nn.Sequential(OrderedDict(order))
119 model.name = "driu-bn"
120 return model