Source code for bob.med.tb.models.densenet

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import torch.nn as nn
import torchvision.models as models
from collections import OrderedDict
from .normalizer import TorchVisionNormalizer

[docs]class Densenet(nn.Module): """ Densenet module Note: only usable with a normalized dataset """ def __init__(self, pretrained=False): super(Densenet, self).__init__() # Load pretrained model weights = None if not pretrained else models.DenseNet121_Weights.DEFAULT self.model_ft = models.densenet121(weights=weights) # Adapt output features self.model_ft.classifier = nn.Sequential( nn.Linear(1024,256), nn.Linear(256,1) )
[docs] def forward(self, x): """ Parameters ---------- x : list list of tensors. Returns ------- tensor : :py:class:`torch.Tensor` """ return self.model_ft(x)
[docs]def build_densenet(pretrained=False, nb_channels=3): """ Build Densenet CNN Returns ------- module : :py:class:`torch.nn.Module` """ model = Densenet(pretrained=pretrained) model = [("normalizer", TorchVisionNormalizer(nb_channels=nb_channels)), ("model", model)] model = nn.Sequential(OrderedDict(model)) model.name = "Densenet" return model