Source code for bob.med.tb.models.densenet_rs

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import torch.nn as nn
import torchvision.models as models
from collections import OrderedDict
from .normalizer import TorchVisionNormalizer


[docs]class DensenetRS(nn.Module): """ Densenet121 module for radiological extraction """ def __init__(self): super(DensenetRS, self).__init__() # Load pretrained model self.model_ft = models.densenet121( weights=models.DenseNet121_Weights.DEFAULT ) # Adapt output features num_ftrs = self.model_ft.classifier.in_features self.model_ft.classifier = nn.Linear(num_ftrs, 14)
[docs] def forward(self, x): """ Parameters ---------- x : list list of tensors. Returns ------- tensor : :py:class:`torch.Tensor` """ return self.model_ft(x)
[docs]def build_densenetrs(): """ Build DensenetRS CNN Returns ------- module : :py:class:`torch.nn.Module` """ model = DensenetRS() model = [("normalizer", TorchVisionNormalizer()), ("model", model)] model = nn.Sequential(OrderedDict(model)) model.name = "DensenetRS" return model