Coverage for src/deepdraw/models/unet.py: 80%

30 statements  

« prev     ^ index     » next       coverage.py v7.3.1, created at 2023-11-30 15:00 +0100

1# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> 

2# 

3# SPDX-License-Identifier: GPL-3.0-or-later 

4 

5from collections import OrderedDict 

6 

7import torch.nn 

8 

9from .backbones.vgg import vgg16_for_segmentation 

10from .make_layers import UnetBlock, conv_with_kaiming_uniform 

11 

12 

13class UNet(torch.nn.Module): 

14 """UNet head module. 

15 

16 Parameters 

17 ---------- 

18 in_channels_list : list 

19 number of channels for each feature map that is returned from backbone 

20 """ 

21 

22 def __init__(self, in_channels_list=None, pixel_shuffle=False): 

23 super().__init__() 

24 # number of channels 

25 c_decode1, c_decode2, c_decode3, c_decode4, c_decode5 = in_channels_list 

26 

27 # build layers 

28 self.decode4 = UnetBlock( 

29 c_decode5, c_decode4, pixel_shuffle, middle_block=True 

30 ) 

31 self.decode3 = UnetBlock(c_decode4, c_decode3, pixel_shuffle) 

32 self.decode2 = UnetBlock(c_decode3, c_decode2, pixel_shuffle) 

33 self.decode1 = UnetBlock(c_decode2, c_decode1, pixel_shuffle) 

34 self.final = conv_with_kaiming_uniform(c_decode1, 1, 1) 

35 

36 def forward(self, x): 

37 """ 

38 Parameters 

39 ---------- 

40 x : list 

41 list of tensors as returned from the backbone network. 

42 First element: height and width of input image. 

43 Remaining elements: feature maps for each feature level. 

44 """ 

45 # NOTE: x[0]: height and width of input image not needed in U-Net architecture 

46 decode4 = self.decode4(x[5], x[4]) 

47 decode3 = self.decode3(decode4, x[3]) 

48 decode2 = self.decode2(decode3, x[2]) 

49 decode1 = self.decode1(decode2, x[1]) 

50 out = self.final(decode1) 

51 return out 

52 

53 

54def unet(pretrained_backbone=True, progress=True): 

55 """Builds U-Net segmentation network by adding backbone and head together. 

56 

57 Parameters 

58 ---------- 

59 

60 pretrained_backbone : :py:class:`bool`, Optional 

61 If set to ``True``, then loads a pre-trained version of the backbone 

62 (not the head) for the DRIU network using VGG-16 trained for ImageNet 

63 classification. 

64 

65 progress : :py:class:`bool`, Optional 

66 If set to ``True``, and you decided to use a ``pretrained_backbone``, 

67 then, shows a progress bar of the backbone model downloading if 

68 download is necesssary. 

69 

70 

71 Returns 

72 ------- 

73 

74 module : :py:class:`torch.nn.Module` 

75 Network model for U-Net 

76 """ 

77 

78 backbone = vgg16_for_segmentation( 

79 pretrained=pretrained_backbone, 

80 progress=progress, 

81 return_features=[3, 8, 14, 22, 29], 

82 ) 

83 head = UNet([64, 128, 256, 512, 512], pixel_shuffle=False) 

84 

85 order = [("backbone", backbone), ("head", head)] 

86 if pretrained_backbone: 

87 from .normalizer import TorchVisionNormalizer 

88 

89 order = [("normalizer", TorchVisionNormalizer())] + order 

90 

91 model = torch.nn.Sequential(OrderedDict(order)) 

92 model.name = "unet" 

93 return model