Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
4from collections import OrderedDict
6import torch
7import torch.nn
9from .backbones.vgg import vgg16_for_segmentation
10from .make_layers import UpsampleCropBlock, conv_with_kaiming_uniform
13class ConcatFuseBlock(torch.nn.Module):
14 """
15 Takes in four feature maps with 16 channels each, concatenates them
16 and applies a 1x1 convolution with 1 output channel.
17 """
19 def __init__(self):
20 super().__init__()
21 self.conv = conv_with_kaiming_uniform(4 * 16, 1, 1, 1, 0)
23 def forward(self, x1, x2, x3, x4):
25 x_cat = torch.cat([x1, x2, x3, x4], dim=1)
26 x = self.conv(x_cat)
27 return x
30class DRIU(torch.nn.Module):
31 """
32 DRIU head module
34 Based on paper by [MANINIS-2016]_.
36 Parameters
37 ----------
38 in_channels_list : list
39 number of channels for each feature map that is returned from backbone
41 """
43 def __init__(self, in_channels_list=None):
44 super(DRIU, self).__init__()
45 (
46 in_conv_1_2_16,
47 in_upsample2,
48 in_upsample_4,
49 in_upsample_8,
50 ) = in_channels_list
52 self.conv1_2_16 = torch.nn.Conv2d(in_conv_1_2_16, 16, 3, 1, 1)
53 # Upsample layers
54 self.upsample2 = UpsampleCropBlock(in_upsample2, 16, 4, 2, 0)
55 self.upsample4 = UpsampleCropBlock(in_upsample_4, 16, 8, 4, 0)
56 self.upsample8 = UpsampleCropBlock(in_upsample_8, 16, 16, 8, 0)
58 # Concat and Fuse
59 self.concatfuse = ConcatFuseBlock()
61 def forward(self, x):
62 """
64 Parameters
65 ----------
67 x : list
68 list of tensors as returned from the backbone network. First
69 element: height and width of input image. Remaining elements:
70 feature maps for each feature level.
72 Returns
73 -------
75 tensor : :py:class:`torch.Tensor`
77 """
78 hw = x[0]
79 conv1_2_16 = self.conv1_2_16(x[1]) # conv1_2_16
80 upsample2 = self.upsample2(x[2], hw) # side-multi2-up
81 upsample4 = self.upsample4(x[3], hw) # side-multi3-up
82 upsample8 = self.upsample8(x[4], hw) # side-multi4-up
83 out = self.concatfuse(conv1_2_16, upsample2, upsample4, upsample8)
84 return out
87def driu(pretrained_backbone=True, progress=True):
88 """Builds DRIU for vessel segmentation by adding backbone and head together
91 Parameters
92 ----------
94 pretrained_backbone : :py:class:`bool`, Optional
95 If set to ``True``, then loads a pre-trained version of the backbone
96 (not the head) for the DRIU network using VGG-16 trained for ImageNet
97 classification.
99 progress : :py:class:`bool`, Optional
100 If set to ``True``, and you decided to use a ``pretrained_backbone``,
101 then, shows a progress bar of the backbone model downloading if
102 download is necesssary.
105 Returns
106 -------
108 module : :py:class:`torch.nn.Module`
109 Network model for DRIU (vessel segmentation)
111 """
113 backbone = vgg16_for_segmentation(
114 pretrained=pretrained_backbone,
115 progress=progress,
116 return_features=[3, 8, 14, 22],
117 )
118 head = DRIU([64, 128, 256, 512])
120 order = [("backbone", backbone), ("head", head)]
121 if pretrained_backbone:
122 from .normalizer import TorchVisionNormalizer
124 order = [("normalizer", TorchVisionNormalizer())] + order
126 model = torch.nn.Sequential(OrderedDict(order))
127 model.name = "driu"
128 return model