Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1#!/usr/bin/env python 

2# -*- coding: utf-8 -*- 

3 

4import torch 

5import torch.nn 

6 

7from torch.nn import Conv2d, ConvTranspose2d 

8 

9 

10def conv_with_kaiming_uniform( 

11 in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1 

12): 

13 conv = Conv2d( 

14 in_channels, 

15 out_channels, 

16 kernel_size=kernel_size, 

17 stride=stride, 

18 padding=padding, 

19 dilation=dilation, 

20 bias=True, 

21 ) 

22 # Caffe2 implementation uses XavierFill, which in fact 

23 # corresponds to kaiming_uniform_ in PyTorch 

24 torch.nn.init.kaiming_uniform_(conv.weight, a=1) 

25 torch.nn.init.constant_(conv.bias, 0) 

26 return conv 

27 

28 

29def convtrans_with_kaiming_uniform( 

30 in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1 

31): 

32 conv = ConvTranspose2d( 

33 in_channels, 

34 out_channels, 

35 kernel_size=kernel_size, 

36 stride=stride, 

37 padding=padding, 

38 dilation=dilation, 

39 bias=True, 

40 ) 

41 # Caffe2 implementation uses XavierFill, which in fact 

42 # corresponds to kaiming_uniform_ in PyTorch 

43 torch.nn.init.kaiming_uniform_(conv.weight, a=1) 

44 torch.nn.init.constant_(conv.bias, 0) 

45 return conv 

46 

47 

48class UpsampleCropBlock(torch.nn.Module): 

49 """ 

50 Combines Conv2d, ConvTransposed2d and Cropping. Simulates the caffe2 crop 

51 layer in the forward function. 

52 

53 Used for DRIU and HED. 

54 

55 Parameters 

56 ---------- 

57 

58 in_channels : int 

59 number of channels of intermediate layer 

60 out_channels : int 

61 number of output channels 

62 up_kernel_size : int 

63 kernel size for transposed convolution 

64 up_stride : int 

65 stride for transposed convolution 

66 up_padding : int 

67 padding for transposed convolution 

68 

69 """ 

70 

71 def __init__( 

72 self, 

73 in_channels, 

74 out_channels, 

75 up_kernel_size, 

76 up_stride, 

77 up_padding, 

78 pixelshuffle=False, 

79 ): 

80 super().__init__() 

81 # NOTE: Kaiming init, replace with torch.nn.Conv2d and torch.nn.ConvTranspose2d to get original DRIU impl. 

82 self.conv = conv_with_kaiming_uniform( 

83 in_channels, out_channels, 3, 1, 1 

84 ) 

85 if pixelshuffle: 

86 self.upconv = PixelShuffle_ICNR( 

87 out_channels, out_channels, scale=up_stride 

88 ) 

89 else: 

90 self.upconv = convtrans_with_kaiming_uniform( 

91 out_channels, 

92 out_channels, 

93 up_kernel_size, 

94 up_stride, 

95 up_padding, 

96 ) 

97 

98 def forward(self, x, input_res): 

99 """Forward pass of UpsampleBlock. 

100 

101 Upsampled feature maps are cropped to the resolution of the input 

102 image. 

103 

104 Parameters 

105 ---------- 

106 

107 x : tuple 

108 input channels 

109 

110 input_res : tuple 

111 Resolution of the input image format ``(height, width)`` 

112 

113 """ 

114 

115 img_h = input_res[0] 

116 img_w = input_res[1] 

117 x = self.conv(x) 

118 x = self.upconv(x) 

119 # determine center crop 

120 # height 

121 up_h = x.shape[2] 

122 h_crop = up_h - img_h 

123 h_s = h_crop // 2 

124 h_e = up_h - (h_crop - h_s) 

125 # width 

126 up_w = x.shape[3] 

127 w_crop = up_w - img_w 

128 w_s = w_crop // 2 

129 w_e = up_w - (w_crop - w_s) 

130 # perform crop 

131 # needs explicit ranges for onnx export 

132 x = x[:, :, h_s:h_e, w_s:w_e] # crop to input size 

133 

134 return x 

135 

136 

137def ifnone(a, b): 

138 "``a`` if ``a`` is not None, otherwise ``b``." 

139 return b if a is None else a 

140 

141 

142def icnr(x, scale=2, init=torch.nn.init.kaiming_normal_): 

143 """https://docs.fast.ai/layers.html#PixelShuffle_ICNR 

144 

145 ICNR init of ``x``, with ``scale`` and ``init`` function. 

146 """ 

147 

148 ni, nf, h, w = x.shape 

149 ni2 = int(ni / (scale ** 2)) 

150 k = init(torch.zeros([ni2, nf, h, w])).transpose(0, 1) 

151 k = k.contiguous().view(ni2, nf, -1) 

152 k = k.repeat(1, 1, scale ** 2) 

153 k = k.contiguous().view([nf, ni, h, w]).transpose(0, 1) 

154 x.data.copy_(k) 

155 

156 

157class PixelShuffle_ICNR(torch.nn.Module): 

158 """https://docs.fast.ai/layers.html#PixelShuffle_ICNR 

159 

160 Upsample by ``scale`` from ``ni`` filters to ``nf`` (default ``ni``), using 

161 ``torch.nn.PixelShuffle``, ``icnr`` init, and ``weight_norm``. 

162 """ 

163 

164 def __init__(self, ni: int, nf: int = None, scale: int = 2): 

165 super().__init__() 

166 nf = ifnone(nf, ni) 

167 self.conv = conv_with_kaiming_uniform(ni, nf * (scale ** 2), 1) 

168 icnr(self.conv.weight) 

169 self.shuf = torch.nn.PixelShuffle(scale) 

170 # Blurring over (h*w) kernel 

171 # "Super-Resolution using Convolutional Neural Networks without Any Checkerboard Artifacts" 

172 # - https://arxiv.org/abs/1806.02658 

173 self.pad = torch.nn.ReplicationPad2d((1, 0, 1, 0)) 

174 self.blur = torch.nn.AvgPool2d(2, stride=1) 

175 self.relu = torch.nn.ReLU(inplace=True) 

176 

177 def forward(self, x): 

178 x = self.shuf(self.relu(self.conv(x))) 

179 x = self.blur(self.pad(x)) 

180 return x 

181 

182 

183class UnetBlock(torch.nn.Module): 

184 def __init__( 

185 self, up_in_c, x_in_c, pixel_shuffle=False, middle_block=False 

186 ): 

187 super().__init__() 

188 

189 # middle block for VGG based U-Net 

190 if middle_block: 

191 up_out_c = up_in_c 

192 else: 

193 up_out_c = up_in_c // 2 

194 cat_channels = x_in_c + up_out_c 

195 inner_channels = cat_channels // 2 

196 

197 if pixel_shuffle: 

198 self.upsample = PixelShuffle_ICNR(up_in_c, up_out_c) 

199 else: 

200 self.upsample = convtrans_with_kaiming_uniform( 

201 up_in_c, up_out_c, 2, 2 

202 ) 

203 self.convtrans1 = convtrans_with_kaiming_uniform( 

204 cat_channels, inner_channels, 3, 1, 1 

205 ) 

206 self.convtrans2 = convtrans_with_kaiming_uniform( 

207 inner_channels, inner_channels, 3, 1, 1 

208 ) 

209 self.relu = torch.nn.ReLU(inplace=True) 

210 

211 def forward(self, up_in, x_in): 

212 up_out = self.upsample(up_in) 

213 cat_x = torch.cat([up_out, x_in], dim=1) 

214 x = self.relu(self.convtrans1(cat_x)) 

215 x = self.relu(self.convtrans2(x)) 

216 return x