Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1#!/usr/bin/env python 

2# -*- coding: utf-8 -*- 

3 

4from collections import OrderedDict 

5 

6import torch 

7import torch.nn 

8 

9from .backbones.vgg import vgg16_for_segmentation 

10from .make_layers import UpsampleCropBlock, conv_with_kaiming_uniform 

11 

12 

13class ConcatFuseBlock(torch.nn.Module): 

14 """ 

15 Takes in five feature maps with one channel each, concatenates thems 

16 and applies a 1x1 convolution with 1 output channel. 

17 """ 

18 

19 def __init__(self): 

20 super().__init__() 

21 self.conv = conv_with_kaiming_uniform(5, 1, 1, 1, 0) 

22 

23 def forward(self, x1, x2, x3, x4, x5): 

24 x_cat = torch.cat([x1, x2, x3, x4, x5], dim=1) 

25 x = self.conv(x_cat) 

26 return x 

27 

28 

29class HED(torch.nn.Module): 

30 """ 

31 HED head module 

32 

33 Parameters 

34 ---------- 

35 in_channels_list : list 

36 number of channels for each feature map that is returned from backbone 

37 """ 

38 

39 def __init__(self, in_channels_list=None): 

40 super(HED, self).__init__() 

41 ( 

42 in_conv_1_2_16, 

43 in_upsample2, 

44 in_upsample_4, 

45 in_upsample_8, 

46 in_upsample_16, 

47 ) = in_channels_list 

48 

49 self.conv1_2_16 = torch.nn.Conv2d(in_conv_1_2_16, 1, 3, 1, 1) 

50 # Upsample 

51 self.upsample2 = UpsampleCropBlock(in_upsample2, 1, 4, 2, 0) 

52 self.upsample4 = UpsampleCropBlock(in_upsample_4, 1, 8, 4, 0) 

53 self.upsample8 = UpsampleCropBlock(in_upsample_8, 1, 16, 8, 0) 

54 self.upsample16 = UpsampleCropBlock(in_upsample_16, 1, 32, 16, 0) 

55 # Concat and Fuse 

56 self.concatfuse = ConcatFuseBlock() 

57 

58 def forward(self, x): 

59 """ 

60 Parameters 

61 ---------- 

62 x : list 

63 list of tensors as returned from the backbone network. 

64 First element: height and width of input image. 

65 Remaining elements: feature maps for each feature level. 

66 

67 Returns 

68 ------- 

69 tensor : :py:class:`torch.Tensor` 

70 """ 

71 hw = x[0] 

72 conv1_2_16 = self.conv1_2_16(x[1]) 

73 upsample2 = self.upsample2(x[2], hw) 

74 upsample4 = self.upsample4(x[3], hw) 

75 upsample8 = self.upsample8(x[4], hw) 

76 upsample16 = self.upsample16(x[5], hw) 

77 concatfuse = self.concatfuse( 

78 conv1_2_16, upsample2, upsample4, upsample8, upsample16 

79 ) 

80 

81 return (upsample2, upsample4, upsample8, upsample16, concatfuse) 

82 

83 

84def hed(pretrained_backbone=True, progress=True): 

85 """Builds HED by adding backbone and head together 

86 

87 Parameters 

88 ---------- 

89 

90 pretrained_backbone : :py:class:`bool`, Optional 

91 If set to ``True``, then loads a pre-trained version of the backbone 

92 (not the head) for the DRIU network using VGG-16 trained for ImageNet 

93 classification. 

94 

95 progress : :py:class:`bool`, Optional 

96 If set to ``True``, and you decided to use a ``pretrained_backbone``, 

97 then, shows a progress bar of the backbone model downloading if 

98 download is necesssary. 

99 

100 

101 Returns 

102 ------- 

103 

104 module : :py:class:`torch.nn.Module` 

105 Network model for HED 

106 

107 """ 

108 

109 backbone = vgg16_for_segmentation( 

110 pretrained=pretrained_backbone, 

111 progress=progress, 

112 return_features=[3, 8, 14, 22, 29], 

113 ) 

114 head = HED([64, 128, 256, 512, 512]) 

115 

116 order = [("backbone", backbone), ("head", head)] 

117 if pretrained_backbone: 

118 from .normalizer import TorchVisionNormalizer 

119 

120 order = [("normalizer", TorchVisionNormalizer())] + order 

121 

122 model = torch.nn.Sequential(OrderedDict(order)) 

123 model.name = "hed" 

124 return model