Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
4import torch
5import torch.nn
7from torch.nn import Conv2d, ConvTranspose2d
10def conv_with_kaiming_uniform(
11 in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1
12):
13 conv = Conv2d(
14 in_channels,
15 out_channels,
16 kernel_size=kernel_size,
17 stride=stride,
18 padding=padding,
19 dilation=dilation,
20 bias=True,
21 )
22 # Caffe2 implementation uses XavierFill, which in fact
23 # corresponds to kaiming_uniform_ in PyTorch
24 torch.nn.init.kaiming_uniform_(conv.weight, a=1)
25 torch.nn.init.constant_(conv.bias, 0)
26 return conv
29def convtrans_with_kaiming_uniform(
30 in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1
31):
32 conv = ConvTranspose2d(
33 in_channels,
34 out_channels,
35 kernel_size=kernel_size,
36 stride=stride,
37 padding=padding,
38 dilation=dilation,
39 bias=True,
40 )
41 # Caffe2 implementation uses XavierFill, which in fact
42 # corresponds to kaiming_uniform_ in PyTorch
43 torch.nn.init.kaiming_uniform_(conv.weight, a=1)
44 torch.nn.init.constant_(conv.bias, 0)
45 return conv
48class UpsampleCropBlock(torch.nn.Module):
49 """
50 Combines Conv2d, ConvTransposed2d and Cropping. Simulates the caffe2 crop
51 layer in the forward function.
53 Used for DRIU and HED.
55 Parameters
56 ----------
58 in_channels : int
59 number of channels of intermediate layer
60 out_channels : int
61 number of output channels
62 up_kernel_size : int
63 kernel size for transposed convolution
64 up_stride : int
65 stride for transposed convolution
66 up_padding : int
67 padding for transposed convolution
69 """
71 def __init__(
72 self,
73 in_channels,
74 out_channels,
75 up_kernel_size,
76 up_stride,
77 up_padding,
78 pixelshuffle=False,
79 ):
80 super().__init__()
81 # NOTE: Kaiming init, replace with torch.nn.Conv2d and torch.nn.ConvTranspose2d to get original DRIU impl.
82 self.conv = conv_with_kaiming_uniform(
83 in_channels, out_channels, 3, 1, 1
84 )
85 if pixelshuffle:
86 self.upconv = PixelShuffle_ICNR(
87 out_channels, out_channels, scale=up_stride
88 )
89 else:
90 self.upconv = convtrans_with_kaiming_uniform(
91 out_channels,
92 out_channels,
93 up_kernel_size,
94 up_stride,
95 up_padding,
96 )
98 def forward(self, x, input_res):
99 """Forward pass of UpsampleBlock.
101 Upsampled feature maps are cropped to the resolution of the input
102 image.
104 Parameters
105 ----------
107 x : tuple
108 input channels
110 input_res : tuple
111 Resolution of the input image format ``(height, width)``
113 """
115 img_h = input_res[0]
116 img_w = input_res[1]
117 x = self.conv(x)
118 x = self.upconv(x)
119 # determine center crop
120 # height
121 up_h = x.shape[2]
122 h_crop = up_h - img_h
123 h_s = h_crop // 2
124 h_e = up_h - (h_crop - h_s)
125 # width
126 up_w = x.shape[3]
127 w_crop = up_w - img_w
128 w_s = w_crop // 2
129 w_e = up_w - (w_crop - w_s)
130 # perform crop
131 # needs explicit ranges for onnx export
132 x = x[:, :, h_s:h_e, w_s:w_e] # crop to input size
134 return x
137def ifnone(a, b):
138 "``a`` if ``a`` is not None, otherwise ``b``."
139 return b if a is None else a
142def icnr(x, scale=2, init=torch.nn.init.kaiming_normal_):
143 """https://docs.fast.ai/layers.html#PixelShuffle_ICNR
145 ICNR init of ``x``, with ``scale`` and ``init`` function.
146 """
148 ni, nf, h, w = x.shape
149 ni2 = int(ni / (scale ** 2))
150 k = init(torch.zeros([ni2, nf, h, w])).transpose(0, 1)
151 k = k.contiguous().view(ni2, nf, -1)
152 k = k.repeat(1, 1, scale ** 2)
153 k = k.contiguous().view([nf, ni, h, w]).transpose(0, 1)
154 x.data.copy_(k)
157class PixelShuffle_ICNR(torch.nn.Module):
158 """https://docs.fast.ai/layers.html#PixelShuffle_ICNR
160 Upsample by ``scale`` from ``ni`` filters to ``nf`` (default ``ni``), using
161 ``torch.nn.PixelShuffle``, ``icnr`` init, and ``weight_norm``.
162 """
164 def __init__(self, ni: int, nf: int = None, scale: int = 2):
165 super().__init__()
166 nf = ifnone(nf, ni)
167 self.conv = conv_with_kaiming_uniform(ni, nf * (scale ** 2), 1)
168 icnr(self.conv.weight)
169 self.shuf = torch.nn.PixelShuffle(scale)
170 # Blurring over (h*w) kernel
171 # "Super-Resolution using Convolutional Neural Networks without Any Checkerboard Artifacts"
172 # - https://arxiv.org/abs/1806.02658
173 self.pad = torch.nn.ReplicationPad2d((1, 0, 1, 0))
174 self.blur = torch.nn.AvgPool2d(2, stride=1)
175 self.relu = torch.nn.ReLU(inplace=True)
177 def forward(self, x):
178 x = self.shuf(self.relu(self.conv(x)))
179 x = self.blur(self.pad(x))
180 return x
183class UnetBlock(torch.nn.Module):
184 def __init__(
185 self, up_in_c, x_in_c, pixel_shuffle=False, middle_block=False
186 ):
187 super().__init__()
189 # middle block for VGG based U-Net
190 if middle_block:
191 up_out_c = up_in_c
192 else:
193 up_out_c = up_in_c // 2
194 cat_channels = x_in_c + up_out_c
195 inner_channels = cat_channels // 2
197 if pixel_shuffle:
198 self.upsample = PixelShuffle_ICNR(up_in_c, up_out_c)
199 else:
200 self.upsample = convtrans_with_kaiming_uniform(
201 up_in_c, up_out_c, 2, 2
202 )
203 self.convtrans1 = convtrans_with_kaiming_uniform(
204 cat_channels, inner_channels, 3, 1, 1
205 )
206 self.convtrans2 = convtrans_with_kaiming_uniform(
207 inner_channels, inner_channels, 3, 1, 1
208 )
209 self.relu = torch.nn.ReLU(inplace=True)
211 def forward(self, up_in, x_in):
212 up_out = self.upsample(up_in)
213 cat_x = torch.cat([up_out, x_in], dim=1)
214 x = self.relu(self.convtrans1(cat_x))
215 x = self.relu(self.convtrans2(x))
216 return x