Coverage for /scratch/builds/bob/bob.med.tb/miniconda/conda-bld/bob.med.tb_1674079587905/_test_env_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placeho/lib/python3.10/site-packages/bob/med/tb/models/densenet.py: 94%

18 statements  

« prev     ^ index     » next       coverage.py v7.0.5, created at 2023-01-18 22:14 +0000

1#!/usr/bin/env python 

2# -*- coding: utf-8 -*- 

3 

4import torch.nn as nn 

5import torchvision.models as models 

6from collections import OrderedDict 

7from .normalizer import TorchVisionNormalizer 

8 

9class Densenet(nn.Module): 

10 """ 

11 Densenet module 

12 

13 Note: only usable with a normalized dataset 

14 

15 """ 

16 def __init__(self, pretrained=False): 

17 super(Densenet, self).__init__() 

18 

19 # Load pretrained model 

20 weights = None if not pretrained else models.DenseNet121_Weights.DEFAULT 

21 self.model_ft = models.densenet121(weights=weights) 

22 

23 # Adapt output features 

24 self.model_ft.classifier = nn.Sequential( 

25 nn.Linear(1024,256), 

26 nn.Linear(256,1) 

27 ) 

28 

29 def forward(self, x): 

30 """ 

31 

32 Parameters 

33 ---------- 

34 

35 x : list 

36 list of tensors. 

37 

38 Returns 

39 ------- 

40 

41 tensor : :py:class:`torch.Tensor` 

42 

43 """ 

44 return self.model_ft(x) 

45 

46 

47def build_densenet(pretrained=False, nb_channels=3): 

48 """ 

49 Build Densenet CNN 

50 

51 Returns 

52 ------- 

53 

54 module : :py:class:`torch.nn.Module` 

55 

56 """ 

57 

58 model = Densenet(pretrained=pretrained) 

59 model = [("normalizer", TorchVisionNormalizer(nb_channels=nb_channels)), 

60 ("model", model)] 

61 model = nn.Sequential(OrderedDict(model)) 

62 

63 model.name = "Densenet" 

64 return model