Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
4from collections import OrderedDict
6import torch
7import torch.nn
9from .backbones.vgg import vgg16_for_segmentation
10from .driu import ConcatFuseBlock
11from .make_layers import UpsampleCropBlock
14class DRIUPIX(torch.nn.Module):
15 """
16 DRIUPIX head module. DRIU with pixelshuffle instead of ConvTrans2D
18 Parameters
19 ----------
20 in_channels_list : list
21 number of channels for each feature map that is returned from backbone
22 """
24 def __init__(self, in_channels_list=None):
25 super(DRIUPIX, self).__init__()
26 (
27 in_conv_1_2_16,
28 in_upsample2,
29 in_upsample_4,
30 in_upsample_8,
31 ) = in_channels_list
33 self.conv1_2_16 = torch.nn.Conv2d(in_conv_1_2_16, 16, 3, 1, 1)
34 # Upsample layers
35 self.upsample2 = UpsampleCropBlock(
36 in_upsample2, 16, 4, 2, 0, pixelshuffle=True
37 )
38 self.upsample4 = UpsampleCropBlock(
39 in_upsample_4, 16, 8, 4, 0, pixelshuffle=True
40 )
41 self.upsample8 = UpsampleCropBlock(
42 in_upsample_8, 16, 16, 8, 0, pixelshuffle=True
43 )
45 # Concat and Fuse
46 self.concatfuse = ConcatFuseBlock()
48 def forward(self, x):
49 """
50 Parameters
51 ----------
52 x : list
53 list of tensors as returned from the backbone network.
54 First element: height and width of input image.
55 Remaining elements: feature maps for each feature level.
57 Returns
58 -------
59 :py:class:`torch.Tensor`
60 """
61 hw = x[0]
62 conv1_2_16 = self.conv1_2_16(x[1]) # conv1_2_16
63 upsample2 = self.upsample2(x[2], hw) # side-multi2-up
64 upsample4 = self.upsample4(x[3], hw) # side-multi3-up
65 upsample8 = self.upsample8(x[4], hw) # side-multi4-up
66 out = self.concatfuse(conv1_2_16, upsample2, upsample4, upsample8)
67 return out
70def driu_pix(pretrained_backbone=True, progress=True):
71 """Builds DRIU with pixelshuffle by adding backbone and head together
73 Parameters
74 ----------
76 pretrained_backbone : :py:class:`bool`, Optional
77 If set to ``True``, then loads a pre-trained version of the backbone
78 (not the head) for the DRIU network using VGG-16 trained for ImageNet
79 classification.
81 progress : :py:class:`bool`, Optional
82 If set to ``True``, and you decided to use a ``pretrained_backbone``,
83 then, shows a progress bar of the backbone model downloading if
84 download is necesssary.
87 Returns
88 -------
90 module : :py:class:`torch.nn.Module`
91 Network model for DRIU (vessel segmentation) with pixelshuffle
93 """
95 backbone = vgg16_for_segmentation(
96 pretrained=pretrained_backbone,
97 progress=progress,
98 return_features=[3, 8, 14, 22],
99 )
100 head = DRIUPIX([64, 128, 256, 512])
102 order = [("backbone", backbone), ("head", head)]
103 if pretrained_backbone:
104 from .normalizer import TorchVisionNormalizer
106 order = [("normalizer", TorchVisionNormalizer())] + order
108 model = torch.nn.Sequential(OrderedDict(order))
109 model.name = "driu-pix"
110 return model