Coverage for /scratch/builds/bob/bob.med.tb/miniconda/conda-bld/bob.med.tb_1637571489937/_test_env_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placeho/lib/python3.8/site-packages/bob/med/tb/models/pasa.py: 100%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

74 statements  

1#!/usr/bin/env python 

2# -*- coding: utf-8 -*- 

3 

4import torch 

5import torch.nn as nn 

6import torch.nn.functional as F 

7from collections import OrderedDict 

8from .normalizer import TorchVisionNormalizer 

9 

10class PASA(nn.Module): 

11 """ 

12 PASA module 

13 

14 Based on paper by [PASA-2019]_. 

15 

16 """ 

17 def __init__(self): 

18 super().__init__() 

19 # First convolution block 

20 self.fc1 = nn.Conv2d(1, 4, (3, 3), (2, 2), (1, 1)) 

21 self.fc2 = nn.Conv2d(4, 16, (3, 3), (2, 2), (1, 1)) 

22 self.fc3 = nn.Conv2d(1, 16, (1, 1), (4, 4)) 

23 

24 self.batchNorm2d_4 = nn.BatchNorm2d(4) 

25 self.batchNorm2d_16 = nn.BatchNorm2d(16) 

26 self.batchNorm2d_16_2 = nn.BatchNorm2d(16) 

27 

28 # Second convolution block 

29 self.fc4 = nn.Conv2d(16, 24, (3, 3), (1, 1), (1, 1)) 

30 self.fc5 = nn.Conv2d(24, 32, (3, 3), (1, 1), (1, 1)) 

31 self.fc6 = nn.Conv2d(16, 32, (1, 1), (1, 1)) # Original stride (2, 2) 

32 

33 self.batchNorm2d_24 = nn.BatchNorm2d(24) 

34 self.batchNorm2d_32 = nn.BatchNorm2d(32) 

35 self.batchNorm2d_32_2 = nn.BatchNorm2d(32) 

36 

37 # Third convolution block 

38 self.fc7 = nn.Conv2d(32, 40, (3, 3), (1, 1), (1, 1)) 

39 self.fc8 = nn.Conv2d(40, 48, (3, 3), (1, 1), (1, 1)) 

40 self.fc9 = nn.Conv2d(32, 48, (1, 1), (1, 1)) # Original stride (2, 2) 

41 

42 self.batchNorm2d_40 = nn.BatchNorm2d(40) 

43 self.batchNorm2d_48 = nn.BatchNorm2d(48) 

44 self.batchNorm2d_48_2 = nn.BatchNorm2d(48) 

45 

46 # Fourth convolution block 

47 self.fc10 = nn.Conv2d(48, 56, (3, 3), (1, 1), (1, 1)) 

48 self.fc11 = nn.Conv2d(56, 64, (3, 3), (1, 1), (1, 1)) 

49 self.fc12 = nn.Conv2d(48, 64, (1, 1), (1, 1)) # Original stride (2, 2) 

50 

51 self.batchNorm2d_56 = nn.BatchNorm2d(56) 

52 self.batchNorm2d_64 = nn.BatchNorm2d(64) 

53 self.batchNorm2d_64_2 = nn.BatchNorm2d(64) 

54 

55 # Fifth convolution block 

56 self.fc13 = nn.Conv2d(64, 72, (3, 3), (1, 1), (1, 1)) 

57 self.fc14 = nn.Conv2d(72, 80, (3, 3), (1, 1), (1, 1)) 

58 self.fc15 = nn.Conv2d(64, 80, (1, 1), (1, 1)) # Original stride (2, 2) 

59 

60 self.batchNorm2d_72 = nn.BatchNorm2d(72) 

61 self.batchNorm2d_80 = nn.BatchNorm2d(80) 

62 self.batchNorm2d_80_2 = nn.BatchNorm2d(80) 

63 

64 self.pool2d = nn.MaxPool2d((3, 3), (2, 2)) # Pool after conv. block 

65 self.dense = nn.Linear(80, 1) # Fully connected layer 

66 

67 def forward(self, x): 

68 """ 

69 

70 Parameters 

71 ---------- 

72 

73 x : list 

74 list of tensors. 

75 

76 Returns 

77 ------- 

78 

79 tensor : :py:class:`torch.Tensor` 

80 

81 """ 

82 

83 # First convolution block 

84 _x = x 

85 x = F.relu(self.batchNorm2d_4(self.fc1(x))) # 1st convolution 

86 x = F.relu(self.batchNorm2d_16(self.fc2(x))) # 2nd convolution 

87 x = (x + F.relu(self.batchNorm2d_16_2(self.fc3(_x))))/2 # Parallel 

88 x = self.pool2d(x) # Pooling 

89 

90 # Second convolution block 

91 _x = x 

92 x = F.relu(self.batchNorm2d_24(self.fc4(x))) # 1st convolution 

93 x = F.relu(self.batchNorm2d_32(self.fc5(x))) # 2nd convolution 

94 x = (x + F.relu(self.batchNorm2d_32_2(self.fc6(_x))))/2 # Parallel 

95 x = self.pool2d(x) # Pooling 

96 

97 # Third convolution block 

98 _x = x 

99 x = F.relu(self.batchNorm2d_40(self.fc7(x))) # 1st convolution 

100 x = F.relu(self.batchNorm2d_48(self.fc8(x))) # 2nd convolution 

101 x = (x + F.relu(self.batchNorm2d_48_2(self.fc9(_x))))/2 # Parallel 

102 x = self.pool2d(x) # Pooling 

103 

104 # Fourth convolution block 

105 _x = x 

106 x = F.relu(self.batchNorm2d_56(self.fc10(x))) # 1st convolution 

107 x = F.relu(self.batchNorm2d_64(self.fc11(x))) # 2nd convolution 

108 x = (x + F.relu(self.batchNorm2d_64_2(self.fc12(_x))))/2 # Parallel 

109 x = self.pool2d(x) # Pooling 

110 

111 # Fifth convolution block 

112 _x = x 

113 x = F.relu(self.batchNorm2d_72(self.fc13(x))) # 1st convolution 

114 x = F.relu(self.batchNorm2d_80(self.fc14(x))) # 2nd convolution 

115 x = (x + F.relu(self.batchNorm2d_80_2(self.fc15(_x))))/2 # Parallel 

116 # no pooling 

117 

118 # Global average pooling 

119 x = torch.mean(x.view(x.size(0), x.size(1), -1), dim=2) 

120 

121 # Dense layer 

122 x = self.dense(x) 

123 

124 # x = F.log_softmax(x, dim=1) # 0 is batch size 

125 

126 return x 

127 

128def build_pasa(): 

129 """ 

130 Build pasa CNN 

131 

132 Returns 

133 ------- 

134 

135 module : :py:class:`torch.nn.Module` 

136 

137 """ 

138 

139 model = PASA() 

140 model = [("normalizer", TorchVisionNormalizer(nb_channels=1)), 

141 ("model", model)] 

142 model = nn.Sequential(OrderedDict(model)) 

143 

144 model.name = "pasa" 

145 return model