Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1#!/usr/bin/env python 

2# -*- coding: utf-8 -*- 

3 

4from collections import OrderedDict 

5 

6import torch 

7import torch.nn 

8 

9from .backbones.vgg import vgg16_for_segmentation 

10from .driu import ConcatFuseBlock 

11from .make_layers import UpsampleCropBlock 

12 

13 

14class DRIUOD(torch.nn.Module): 

15 """ 

16 DRIU for optic disc segmentation head module 

17 

18 Parameters 

19 ---------- 

20 in_channels_list : list 

21 number of channels for each feature map that is returned from backbone 

22 """ 

23 

24 def __init__(self, in_channels_list=None): 

25 super(DRIUOD, self).__init__() 

26 ( 

27 in_upsample2, 

28 in_upsample_4, 

29 in_upsample_8, 

30 in_upsample_16, 

31 ) = in_channels_list 

32 

33 self.upsample2 = UpsampleCropBlock(in_upsample2, 16, 4, 2, 0) 

34 # Upsample layers 

35 self.upsample4 = UpsampleCropBlock(in_upsample_4, 16, 8, 4, 0) 

36 self.upsample8 = UpsampleCropBlock(in_upsample_8, 16, 16, 8, 0) 

37 self.upsample16 = UpsampleCropBlock(in_upsample_16, 16, 32, 16, 0) 

38 

39 # Concat and Fuse 

40 self.concatfuse = ConcatFuseBlock() 

41 

42 def forward(self, x): 

43 """ 

44 Parameters 

45 ---------- 

46 x : list 

47 list of tensors as returned from the backbone network. 

48 First element: height and width of input image. 

49 Remaining elements: feature maps for each feature level. 

50 

51 Returns 

52 ------- 

53 :py:class:`torch.Tensor` 

54 """ 

55 hw = x[0] 

56 upsample2 = self.upsample2(x[1], hw) # side-multi2-up 

57 upsample4 = self.upsample4(x[2], hw) # side-multi3-up 

58 upsample8 = self.upsample8(x[3], hw) # side-multi4-up 

59 upsample16 = self.upsample16(x[4], hw) # side-multi5-up 

60 out = self.concatfuse(upsample2, upsample4, upsample8, upsample16) 

61 return out 

62 

63 

64def driu_od(pretrained_backbone=True, progress=True): 

65 """Builds DRIU for Optical Disc by adding backbone and head together 

66 

67 Parameters 

68 ---------- 

69 

70 pretrained_backbone : :py:class:`bool`, Optional 

71 If set to ``True``, then loads a pre-trained version of the backbone 

72 (not the head) for the DRIU network using VGG-16 trained for ImageNet 

73 classification. 

74 

75 progress : :py:class:`bool`, Optional 

76 If set to ``True``, and you decided to use a ``pretrained_backbone``, 

77 then, shows a progress bar of the backbone model downloading if 

78 download is necesssary. 

79 

80 

81 Returns 

82 ------- 

83 

84 module : :py:class:`torch.nn.Module` 

85 Network model for DRIU (optic disc segmentation) 

86 

87 """ 

88 

89 backbone = vgg16_for_segmentation( 

90 pretrained=pretrained_backbone, 

91 progress=progress, 

92 return_features=[8, 14, 22, 29], 

93 ) 

94 head = DRIUOD([128, 256, 512, 512]) 

95 

96 order = [("backbone", backbone), ("head", head)] 

97 if pretrained_backbone: 

98 from .normalizer import TorchVisionNormalizer 

99 

100 order = [("normalizer", TorchVisionNormalizer())] + order 

101 

102 model = torch.nn.Sequential(OrderedDict(order)) 

103 model.name = "driu-od" 

104 return model