Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1#!/usr/bin/env python 

2# -*- coding: utf-8 -*- 

3 

4from collections import OrderedDict 

5 

6import torch 

7import torch.nn 

8 

9from .backbones.vgg import vgg16_for_segmentation 

10from .make_layers import UpsampleCropBlock, conv_with_kaiming_uniform 

11 

12 

13class ConcatFuseBlock(torch.nn.Module): 

14 """ 

15 Takes in four feature maps with 16 channels each, concatenates them 

16 and applies a 1x1 convolution with 1 output channel. 

17 """ 

18 

19 def __init__(self): 

20 super().__init__() 

21 self.conv = conv_with_kaiming_uniform(4 * 16, 1, 1, 1, 0) 

22 

23 def forward(self, x1, x2, x3, x4): 

24 

25 x_cat = torch.cat([x1, x2, x3, x4], dim=1) 

26 x = self.conv(x_cat) 

27 return x 

28 

29 

30class DRIU(torch.nn.Module): 

31 """ 

32 DRIU head module 

33 

34 Based on paper by [MANINIS-2016]_. 

35 

36 Parameters 

37 ---------- 

38 in_channels_list : list 

39 number of channels for each feature map that is returned from backbone 

40 

41 """ 

42 

43 def __init__(self, in_channels_list=None): 

44 super(DRIU, self).__init__() 

45 ( 

46 in_conv_1_2_16, 

47 in_upsample2, 

48 in_upsample_4, 

49 in_upsample_8, 

50 ) = in_channels_list 

51 

52 self.conv1_2_16 = torch.nn.Conv2d(in_conv_1_2_16, 16, 3, 1, 1) 

53 # Upsample layers 

54 self.upsample2 = UpsampleCropBlock(in_upsample2, 16, 4, 2, 0) 

55 self.upsample4 = UpsampleCropBlock(in_upsample_4, 16, 8, 4, 0) 

56 self.upsample8 = UpsampleCropBlock(in_upsample_8, 16, 16, 8, 0) 

57 

58 # Concat and Fuse 

59 self.concatfuse = ConcatFuseBlock() 

60 

61 def forward(self, x): 

62 """ 

63 

64 Parameters 

65 ---------- 

66 

67 x : list 

68 list of tensors as returned from the backbone network. First 

69 element: height and width of input image. Remaining elements: 

70 feature maps for each feature level. 

71 

72 Returns 

73 ------- 

74 

75 tensor : :py:class:`torch.Tensor` 

76 

77 """ 

78 hw = x[0] 

79 conv1_2_16 = self.conv1_2_16(x[1]) # conv1_2_16 

80 upsample2 = self.upsample2(x[2], hw) # side-multi2-up 

81 upsample4 = self.upsample4(x[3], hw) # side-multi3-up 

82 upsample8 = self.upsample8(x[4], hw) # side-multi4-up 

83 out = self.concatfuse(conv1_2_16, upsample2, upsample4, upsample8) 

84 return out 

85 

86 

87def driu(pretrained_backbone=True, progress=True): 

88 """Builds DRIU for vessel segmentation by adding backbone and head together 

89 

90 

91 Parameters 

92 ---------- 

93 

94 pretrained_backbone : :py:class:`bool`, Optional 

95 If set to ``True``, then loads a pre-trained version of the backbone 

96 (not the head) for the DRIU network using VGG-16 trained for ImageNet 

97 classification. 

98 

99 progress : :py:class:`bool`, Optional 

100 If set to ``True``, and you decided to use a ``pretrained_backbone``, 

101 then, shows a progress bar of the backbone model downloading if 

102 download is necesssary. 

103 

104 

105 Returns 

106 ------- 

107 

108 module : :py:class:`torch.nn.Module` 

109 Network model for DRIU (vessel segmentation) 

110 

111 """ 

112 

113 backbone = vgg16_for_segmentation( 

114 pretrained=pretrained_backbone, 

115 progress=progress, 

116 return_features=[3, 8, 14, 22], 

117 ) 

118 head = DRIU([64, 128, 256, 512]) 

119 

120 order = [("backbone", backbone), ("head", head)] 

121 if pretrained_backbone: 

122 from .normalizer import TorchVisionNormalizer 

123 

124 order = [("normalizer", TorchVisionNormalizer())] + order 

125 

126 model = torch.nn.Sequential(OrderedDict(order)) 

127 model.name = "driu" 

128 return model