1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3
4import torch.nn as nn
5import torchvision.models as models
6from collections import OrderedDict
7from .normalizer import TorchVisionNormalizer
8
9class Alexnet(nn.Module):
10 """
11 Alexnet module
12
13 Note: only usable with a normalized dataset
14
15 """
16 def __init__(self, pretrained=False):
17 super(Alexnet, self).__init__()
18
19 # Load pretrained model
20 weights = None if pretrained is False else models.AlexNet_Weights.DEFAULT
21 self.model_ft = models.alexnet(weights=weights)
22
23 # Adapt output features
24 self.model_ft.classifier[4] = nn.Linear(4096,512)
25 self.model_ft.classifier[6] = nn.Linear(512,1)
26
27 def forward(self, x):
28 """
29
30 Parameters
31 ----------
32
33 x : list
34 list of tensors.
35
36 Returns
37 -------
38
39 tensor : :py:class:`torch.Tensor`
40
41 """
42
43 return self.model_ft(x)
44
45
46def build_alexnet(pretrained=False):
47 """
48 Build Alexnet CNN
49
50 Returns
51 -------
52
53 module : :py:class:`torch.nn.Module`
54
55 """
56
57 model = Alexnet(pretrained=pretrained)
58 model = [("normalizer", TorchVisionNormalizer()),
59 ("model", model)]
60 model = nn.Sequential(OrderedDict(model))
61
62 model.name = "AlexNet"
63 return model