Coverage for /scratch/builds/bob/bob.med.tb/miniconda/conda-bld/bob.med.tb_1674079587905/_test_env_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placeho/lib/python3.10/site-packages/bob/med/tb/models/alexnet.py: 95%

19 statements  

« prev     ^ index     » next       coverage.py v7.0.5, created at 2023-01-18 22:14 +0000

1#!/usr/bin/env python 

2# -*- coding: utf-8 -*- 

3 

4import torch.nn as nn 

5import torchvision.models as models 

6from collections import OrderedDict 

7from .normalizer import TorchVisionNormalizer 

8 

9class Alexnet(nn.Module): 

10 """ 

11 Alexnet module 

12 

13 Note: only usable with a normalized dataset 

14 

15 """ 

16 def __init__(self, pretrained=False): 

17 super(Alexnet, self).__init__() 

18 

19 # Load pretrained model 

20 weights = None if pretrained is False else models.AlexNet_Weights.DEFAULT 

21 self.model_ft = models.alexnet(weights=weights) 

22 

23 # Adapt output features 

24 self.model_ft.classifier[4] = nn.Linear(4096,512) 

25 self.model_ft.classifier[6] = nn.Linear(512,1) 

26 

27 def forward(self, x): 

28 """ 

29 

30 Parameters 

31 ---------- 

32 

33 x : list 

34 list of tensors. 

35 

36 Returns 

37 ------- 

38 

39 tensor : :py:class:`torch.Tensor` 

40 

41 """ 

42 

43 return self.model_ft(x) 

44 

45 

46def build_alexnet(pretrained=False): 

47 """ 

48 Build Alexnet CNN 

49 

50 Returns 

51 ------- 

52 

53 module : :py:class:`torch.nn.Module` 

54 

55 """ 

56 

57 model = Alexnet(pretrained=pretrained) 

58 model = [("normalizer", TorchVisionNormalizer()), 

59 ("model", model)] 

60 model = nn.Sequential(OrderedDict(model)) 

61 

62 model.name = "AlexNet" 

63 return model