Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
4from collections import OrderedDict
6import torch
7import torch.nn
9from .backbones.vgg import vgg16_for_segmentation
10from .driu import ConcatFuseBlock
11from .make_layers import UpsampleCropBlock
14class DRIUOD(torch.nn.Module):
15 """
16 DRIU for optic disc segmentation head module
18 Parameters
19 ----------
20 in_channels_list : list
21 number of channels for each feature map that is returned from backbone
22 """
24 def __init__(self, in_channels_list=None):
25 super(DRIUOD, self).__init__()
26 (
27 in_upsample2,
28 in_upsample_4,
29 in_upsample_8,
30 in_upsample_16,
31 ) = in_channels_list
33 self.upsample2 = UpsampleCropBlock(in_upsample2, 16, 4, 2, 0)
34 # Upsample layers
35 self.upsample4 = UpsampleCropBlock(in_upsample_4, 16, 8, 4, 0)
36 self.upsample8 = UpsampleCropBlock(in_upsample_8, 16, 16, 8, 0)
37 self.upsample16 = UpsampleCropBlock(in_upsample_16, 16, 32, 16, 0)
39 # Concat and Fuse
40 self.concatfuse = ConcatFuseBlock()
42 def forward(self, x):
43 """
44 Parameters
45 ----------
46 x : list
47 list of tensors as returned from the backbone network.
48 First element: height and width of input image.
49 Remaining elements: feature maps for each feature level.
51 Returns
52 -------
53 :py:class:`torch.Tensor`
54 """
55 hw = x[0]
56 upsample2 = self.upsample2(x[1], hw) # side-multi2-up
57 upsample4 = self.upsample4(x[2], hw) # side-multi3-up
58 upsample8 = self.upsample8(x[3], hw) # side-multi4-up
59 upsample16 = self.upsample16(x[4], hw) # side-multi5-up
60 out = self.concatfuse(upsample2, upsample4, upsample8, upsample16)
61 return out
64def driu_od(pretrained_backbone=True, progress=True):
65 """Builds DRIU for Optical Disc by adding backbone and head together
67 Parameters
68 ----------
70 pretrained_backbone : :py:class:`bool`, Optional
71 If set to ``True``, then loads a pre-trained version of the backbone
72 (not the head) for the DRIU network using VGG-16 trained for ImageNet
73 classification.
75 progress : :py:class:`bool`, Optional
76 If set to ``True``, and you decided to use a ``pretrained_backbone``,
77 then, shows a progress bar of the backbone model downloading if
78 download is necesssary.
81 Returns
82 -------
84 module : :py:class:`torch.nn.Module`
85 Network model for DRIU (optic disc segmentation)
87 """
89 backbone = vgg16_for_segmentation(
90 pretrained=pretrained_backbone,
91 progress=progress,
92 return_features=[8, 14, 22, 29],
93 )
94 head = DRIUOD([128, 256, 512, 512])
96 order = [("backbone", backbone), ("head", head)]
97 if pretrained_backbone:
98 from .normalizer import TorchVisionNormalizer
100 order = [("normalizer", TorchVisionNormalizer())] + order
102 model = torch.nn.Sequential(OrderedDict(order))
103 model.name = "driu-od"
104 return model