Coverage for src/deepdraw/models/hed.py: 74%

42 statements  

« prev     ^ index     » next       coverage.py v7.3.1, created at 2023-11-30 15:00 +0100

1# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> 

2# 

3# SPDX-License-Identifier: GPL-3.0-or-later 

4 

5from collections import OrderedDict 

6 

7import torch 

8import torch.nn 

9 

10from .backbones.vgg import vgg16_for_segmentation 

11from .make_layers import UpsampleCropBlock, conv_with_kaiming_uniform 

12 

13 

14class ConcatFuseBlock(torch.nn.Module): 

15 """Takes in five feature maps with one channel each, concatenates thems and 

16 applies a 1x1 convolution with 1 output channel.""" 

17 

18 def __init__(self): 

19 super().__init__() 

20 self.conv = conv_with_kaiming_uniform(5, 1, 1, 1, 0) 

21 

22 def forward(self, x1, x2, x3, x4, x5): 

23 x_cat = torch.cat([x1, x2, x3, x4, x5], dim=1) 

24 x = self.conv(x_cat) 

25 return x 

26 

27 

28class HED(torch.nn.Module): 

29 """HED head module. 

30 

31 Parameters 

32 ---------- 

33 in_channels_list : list 

34 number of channels for each feature map that is returned from backbone 

35 """ 

36 

37 def __init__(self, in_channels_list=None): 

38 super().__init__() 

39 ( 

40 in_conv_1_2_16, 

41 in_upsample2, 

42 in_upsample_4, 

43 in_upsample_8, 

44 in_upsample_16, 

45 ) = in_channels_list 

46 

47 self.conv1_2_16 = torch.nn.Conv2d(in_conv_1_2_16, 1, 3, 1, 1) 

48 # Upsample 

49 self.upsample2 = UpsampleCropBlock(in_upsample2, 1, 4, 2, 0) 

50 self.upsample4 = UpsampleCropBlock(in_upsample_4, 1, 8, 4, 0) 

51 self.upsample8 = UpsampleCropBlock(in_upsample_8, 1, 16, 8, 0) 

52 self.upsample16 = UpsampleCropBlock(in_upsample_16, 1, 32, 16, 0) 

53 # Concat and Fuse 

54 self.concatfuse = ConcatFuseBlock() 

55 

56 def forward(self, x): 

57 """ 

58 Parameters 

59 ---------- 

60 x : list 

61 list of tensors as returned from the backbone network. 

62 First element: height and width of input image. 

63 Remaining elements: feature maps for each feature level. 

64 

65 Returns 

66 ------- 

67 tensor : :py:class:`torch.Tensor` 

68 """ 

69 hw = x[0] 

70 conv1_2_16 = self.conv1_2_16(x[1]) 

71 upsample2 = self.upsample2(x[2], hw) 

72 upsample4 = self.upsample4(x[3], hw) 

73 upsample8 = self.upsample8(x[4], hw) 

74 upsample16 = self.upsample16(x[5], hw) 

75 concatfuse = self.concatfuse( 

76 conv1_2_16, upsample2, upsample4, upsample8, upsample16 

77 ) 

78 

79 return (upsample2, upsample4, upsample8, upsample16, concatfuse) 

80 

81 

82def hed(pretrained_backbone=True, progress=True): 

83 """Builds HED by adding backbone and head together. 

84 

85 Parameters 

86 ---------- 

87 

88 pretrained_backbone : :py:class:`bool`, Optional 

89 If set to ``True``, then loads a pre-trained version of the backbone 

90 (not the head) for the DRIU network using VGG-16 trained for ImageNet 

91 classification. 

92 

93 progress : :py:class:`bool`, Optional 

94 If set to ``True``, and you decided to use a ``pretrained_backbone``, 

95 then, shows a progress bar of the backbone model downloading if 

96 download is necesssary. 

97 

98 

99 Returns 

100 ------- 

101 

102 module : :py:class:`torch.nn.Module` 

103 Network model for HED 

104 """ 

105 

106 backbone = vgg16_for_segmentation( 

107 pretrained=pretrained_backbone, 

108 progress=progress, 

109 return_features=[3, 8, 14, 22, 29], 

110 ) 

111 head = HED([64, 128, 256, 512, 512]) 

112 

113 order = [("backbone", backbone), ("head", head)] 

114 if pretrained_backbone: 

115 from .normalizer import TorchVisionNormalizer 

116 

117 order = [("normalizer", TorchVisionNormalizer())] + order 

118 

119 model = torch.nn.Sequential(OrderedDict(order)) 

120 model.name = "hed" 

121 return model