Coverage for /scratch/builds/bob/bob.med.tb/miniconda/conda-bld/bob.med.tb_1637571489937/_test_env_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placeho/lib/python3.8/site-packages/bob/med/tb/models/densenet_rs.py: 95%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

19 statements  

1#!/usr/bin/env python 

2# -*- coding: utf-8 -*- 

3 

4import torch 

5import torch.nn as nn 

6import torchvision.models as models 

7from collections import OrderedDict 

8from .normalizer import TorchVisionNormalizer 

9 

10class DensenetRS(nn.Module): 

11 """ 

12 Densenet121 module for radiological extraction 

13 

14 """ 

15 def __init__(self): 

16 super(DensenetRS, self).__init__() 

17 

18 # Load pretrained model 

19 self.model_ft = models.densenet121(pretrained=True) 

20 

21 # Adapt output features 

22 num_ftrs = self.model_ft.classifier.in_features 

23 self.model_ft.classifier = nn.Linear(num_ftrs, 14) 

24 

25 def forward(self, x): 

26 """ 

27 

28 Parameters 

29 ---------- 

30 

31 x : list 

32 list of tensors. 

33 

34 Returns 

35 ------- 

36 

37 tensor : :py:class:`torch.Tensor` 

38 

39 """ 

40 

41 return self.model_ft(x) 

42 

43 

44def build_densenetrs(): 

45 """ 

46 Build DensenetRS CNN 

47 

48 Returns 

49 ------- 

50 

51 module : :py:class:`torch.nn.Module` 

52 

53 """ 

54 

55 model = DensenetRS() 

56 model = [("normalizer", TorchVisionNormalizer()), 

57 ("model", model)] 

58 model = nn.Sequential(OrderedDict(model)) 

59 

60 model.name = "DensenetRS" 

61 return model