Coverage for src/deepdraw/models/driu_bn.py: 75%

40 statements  

« prev     ^ index     » next       coverage.py v7.3.1, created at 2023-11-30 15:00 +0100

1# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> 

2# 

3# SPDX-License-Identifier: GPL-3.0-or-later 

4 

5from collections import OrderedDict 

6 

7import torch 

8import torch.nn 

9 

10from .backbones.vgg import vgg16_bn_for_segmentation 

11from .make_layers import UpsampleCropBlock, conv_with_kaiming_uniform 

12 

13 

14class ConcatFuseBlock(torch.nn.Module): 

15 """Takes in four feature maps with 16 channels each, concatenates them and 

16 applies a 1x1 convolution with 1 output channel.""" 

17 

18 def __init__(self): 

19 super().__init__() 

20 self.conv = torch.nn.Sequential( 

21 conv_with_kaiming_uniform(4 * 16, 1, 1, 1, 0), 

22 torch.nn.BatchNorm2d(1), 

23 ) 

24 

25 def forward(self, x1, x2, x3, x4): 

26 x_cat = torch.cat([x1, x2, x3, x4], dim=1) 

27 x = self.conv(x_cat) 

28 return x 

29 

30 

31class DRIUBN(torch.nn.Module): 

32 """DRIU with Batch-Normalization head module. 

33 

34 Based on paper by [MANINIS-2016]_. 

35 

36 Parameters 

37 ---------- 

38 in_channels_list : list 

39 number of channels for each feature map that is returned from backbone 

40 """ 

41 

42 def __init__(self, in_channels_list=None): 

43 super().__init__() 

44 ( 

45 in_conv_1_2_16, 

46 in_upsample2, 

47 in_upsample_4, 

48 in_upsample_8, 

49 ) = in_channels_list 

50 

51 self.conv1_2_16 = torch.nn.Conv2d(in_conv_1_2_16, 16, 3, 1, 1) 

52 # Upsample layers 

53 self.upsample2 = UpsampleCropBlock(in_upsample2, 16, 4, 2, 0) 

54 self.upsample4 = UpsampleCropBlock(in_upsample_4, 16, 8, 4, 0) 

55 self.upsample8 = UpsampleCropBlock(in_upsample_8, 16, 16, 8, 0) 

56 

57 # Concat and Fuse 

58 self.concatfuse = ConcatFuseBlock() 

59 

60 def forward(self, x): 

61 """ 

62 Parameters 

63 ---------- 

64 x : list 

65 list of tensors as returned from the backbone network. 

66 First element: height and width of input image. 

67 Remaining elements: feature maps for each feature level. 

68 

69 Returns 

70 ------- 

71 :py:class:`torch.Tensor` 

72 """ 

73 hw = x[0] 

74 conv1_2_16 = self.conv1_2_16(x[1]) # conv1_2_16 

75 upsample2 = self.upsample2(x[2], hw) # side-multi2-up 

76 upsample4 = self.upsample4(x[3], hw) # side-multi3-up 

77 upsample8 = self.upsample8(x[4], hw) # side-multi4-up 

78 out = self.concatfuse(conv1_2_16, upsample2, upsample4, upsample8) 

79 return out 

80 

81 

82def driu_bn(pretrained_backbone=True, progress=True): 

83 """Builds DRIU with batch-normalization by adding backbone and head 

84 together. 

85 

86 Parameters 

87 ---------- 

88 

89 pretrained_backbone : :py:class:`bool`, Optional 

90 If set to ``True``, then loads a pre-trained version of the backbone 

91 (not the head) for the DRIU network using VGG-16 trained for ImageNet 

92 classification. 

93 

94 progress : :py:class:`bool`, Optional 

95 If set to ``True``, and you decided to use a ``pretrained_backbone``, 

96 then, shows a progress bar of the backbone model downloading if 

97 download is necesssary. 

98 

99 

100 Returns 

101 ------- 

102 

103 module : :py:class:`torch.nn.Module` 

104 Network model for DRIU (vessel segmentation) using batch normalization 

105 """ 

106 

107 backbone = vgg16_bn_for_segmentation( 

108 pretrained=False, return_features=[5, 12, 19, 29] 

109 ) 

110 head = DRIUBN([64, 128, 256, 512]) 

111 

112 order = [("backbone", backbone), ("head", head)] 

113 if pretrained_backbone: 

114 from .normalizer import TorchVisionNormalizer 

115 

116 order = [("normalizer", TorchVisionNormalizer())] + order 

117 

118 model = torch.nn.Sequential(OrderedDict(order)) 

119 model.name = "driu-bn" 

120 return model