Coverage for src/deepdraw/models/hed.py: 74%
42 statements
« prev ^ index » next coverage.py v7.3.1, created at 2023-11-30 15:00 +0100
« prev ^ index » next coverage.py v7.3.1, created at 2023-11-30 15:00 +0100
1# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
2#
3# SPDX-License-Identifier: GPL-3.0-or-later
5from collections import OrderedDict
7import torch
8import torch.nn
10from .backbones.vgg import vgg16_for_segmentation
11from .make_layers import UpsampleCropBlock, conv_with_kaiming_uniform
14class ConcatFuseBlock(torch.nn.Module):
15 """Takes in five feature maps with one channel each, concatenates thems and
16 applies a 1x1 convolution with 1 output channel."""
18 def __init__(self):
19 super().__init__()
20 self.conv = conv_with_kaiming_uniform(5, 1, 1, 1, 0)
22 def forward(self, x1, x2, x3, x4, x5):
23 x_cat = torch.cat([x1, x2, x3, x4, x5], dim=1)
24 x = self.conv(x_cat)
25 return x
28class HED(torch.nn.Module):
29 """HED head module.
31 Parameters
32 ----------
33 in_channels_list : list
34 number of channels for each feature map that is returned from backbone
35 """
37 def __init__(self, in_channels_list=None):
38 super().__init__()
39 (
40 in_conv_1_2_16,
41 in_upsample2,
42 in_upsample_4,
43 in_upsample_8,
44 in_upsample_16,
45 ) = in_channels_list
47 self.conv1_2_16 = torch.nn.Conv2d(in_conv_1_2_16, 1, 3, 1, 1)
48 # Upsample
49 self.upsample2 = UpsampleCropBlock(in_upsample2, 1, 4, 2, 0)
50 self.upsample4 = UpsampleCropBlock(in_upsample_4, 1, 8, 4, 0)
51 self.upsample8 = UpsampleCropBlock(in_upsample_8, 1, 16, 8, 0)
52 self.upsample16 = UpsampleCropBlock(in_upsample_16, 1, 32, 16, 0)
53 # Concat and Fuse
54 self.concatfuse = ConcatFuseBlock()
56 def forward(self, x):
57 """
58 Parameters
59 ----------
60 x : list
61 list of tensors as returned from the backbone network.
62 First element: height and width of input image.
63 Remaining elements: feature maps for each feature level.
65 Returns
66 -------
67 tensor : :py:class:`torch.Tensor`
68 """
69 hw = x[0]
70 conv1_2_16 = self.conv1_2_16(x[1])
71 upsample2 = self.upsample2(x[2], hw)
72 upsample4 = self.upsample4(x[3], hw)
73 upsample8 = self.upsample8(x[4], hw)
74 upsample16 = self.upsample16(x[5], hw)
75 concatfuse = self.concatfuse(
76 conv1_2_16, upsample2, upsample4, upsample8, upsample16
77 )
79 return (upsample2, upsample4, upsample8, upsample16, concatfuse)
82def hed(pretrained_backbone=True, progress=True):
83 """Builds HED by adding backbone and head together.
85 Parameters
86 ----------
88 pretrained_backbone : :py:class:`bool`, Optional
89 If set to ``True``, then loads a pre-trained version of the backbone
90 (not the head) for the DRIU network using VGG-16 trained for ImageNet
91 classification.
93 progress : :py:class:`bool`, Optional
94 If set to ``True``, and you decided to use a ``pretrained_backbone``,
95 then, shows a progress bar of the backbone model downloading if
96 download is necesssary.
99 Returns
100 -------
102 module : :py:class:`torch.nn.Module`
103 Network model for HED
104 """
106 backbone = vgg16_for_segmentation(
107 pretrained=pretrained_backbone,
108 progress=progress,
109 return_features=[3, 8, 14, 22, 29],
110 )
111 head = HED([64, 128, 256, 512, 512])
113 order = [("backbone", backbone), ("head", head)]
114 if pretrained_backbone:
115 from .normalizer import TorchVisionNormalizer
117 order = [("normalizer", TorchVisionNormalizer())] + order
119 model = torch.nn.Sequential(OrderedDict(order))
120 model.name = "hed"
121 return model