Coverage for /scratch/builds/bob/bob.med.tb/miniconda/conda-bld/bob.med.tb_1674079587905/_test_env_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placeho/lib/python3.10/site-packages/bob/med/tb/models/densenet_rs.py: 94%

18 statements  

« prev     ^ index     » next       coverage.py v7.0.5, created at 2023-01-18 22:14 +0000

1#!/usr/bin/env python 

2# -*- coding: utf-8 -*- 

3 

4import torch.nn as nn 

5import torchvision.models as models 

6from collections import OrderedDict 

7from .normalizer import TorchVisionNormalizer 

8 

9 

10class DensenetRS(nn.Module): 

11 """ 

12 Densenet121 module for radiological extraction 

13 

14 """ 

15 

16 def __init__(self): 

17 super(DensenetRS, self).__init__() 

18 

19 # Load pretrained model 

20 self.model_ft = models.densenet121( 

21 weights=models.DenseNet121_Weights.DEFAULT 

22 ) 

23 

24 # Adapt output features 

25 num_ftrs = self.model_ft.classifier.in_features 

26 self.model_ft.classifier = nn.Linear(num_ftrs, 14) 

27 

28 def forward(self, x): 

29 """ 

30 

31 Parameters 

32 ---------- 

33 

34 x : list 

35 list of tensors. 

36 

37 Returns 

38 ------- 

39 

40 tensor : :py:class:`torch.Tensor` 

41 

42 """ 

43 

44 return self.model_ft(x) 

45 

46 

47def build_densenetrs(): 

48 """ 

49 Build DensenetRS CNN 

50 

51 Returns 

52 ------- 

53 

54 module : :py:class:`torch.nn.Module` 

55 

56 """ 

57 

58 model = DensenetRS() 

59 model = [("normalizer", TorchVisionNormalizer()), ("model", model)] 

60 model = nn.Sequential(OrderedDict(model)) 

61 

62 model.name = "DensenetRS" 

63 return model