1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3
4import torch
5import torch.nn as nn
6import torchvision.models as models
7from collections import OrderedDict
8from .normalizer import TorchVisionNormalizer
9
10class Densenet(nn.Module):
11 """
12 Densenet module
13
14 Note: only usable with a normalized dataset
15
16 """
17 def __init__(self, pretrained=False):
18 super(Densenet, self).__init__()
19
20 # Load pretrained model
21 self.model_ft = models.densenet121(pretrained=pretrained)
22
23 # Adapt output features
24 self.model_ft.classifier = nn.Sequential(
25 nn.Linear(1024,256),
26 nn.Linear(256,1)
27 )
28
29 def forward(self, x):
30 """
31
32 Parameters
33 ----------
34
35 x : list
36 list of tensors.
37
38 Returns
39 -------
40
41 tensor : :py:class:`torch.Tensor`
42
43 """
44 return self.model_ft(x)
45
46
47def build_densenet(pretrained=False, nb_channels=3):
48 """
49 Build Densenet CNN
50
51 Returns
52 -------
53
54 module : :py:class:`torch.nn.Module`
55
56 """
57
58 model = Densenet(pretrained=pretrained)
59 model = [("normalizer", TorchVisionNormalizer(nb_channels=nb_channels)),
60 ("model", model)]
61 model = nn.Sequential(OrderedDict(model))
62
63 model.name = "Densenet"
64 return model