Coverage for /scratch/builds/bob/bob.med.tb/miniconda/conda-bld/bob.med.tb_1637571489937/_test_env_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placeho/lib/python3.8/site-packages/bob/med/tb/models/alexnet.py: 95%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

19 statements  

1#!/usr/bin/env python 

2# -*- coding: utf-8 -*- 

3 

4import torch 

5import torch.nn as nn 

6import torchvision.models as models 

7from collections import OrderedDict 

8from .normalizer import TorchVisionNormalizer 

9 

10class Alexnet(nn.Module): 

11 """ 

12 Alexnet module 

13 

14 Note: only usable with a normalized dataset 

15 

16 """ 

17 def __init__(self, pretrained=False): 

18 super(Alexnet, self).__init__() 

19 

20 # Load pretrained model 

21 self.model_ft = models.alexnet(pretrained=pretrained) 

22 

23 # Adapt output features 

24 self.model_ft.classifier[4] = nn.Linear(4096,512) 

25 self.model_ft.classifier[6] = nn.Linear(512,1) 

26 

27 def forward(self, x): 

28 """ 

29 

30 Parameters 

31 ---------- 

32 

33 x : list 

34 list of tensors. 

35 

36 Returns 

37 ------- 

38 

39 tensor : :py:class:`torch.Tensor` 

40 

41 """ 

42 

43 return self.model_ft(x) 

44 

45 

46def build_alexnet(pretrained=False): 

47 """ 

48 Build Alexnet CNN 

49 

50 Returns 

51 ------- 

52 

53 module : :py:class:`torch.nn.Module` 

54 

55 """ 

56 

57 model = Alexnet(pretrained=pretrained) 

58 model = [("normalizer", TorchVisionNormalizer()), 

59 ("model", model)] 

60 model = nn.Sequential(OrderedDict(model)) 

61 

62 model.name = "AlexNet" 

63 return model