Coverage for src/deepdraw/models/backbones/vgg.py: 74%

34 statements  

« prev     ^ index     » next       coverage.py v7.3.1, created at 2023-11-30 15:00 +0100

1# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> 

2# 

3# SPDX-License-Identifier: GPL-3.0-or-later 

4 

5import torchvision.models 

6 

7try: 

8 # pytorch >= 1.12 

9 from torch.hub import load_state_dict_from_url 

10except ImportError: 

11 # pytorch < 1.12 

12 from torchvision.models.utils import load_state_dict_from_url 

13 

14 

15class VGG4Segmentation(torchvision.models.vgg.VGG): 

16 """Adaptation of base VGG functionality to U-Net style segmentation. 

17 

18 This version of VGG is slightly modified so it can be used through 

19 torchvision's API. It outputs intermediate features which are normally not 

20 output by the base VGG implementation, but are required for segmentation 

21 operations. 

22 

23 

24 Parameters 

25 ========== 

26 

27 return_features : :py:class:`list`, Optional 

28 A list of integers indicating the feature layers to be returned from 

29 the original module. 

30 """ 

31 

32 def __init__(self, *args, **kwargs): 

33 self._return_features = kwargs.pop("return_features") 

34 super().__init__(*args, **kwargs) 

35 

36 def forward(self, x): 

37 outputs = [] 

38 # hardwiring of input 

39 outputs.append(x.shape[2:4]) 

40 for index, m in enumerate(self.features): 

41 x = m(x) 

42 # extract layers 

43 if index in self._return_features: 

44 outputs.append(x) 

45 return outputs 

46 

47 

48def _make_vgg16_typeD_for_segmentation( 

49 pretrained, batch_norm, progress, **kwargs 

50): 

51 if pretrained: 

52 kwargs["init_weights"] = False 

53 

54 model = VGG4Segmentation( 

55 torchvision.models.vgg.make_layers( 

56 torchvision.models.vgg.cfgs["D"], 

57 batch_norm=batch_norm, 

58 ), 

59 **kwargs, 

60 ) 

61 

62 if pretrained: 

63 weights = ( 

64 torchvision.models.vgg.VGG16_Weights.DEFAULT.url 

65 if not batch_norm 

66 else torchvision.models.vgg.VGG16_BN_Weights.DEFAULT.url 

67 ) 

68 

69 state_dict = load_state_dict_from_url(weights, progress=progress) 

70 model.load_state_dict(state_dict) 

71 

72 # erase VGG head (for classification), not used for segmentation 

73 delattr(model, "classifier") 

74 delattr(model, "avgpool") 

75 

76 return model 

77 

78 

79def vgg16_for_segmentation(pretrained=False, progress=True, **kwargs): 

80 return _make_vgg16_typeD_for_segmentation( 

81 pretrained=pretrained, batch_norm=False, progress=progress, **kwargs 

82 ) 

83 

84 

85vgg16_for_segmentation.__doc__ = torchvision.models.vgg16.__doc__ 

86 

87 

88def vgg16_bn_for_segmentation(pretrained=False, progress=True, **kwargs): 

89 return _make_vgg16_typeD_for_segmentation( 

90 pretrained=pretrained, batch_norm=True, progress=progress, **kwargs 

91 ) 

92 

93 

94vgg16_bn_for_segmentation.__doc__ = torchvision.models.vgg16_bn.__doc__