Coverage for src/deepdraw/models/resunet.py: 77%

35 statements  

« prev     ^ index     » next       coverage.py v7.3.1, created at 2023-11-30 15:00 +0100

1# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> 

2# 

3# SPDX-License-Identifier: GPL-3.0-or-later 

4 

5from collections import OrderedDict 

6 

7import torch.nn 

8 

9from .backbones.resnet import resnet50_for_segmentation 

10from .make_layers import ( 

11 PixelShuffle_ICNR, 

12 UnetBlock, 

13 conv_with_kaiming_uniform, 

14 convtrans_with_kaiming_uniform, 

15) 

16 

17 

18class ResUNet(torch.nn.Module): 

19 """UNet head module for ResNet backbones. 

20 

21 Parameters 

22 ---------- 

23 

24 in_channels_list : :py:class:`list`, Optional 

25 number of channels for each feature map that is returned from backbone 

26 

27 pixel_shuffle : :py:class:`bool`, Optional 

28 if should use pixel shuffling instead of pooling 

29 """ 

30 

31 def __init__(self, in_channels_list=None, pixel_shuffle=False): 

32 super().__init__() 

33 # number of channels 

34 c_decode1, c_decode2, c_decode3, c_decode4, c_decode5 = in_channels_list 

35 # number of channels for last upsampling operation 

36 c_decode0 = (c_decode1 + c_decode2 // 2) // 2 

37 

38 # build layers 

39 self.decode4 = UnetBlock(c_decode5, c_decode4, pixel_shuffle) 

40 self.decode3 = UnetBlock(c_decode4, c_decode3, pixel_shuffle) 

41 self.decode2 = UnetBlock(c_decode3, c_decode2, pixel_shuffle) 

42 self.decode1 = UnetBlock(c_decode2, c_decode1, pixel_shuffle) 

43 if pixel_shuffle: 

44 self.decode0 = PixelShuffle_ICNR(c_decode0, c_decode0) 

45 else: 

46 self.decode0 = convtrans_with_kaiming_uniform( 

47 c_decode0, c_decode0, 2, 2 

48 ) 

49 self.final = conv_with_kaiming_uniform(c_decode0, 1, 1) 

50 

51 def forward(self, x): 

52 """ 

53 Parameters 

54 ---------- 

55 x : list 

56 list of tensors as returned from the backbone network. 

57 First element: height and width of input image. 

58 Remaining elements: feature maps for each feature level. 

59 """ 

60 # NOTE: x[0]: height and width of input image not needed in U-Net 

61 # architecture 

62 decode4 = self.decode4(x[5], x[4]) 

63 decode3 = self.decode3(decode4, x[3]) 

64 decode2 = self.decode2(decode3, x[2]) 

65 decode1 = self.decode1(decode2, x[1]) 

66 decode0 = self.decode0(decode1) 

67 out = self.final(decode0) 

68 return out 

69 

70 

71def resunet50(pretrained_backbone=True, progress=True): 

72 """Builds Residual-U-Net-50 by adding backbone and head together. 

73 

74 Parameters 

75 ---------- 

76 

77 pretrained_backbone : :py:class:`bool`, Optional 

78 If set to ``True``, then loads a pre-trained version of the backbone 

79 (not the head) for the DRIU network using VGG-16 trained for ImageNet 

80 classification. 

81 

82 progress : :py:class:`bool`, Optional 

83 If set to ``True``, and you decided to use a ``pretrained_backbone``, 

84 then, shows a progress bar of the backbone model downloading if 

85 download is necesssary. 

86 

87 

88 Returns 

89 ------- 

90 

91 module : :py:class:`torch.nn.Module` 

92 Network model for Residual U-Net 50 

93 """ 

94 

95 backbone = resnet50_for_segmentation( 

96 pretrained=pretrained_backbone, 

97 progress=progress, 

98 return_features=[2, 4, 5, 6, 7], 

99 ) 

100 head = ResUNet([64, 256, 512, 1024, 2048], pixel_shuffle=False) 

101 

102 order = [("backbone", backbone), ("head", head)] 

103 if pretrained_backbone: 

104 from .normalizer import TorchVisionNormalizer 

105 

106 order = [("normalizer", TorchVisionNormalizer())] + order 

107 

108 model = torch.nn.Sequential(OrderedDict(order)) 

109 model.name = "resunet50" 

110 return model