Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1#!/usr/bin/env python 

2# -*- coding: utf-8 -*- 

3 

4from collections import OrderedDict 

5 

6import torch 

7import torch.nn 

8 

9from .backbones.vgg import vgg16_for_segmentation 

10from .driu import ConcatFuseBlock 

11from .make_layers import UpsampleCropBlock 

12 

13 

14class DRIUPIX(torch.nn.Module): 

15 """ 

16 DRIUPIX head module. DRIU with pixelshuffle instead of ConvTrans2D 

17 

18 Parameters 

19 ---------- 

20 in_channels_list : list 

21 number of channels for each feature map that is returned from backbone 

22 """ 

23 

24 def __init__(self, in_channels_list=None): 

25 super(DRIUPIX, self).__init__() 

26 ( 

27 in_conv_1_2_16, 

28 in_upsample2, 

29 in_upsample_4, 

30 in_upsample_8, 

31 ) = in_channels_list 

32 

33 self.conv1_2_16 = torch.nn.Conv2d(in_conv_1_2_16, 16, 3, 1, 1) 

34 # Upsample layers 

35 self.upsample2 = UpsampleCropBlock( 

36 in_upsample2, 16, 4, 2, 0, pixelshuffle=True 

37 ) 

38 self.upsample4 = UpsampleCropBlock( 

39 in_upsample_4, 16, 8, 4, 0, pixelshuffle=True 

40 ) 

41 self.upsample8 = UpsampleCropBlock( 

42 in_upsample_8, 16, 16, 8, 0, pixelshuffle=True 

43 ) 

44 

45 # Concat and Fuse 

46 self.concatfuse = ConcatFuseBlock() 

47 

48 def forward(self, x): 

49 """ 

50 Parameters 

51 ---------- 

52 x : list 

53 list of tensors as returned from the backbone network. 

54 First element: height and width of input image. 

55 Remaining elements: feature maps for each feature level. 

56 

57 Returns 

58 ------- 

59 :py:class:`torch.Tensor` 

60 """ 

61 hw = x[0] 

62 conv1_2_16 = self.conv1_2_16(x[1]) # conv1_2_16 

63 upsample2 = self.upsample2(x[2], hw) # side-multi2-up 

64 upsample4 = self.upsample4(x[3], hw) # side-multi3-up 

65 upsample8 = self.upsample8(x[4], hw) # side-multi4-up 

66 out = self.concatfuse(conv1_2_16, upsample2, upsample4, upsample8) 

67 return out 

68 

69 

70def driu_pix(pretrained_backbone=True, progress=True): 

71 """Builds DRIU with pixelshuffle by adding backbone and head together 

72 

73 Parameters 

74 ---------- 

75 

76 pretrained_backbone : :py:class:`bool`, Optional 

77 If set to ``True``, then loads a pre-trained version of the backbone 

78 (not the head) for the DRIU network using VGG-16 trained for ImageNet 

79 classification. 

80 

81 progress : :py:class:`bool`, Optional 

82 If set to ``True``, and you decided to use a ``pretrained_backbone``, 

83 then, shows a progress bar of the backbone model downloading if 

84 download is necesssary. 

85 

86 

87 Returns 

88 ------- 

89 

90 module : :py:class:`torch.nn.Module` 

91 Network model for DRIU (vessel segmentation) with pixelshuffle 

92 

93 """ 

94 

95 backbone = vgg16_for_segmentation( 

96 pretrained=pretrained_backbone, 

97 progress=progress, 

98 return_features=[3, 8, 14, 22], 

99 ) 

100 head = DRIUPIX([64, 128, 256, 512]) 

101 

102 order = [("backbone", backbone), ("head", head)] 

103 if pretrained_backbone: 

104 from .normalizer import TorchVisionNormalizer 

105 

106 order = [("normalizer", TorchVisionNormalizer())] + order 

107 

108 model = torch.nn.Sequential(OrderedDict(order)) 

109 model.name = "driu-pix" 

110 return model