Coverage for src/deepdraw/models/make_layers.py: 71%
79 statements
« prev ^ index » next coverage.py v7.3.1, created at 2023-11-30 15:00 +0100
« prev ^ index » next coverage.py v7.3.1, created at 2023-11-30 15:00 +0100
1# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
2#
3# SPDX-License-Identifier: GPL-3.0-or-later
5import torch
6import torch.nn
8from torch.nn import Conv2d, ConvTranspose2d
11def conv_with_kaiming_uniform(
12 in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1
13):
14 conv = Conv2d(
15 in_channels,
16 out_channels,
17 kernel_size=kernel_size,
18 stride=stride,
19 padding=padding,
20 dilation=dilation,
21 bias=True,
22 )
23 # Caffe2 implementation uses XavierFill, which in fact
24 # corresponds to kaiming_uniform_ in PyTorch
25 torch.nn.init.kaiming_uniform_(conv.weight, a=1)
26 torch.nn.init.constant_(conv.bias, 0)
27 return conv
30def convtrans_with_kaiming_uniform(
31 in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1
32):
33 conv = ConvTranspose2d(
34 in_channels,
35 out_channels,
36 kernel_size=kernel_size,
37 stride=stride,
38 padding=padding,
39 dilation=dilation,
40 bias=True,
41 )
42 # Caffe2 implementation uses XavierFill, which in fact
43 # corresponds to kaiming_uniform_ in PyTorch
44 torch.nn.init.kaiming_uniform_(conv.weight, a=1)
45 torch.nn.init.constant_(conv.bias, 0)
46 return conv
49class UpsampleCropBlock(torch.nn.Module):
50 """Combines Conv2d, ConvTransposed2d and Cropping. Simulates the caffe2
51 crop layer in the forward function.
53 Used for DRIU and HED.
55 Parameters
56 ----------
58 in_channels : int
59 number of channels of intermediate layer
60 out_channels : int
61 number of output channels
62 up_kernel_size : int
63 kernel size for transposed convolution
64 up_stride : int
65 stride for transposed convolution
66 up_padding : int
67 padding for transposed convolution
68 """
70 def __init__(
71 self,
72 in_channels,
73 out_channels,
74 up_kernel_size,
75 up_stride,
76 up_padding,
77 pixelshuffle=False,
78 ):
79 super().__init__()
80 # NOTE: Kaiming init, replace with torch.nn.Conv2d and torch.nn.ConvTranspose2d to get original DRIU impl.
81 self.conv = conv_with_kaiming_uniform(
82 in_channels, out_channels, 3, 1, 1
83 )
84 if pixelshuffle:
85 self.upconv = PixelShuffle_ICNR(
86 out_channels, out_channels, scale=up_stride
87 )
88 else:
89 self.upconv = convtrans_with_kaiming_uniform(
90 out_channels,
91 out_channels,
92 up_kernel_size,
93 up_stride,
94 up_padding,
95 )
97 def forward(self, x, input_res):
98 """Forward pass of UpsampleBlock.
100 Upsampled feature maps are cropped to the resolution of the input
101 image.
103 Parameters
104 ----------
106 x : tuple
107 input channels
109 input_res : tuple
110 Resolution of the input image format ``(height, width)``
111 """
113 img_h = input_res[0]
114 img_w = input_res[1]
115 x = self.conv(x)
116 x = self.upconv(x)
117 # determine center crop
118 # height
119 up_h = x.shape[2]
120 h_crop = up_h - img_h
121 h_s = h_crop // 2
122 h_e = up_h - (h_crop - h_s)
123 # width
124 up_w = x.shape[3]
125 w_crop = up_w - img_w
126 w_s = w_crop // 2
127 w_e = up_w - (w_crop - w_s)
128 # perform crop
129 # needs explicit ranges for onnx export
130 x = x[:, :, h_s:h_e, w_s:w_e] # crop to input size
132 return x
135def ifnone(a, b):
136 "``a`` if ``a`` is not None, otherwise ``b``."
137 return b if a is None else a
140def icnr(x, scale=2, init=torch.nn.init.kaiming_normal_):
141 """https://docs.fast.ai/layers.html#PixelShuffle_ICNR.
143 ICNR init of ``x``, with ``scale`` and ``init`` function.
144 """
146 ni, nf, h, w = x.shape
147 ni2 = int(ni / (scale**2))
148 k = init(torch.zeros([ni2, nf, h, w])).transpose(0, 1)
149 k = k.contiguous().view(ni2, nf, -1)
150 k = k.repeat(1, 1, scale**2)
151 k = k.contiguous().view([nf, ni, h, w]).transpose(0, 1)
152 x.data.copy_(k)
155class PixelShuffle_ICNR(torch.nn.Module):
156 """https://docs.fast.ai/layers.html#PixelShuffle_ICNR.
158 Upsample by ``scale`` from ``ni`` filters to ``nf`` (default
159 ``ni``), using ``torch.nn.PixelShuffle``, ``icnr`` init, and
160 ``weight_norm``.
161 """
163 def __init__(self, ni: int, nf: int = None, scale: int = 2):
164 super().__init__()
165 nf = ifnone(nf, ni)
166 self.conv = conv_with_kaiming_uniform(ni, nf * (scale**2), 1)
167 icnr(self.conv.weight)
168 self.shuf = torch.nn.PixelShuffle(scale)
169 # Blurring over (h*w) kernel
170 # "Super-Resolution using Convolutional Neural Networks without Any Checkerboard Artifacts"
171 # - https://arxiv.org/abs/1806.02658
172 self.pad = torch.nn.ReplicationPad2d((1, 0, 1, 0))
173 self.blur = torch.nn.AvgPool2d(2, stride=1)
174 self.relu = torch.nn.ReLU(inplace=True)
176 def forward(self, x):
177 x = self.shuf(self.relu(self.conv(x)))
178 x = self.blur(self.pad(x))
179 return x
182class UnetBlock(torch.nn.Module):
183 def __init__(
184 self, up_in_c, x_in_c, pixel_shuffle=False, middle_block=False
185 ):
186 super().__init__()
188 # middle block for VGG based U-Net
189 if middle_block:
190 up_out_c = up_in_c
191 else:
192 up_out_c = up_in_c // 2
193 cat_channels = x_in_c + up_out_c
194 inner_channels = cat_channels // 2
196 if pixel_shuffle:
197 self.upsample = PixelShuffle_ICNR(up_in_c, up_out_c)
198 else:
199 self.upsample = convtrans_with_kaiming_uniform(
200 up_in_c, up_out_c, 2, 2
201 )
202 self.convtrans1 = convtrans_with_kaiming_uniform(
203 cat_channels, inner_channels, 3, 1, 1
204 )
205 self.convtrans2 = convtrans_with_kaiming_uniform(
206 inner_channels, inner_channels, 3, 1, 1
207 )
208 self.relu = torch.nn.ReLU(inplace=True)
210 def forward(self, up_in, x_in):
211 up_out = self.upsample(up_in)
212 cat_x = torch.cat([up_out, x_in], dim=1)
213 x = self.relu(self.convtrans1(cat_x))
214 x = self.relu(self.convtrans2(x))
215 return x