1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3
4import torch
5import torch.nn as nn
6import torchvision.models as models
7from collections import OrderedDict
8from .normalizer import TorchVisionNormalizer
9
10class DensenetRS(nn.Module):
11 """
12 Densenet121 module for radiological extraction
13
14 """
15 def __init__(self):
16 super(DensenetRS, self).__init__()
17
18 # Load pretrained model
19 self.model_ft = models.densenet121(pretrained=True)
20
21 # Adapt output features
22 num_ftrs = self.model_ft.classifier.in_features
23 self.model_ft.classifier = nn.Linear(num_ftrs, 14)
24
25 def forward(self, x):
26 """
27
28 Parameters
29 ----------
30
31 x : list
32 list of tensors.
33
34 Returns
35 -------
36
37 tensor : :py:class:`torch.Tensor`
38
39 """
40
41 return self.model_ft(x)
42
43
44def build_densenetrs():
45 """
46 Build DensenetRS CNN
47
48 Returns
49 -------
50
51 module : :py:class:`torch.nn.Module`
52
53 """
54
55 model = DensenetRS()
56 model = [("normalizer", TorchVisionNormalizer()),
57 ("model", model)]
58 model = nn.Sequential(OrderedDict(model))
59
60 model.name = "DensenetRS"
61 return model