Source code for bob.ip.binseg.models.driu_od

#!/usr/bin/env python
# -*- coding: utf-8 -*-

from collections import OrderedDict

import torch
import torch.nn

from .backbones.vgg import vgg16_for_segmentation
from .driu import ConcatFuseBlock
from .make_layers import UpsampleCropBlock


[docs]class DRIUOD(torch.nn.Module): """ DRIU for optic disc segmentation head module Parameters ---------- in_channels_list : list number of channels for each feature map that is returned from backbone """ def __init__(self, in_channels_list=None): super(DRIUOD, self).__init__() ( in_upsample2, in_upsample_4, in_upsample_8, in_upsample_16, ) = in_channels_list self.upsample2 = UpsampleCropBlock(in_upsample2, 16, 4, 2, 0) # Upsample layers self.upsample4 = UpsampleCropBlock(in_upsample_4, 16, 8, 4, 0) self.upsample8 = UpsampleCropBlock(in_upsample_8, 16, 16, 8, 0) self.upsample16 = UpsampleCropBlock(in_upsample_16, 16, 32, 16, 0) # Concat and Fuse self.concatfuse = ConcatFuseBlock()
[docs] def forward(self, x): """ Parameters ---------- x : list list of tensors as returned from the backbone network. First element: height and width of input image. Remaining elements: feature maps for each feature level. Returns ------- :py:class:`torch.Tensor` """ hw = x[0] upsample2 = self.upsample2(x[1], hw) # side-multi2-up upsample4 = self.upsample4(x[2], hw) # side-multi3-up upsample8 = self.upsample8(x[3], hw) # side-multi4-up upsample16 = self.upsample16(x[4], hw) # side-multi5-up out = self.concatfuse(upsample2, upsample4, upsample8, upsample16) return out
[docs]def driu_od(pretrained_backbone=True, progress=True): """Builds DRIU for Optical Disc by adding backbone and head together Parameters ---------- pretrained_backbone : :py:class:`bool`, Optional If set to ``True``, then loads a pre-trained version of the backbone (not the head) for the DRIU network using VGG-16 trained for ImageNet classification. progress : :py:class:`bool`, Optional If set to ``True``, and you decided to use a ``pretrained_backbone``, then, shows a progress bar of the backbone model downloading if download is necesssary. Returns ------- module : :py:class:`torch.nn.Module` Network model for DRIU (optic disc segmentation) """ backbone = vgg16_for_segmentation( pretrained=pretrained_backbone, progress=progress, return_features=[8, 14, 22, 29], ) head = DRIUOD([128, 256, 512, 512]) order = [("backbone", backbone), ("head", head)] if pretrained_backbone: from .normalizer import TorchVisionNormalizer order = [("normalizer", TorchVisionNormalizer())] + order model = torch.nn.Sequential(OrderedDict(order)) model.name = "driu-od" return model