1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3
4import torch
5import torch.nn as nn
6import torchvision.models as models
7from collections import OrderedDict
8from .normalizer import TorchVisionNormalizer
9
10class Alexnet(nn.Module):
11 """
12 Alexnet module
13
14 Note: only usable with a normalized dataset
15
16 """
17 def __init__(self, pretrained=False):
18 super(Alexnet, self).__init__()
19
20 # Load pretrained model
21 self.model_ft = models.alexnet(pretrained=pretrained)
22
23 # Adapt output features
24 self.model_ft.classifier[4] = nn.Linear(4096,512)
25 self.model_ft.classifier[6] = nn.Linear(512,1)
26
27 def forward(self, x):
28 """
29
30 Parameters
31 ----------
32
33 x : list
34 list of tensors.
35
36 Returns
37 -------
38
39 tensor : :py:class:`torch.Tensor`
40
41 """
42
43 return self.model_ft(x)
44
45
46def build_alexnet(pretrained=False):
47 """
48 Build Alexnet CNN
49
50 Returns
51 -------
52
53 module : :py:class:`torch.nn.Module`
54
55 """
56
57 model = Alexnet(pretrained=pretrained)
58 model = [("normalizer", TorchVisionNormalizer()),
59 ("model", model)]
60 model = nn.Sequential(OrderedDict(model))
61
62 model.name = "AlexNet"
63 return model