Coverage for /scratch/builds/bob/bob.med.tb/miniconda/conda-bld/bob.med.tb_1637571489937/_test_env_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placeho/lib/python3.8/site-packages/bob/med/tb/models/densenet.py: 94%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

18 statements  

1#!/usr/bin/env python 

2# -*- coding: utf-8 -*- 

3 

4import torch 

5import torch.nn as nn 

6import torchvision.models as models 

7from collections import OrderedDict 

8from .normalizer import TorchVisionNormalizer 

9 

10class Densenet(nn.Module): 

11 """ 

12 Densenet module 

13 

14 Note: only usable with a normalized dataset 

15 

16 """ 

17 def __init__(self, pretrained=False): 

18 super(Densenet, self).__init__() 

19 

20 # Load pretrained model 

21 self.model_ft = models.densenet121(pretrained=pretrained) 

22 

23 # Adapt output features 

24 self.model_ft.classifier = nn.Sequential( 

25 nn.Linear(1024,256), 

26 nn.Linear(256,1) 

27 ) 

28 

29 def forward(self, x): 

30 """ 

31 

32 Parameters 

33 ---------- 

34 

35 x : list 

36 list of tensors. 

37 

38 Returns 

39 ------- 

40 

41 tensor : :py:class:`torch.Tensor` 

42 

43 """ 

44 return self.model_ft(x) 

45 

46 

47def build_densenet(pretrained=False, nb_channels=3): 

48 """ 

49 Build Densenet CNN 

50 

51 Returns 

52 ------- 

53 

54 module : :py:class:`torch.nn.Module` 

55 

56 """ 

57 

58 model = Densenet(pretrained=pretrained) 

59 model = [("normalizer", TorchVisionNormalizer(nb_channels=nb_channels)), 

60 ("model", model)] 

61 model = nn.Sequential(OrderedDict(model)) 

62 

63 model.name = "Densenet" 

64 return model