Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
4from collections import OrderedDict
6import torch
7import torch.nn
9from .backbones.vgg import vgg16_for_segmentation
10from .make_layers import UpsampleCropBlock, conv_with_kaiming_uniform
13class ConcatFuseBlock(torch.nn.Module):
14 """
15 Takes in five feature maps with one channel each, concatenates thems
16 and applies a 1x1 convolution with 1 output channel.
17 """
19 def __init__(self):
20 super().__init__()
21 self.conv = conv_with_kaiming_uniform(5, 1, 1, 1, 0)
23 def forward(self, x1, x2, x3, x4, x5):
24 x_cat = torch.cat([x1, x2, x3, x4, x5], dim=1)
25 x = self.conv(x_cat)
26 return x
29class HED(torch.nn.Module):
30 """
31 HED head module
33 Parameters
34 ----------
35 in_channels_list : list
36 number of channels for each feature map that is returned from backbone
37 """
39 def __init__(self, in_channels_list=None):
40 super(HED, self).__init__()
41 (
42 in_conv_1_2_16,
43 in_upsample2,
44 in_upsample_4,
45 in_upsample_8,
46 in_upsample_16,
47 ) = in_channels_list
49 self.conv1_2_16 = torch.nn.Conv2d(in_conv_1_2_16, 1, 3, 1, 1)
50 # Upsample
51 self.upsample2 = UpsampleCropBlock(in_upsample2, 1, 4, 2, 0)
52 self.upsample4 = UpsampleCropBlock(in_upsample_4, 1, 8, 4, 0)
53 self.upsample8 = UpsampleCropBlock(in_upsample_8, 1, 16, 8, 0)
54 self.upsample16 = UpsampleCropBlock(in_upsample_16, 1, 32, 16, 0)
55 # Concat and Fuse
56 self.concatfuse = ConcatFuseBlock()
58 def forward(self, x):
59 """
60 Parameters
61 ----------
62 x : list
63 list of tensors as returned from the backbone network.
64 First element: height and width of input image.
65 Remaining elements: feature maps for each feature level.
67 Returns
68 -------
69 tensor : :py:class:`torch.Tensor`
70 """
71 hw = x[0]
72 conv1_2_16 = self.conv1_2_16(x[1])
73 upsample2 = self.upsample2(x[2], hw)
74 upsample4 = self.upsample4(x[3], hw)
75 upsample8 = self.upsample8(x[4], hw)
76 upsample16 = self.upsample16(x[5], hw)
77 concatfuse = self.concatfuse(
78 conv1_2_16, upsample2, upsample4, upsample8, upsample16
79 )
81 return (upsample2, upsample4, upsample8, upsample16, concatfuse)
84def hed(pretrained_backbone=True, progress=True):
85 """Builds HED by adding backbone and head together
87 Parameters
88 ----------
90 pretrained_backbone : :py:class:`bool`, Optional
91 If set to ``True``, then loads a pre-trained version of the backbone
92 (not the head) for the DRIU network using VGG-16 trained for ImageNet
93 classification.
95 progress : :py:class:`bool`, Optional
96 If set to ``True``, and you decided to use a ``pretrained_backbone``,
97 then, shows a progress bar of the backbone model downloading if
98 download is necesssary.
101 Returns
102 -------
104 module : :py:class:`torch.nn.Module`
105 Network model for HED
107 """
109 backbone = vgg16_for_segmentation(
110 pretrained=pretrained_backbone,
111 progress=progress,
112 return_features=[3, 8, 14, 22, 29],
113 )
114 head = HED([64, 128, 256, 512, 512])
116 order = [("backbone", backbone), ("head", head)]
117 if pretrained_backbone:
118 from .normalizer import TorchVisionNormalizer
120 order = [("normalizer", TorchVisionNormalizer())] + order
122 model = torch.nn.Sequential(OrderedDict(order))
123 model.name = "hed"
124 return model