Coverage for src/deepdraw/models/make_layers.py: 71%

79 statements  

« prev     ^ index     » next       coverage.py v7.3.1, created at 2023-11-30 15:00 +0100

1# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> 

2# 

3# SPDX-License-Identifier: GPL-3.0-or-later 

4 

5import torch 

6import torch.nn 

7 

8from torch.nn import Conv2d, ConvTranspose2d 

9 

10 

11def conv_with_kaiming_uniform( 

12 in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1 

13): 

14 conv = Conv2d( 

15 in_channels, 

16 out_channels, 

17 kernel_size=kernel_size, 

18 stride=stride, 

19 padding=padding, 

20 dilation=dilation, 

21 bias=True, 

22 ) 

23 # Caffe2 implementation uses XavierFill, which in fact 

24 # corresponds to kaiming_uniform_ in PyTorch 

25 torch.nn.init.kaiming_uniform_(conv.weight, a=1) 

26 torch.nn.init.constant_(conv.bias, 0) 

27 return conv 

28 

29 

30def convtrans_with_kaiming_uniform( 

31 in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1 

32): 

33 conv = ConvTranspose2d( 

34 in_channels, 

35 out_channels, 

36 kernel_size=kernel_size, 

37 stride=stride, 

38 padding=padding, 

39 dilation=dilation, 

40 bias=True, 

41 ) 

42 # Caffe2 implementation uses XavierFill, which in fact 

43 # corresponds to kaiming_uniform_ in PyTorch 

44 torch.nn.init.kaiming_uniform_(conv.weight, a=1) 

45 torch.nn.init.constant_(conv.bias, 0) 

46 return conv 

47 

48 

49class UpsampleCropBlock(torch.nn.Module): 

50 """Combines Conv2d, ConvTransposed2d and Cropping. Simulates the caffe2 

51 crop layer in the forward function. 

52 

53 Used for DRIU and HED. 

54 

55 Parameters 

56 ---------- 

57 

58 in_channels : int 

59 number of channels of intermediate layer 

60 out_channels : int 

61 number of output channels 

62 up_kernel_size : int 

63 kernel size for transposed convolution 

64 up_stride : int 

65 stride for transposed convolution 

66 up_padding : int 

67 padding for transposed convolution 

68 """ 

69 

70 def __init__( 

71 self, 

72 in_channels, 

73 out_channels, 

74 up_kernel_size, 

75 up_stride, 

76 up_padding, 

77 pixelshuffle=False, 

78 ): 

79 super().__init__() 

80 # NOTE: Kaiming init, replace with torch.nn.Conv2d and torch.nn.ConvTranspose2d to get original DRIU impl. 

81 self.conv = conv_with_kaiming_uniform( 

82 in_channels, out_channels, 3, 1, 1 

83 ) 

84 if pixelshuffle: 

85 self.upconv = PixelShuffle_ICNR( 

86 out_channels, out_channels, scale=up_stride 

87 ) 

88 else: 

89 self.upconv = convtrans_with_kaiming_uniform( 

90 out_channels, 

91 out_channels, 

92 up_kernel_size, 

93 up_stride, 

94 up_padding, 

95 ) 

96 

97 def forward(self, x, input_res): 

98 """Forward pass of UpsampleBlock. 

99 

100 Upsampled feature maps are cropped to the resolution of the input 

101 image. 

102 

103 Parameters 

104 ---------- 

105 

106 x : tuple 

107 input channels 

108 

109 input_res : tuple 

110 Resolution of the input image format ``(height, width)`` 

111 """ 

112 

113 img_h = input_res[0] 

114 img_w = input_res[1] 

115 x = self.conv(x) 

116 x = self.upconv(x) 

117 # determine center crop 

118 # height 

119 up_h = x.shape[2] 

120 h_crop = up_h - img_h 

121 h_s = h_crop // 2 

122 h_e = up_h - (h_crop - h_s) 

123 # width 

124 up_w = x.shape[3] 

125 w_crop = up_w - img_w 

126 w_s = w_crop // 2 

127 w_e = up_w - (w_crop - w_s) 

128 # perform crop 

129 # needs explicit ranges for onnx export 

130 x = x[:, :, h_s:h_e, w_s:w_e] # crop to input size 

131 

132 return x 

133 

134 

135def ifnone(a, b): 

136 "``a`` if ``a`` is not None, otherwise ``b``." 

137 return b if a is None else a 

138 

139 

140def icnr(x, scale=2, init=torch.nn.init.kaiming_normal_): 

141 """https://docs.fast.ai/layers.html#PixelShuffle_ICNR. 

142 

143 ICNR init of ``x``, with ``scale`` and ``init`` function. 

144 """ 

145 

146 ni, nf, h, w = x.shape 

147 ni2 = int(ni / (scale**2)) 

148 k = init(torch.zeros([ni2, nf, h, w])).transpose(0, 1) 

149 k = k.contiguous().view(ni2, nf, -1) 

150 k = k.repeat(1, 1, scale**2) 

151 k = k.contiguous().view([nf, ni, h, w]).transpose(0, 1) 

152 x.data.copy_(k) 

153 

154 

155class PixelShuffle_ICNR(torch.nn.Module): 

156 """https://docs.fast.ai/layers.html#PixelShuffle_ICNR. 

157 

158 Upsample by ``scale`` from ``ni`` filters to ``nf`` (default 

159 ``ni``), using ``torch.nn.PixelShuffle``, ``icnr`` init, and 

160 ``weight_norm``. 

161 """ 

162 

163 def __init__(self, ni: int, nf: int = None, scale: int = 2): 

164 super().__init__() 

165 nf = ifnone(nf, ni) 

166 self.conv = conv_with_kaiming_uniform(ni, nf * (scale**2), 1) 

167 icnr(self.conv.weight) 

168 self.shuf = torch.nn.PixelShuffle(scale) 

169 # Blurring over (h*w) kernel 

170 # "Super-Resolution using Convolutional Neural Networks without Any Checkerboard Artifacts" 

171 # - https://arxiv.org/abs/1806.02658 

172 self.pad = torch.nn.ReplicationPad2d((1, 0, 1, 0)) 

173 self.blur = torch.nn.AvgPool2d(2, stride=1) 

174 self.relu = torch.nn.ReLU(inplace=True) 

175 

176 def forward(self, x): 

177 x = self.shuf(self.relu(self.conv(x))) 

178 x = self.blur(self.pad(x)) 

179 return x 

180 

181 

182class UnetBlock(torch.nn.Module): 

183 def __init__( 

184 self, up_in_c, x_in_c, pixel_shuffle=False, middle_block=False 

185 ): 

186 super().__init__() 

187 

188 # middle block for VGG based U-Net 

189 if middle_block: 

190 up_out_c = up_in_c 

191 else: 

192 up_out_c = up_in_c // 2 

193 cat_channels = x_in_c + up_out_c 

194 inner_channels = cat_channels // 2 

195 

196 if pixel_shuffle: 

197 self.upsample = PixelShuffle_ICNR(up_in_c, up_out_c) 

198 else: 

199 self.upsample = convtrans_with_kaiming_uniform( 

200 up_in_c, up_out_c, 2, 2 

201 ) 

202 self.convtrans1 = convtrans_with_kaiming_uniform( 

203 cat_channels, inner_channels, 3, 1, 1 

204 ) 

205 self.convtrans2 = convtrans_with_kaiming_uniform( 

206 inner_channels, inner_channels, 3, 1, 1 

207 ) 

208 self.relu = torch.nn.ReLU(inplace=True) 

209 

210 def forward(self, up_in, x_in): 

211 up_out = self.upsample(up_in) 

212 cat_x = torch.cat([up_out, x_in], dim=1) 

213 x = self.relu(self.convtrans1(cat_x)) 

214 x = self.relu(self.convtrans2(x)) 

215 return x