Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1#!/usr/bin/env python 

2# -*- coding: utf-8 -*- 

3 

4import torchvision.models.vgg 

5 

6 

7class VGG4Segmentation(torchvision.models.vgg.VGG): 

8 """Adaptation of base VGG functionality to U-Net style segmentation 

9 

10 This version of VGG is slightly modified so it can be used through 

11 torchvision's API. It outputs intermediate features which are normally not 

12 output by the base VGG implementation, but are required for segmentation 

13 operations. 

14 

15 

16 Parameters 

17 ========== 

18 

19 return_features : :py:class:`list`, Optional 

20 A list of integers indicating the feature layers to be returned from 

21 the original module. 

22 

23 """ 

24 

25 def __init__(self, *args, **kwargs): 

26 self._return_features = kwargs.pop("return_features") 

27 super(VGG4Segmentation, self).__init__(*args, **kwargs) 

28 

29 def forward(self, x): 

30 outputs = [] 

31 # hardwiring of input 

32 outputs.append(x.shape[2:4]) 

33 for index, m in enumerate(self.features): 

34 x = m(x) 

35 # extract layers 

36 if index in self._return_features: 

37 outputs.append(x) 

38 return outputs 

39 

40 

41def _vgg_for_segmentation( 

42 arch, cfg, batch_norm, pretrained, progress, **kwargs 

43): 

44 

45 if pretrained: 

46 kwargs["init_weights"] = False 

47 

48 model = VGG4Segmentation( 

49 torchvision.models.vgg.make_layers( 

50 torchvision.models.vgg.cfgs[cfg], batch_norm=batch_norm 

51 ), 

52 **kwargs 

53 ) 

54 

55 if pretrained: 

56 state_dict = torchvision.models.vgg.load_state_dict_from_url( 

57 torchvision.models.vgg.model_urls[arch], progress=progress 

58 ) 

59 model.load_state_dict(state_dict) 

60 

61 # erase VGG head (for classification), not used for segmentation 

62 delattr(model, "classifier") 

63 delattr(model, "avgpool") 

64 

65 return model 

66 

67 

68def vgg16_for_segmentation(pretrained=False, progress=True, **kwargs): 

69 return _vgg_for_segmentation( 

70 "vgg16", "D", False, pretrained, progress, **kwargs 

71 ) 

72 

73 

74vgg16_for_segmentation.__doc__ = torchvision.models.vgg16.__doc__ 

75 

76 

77def vgg16_bn_for_segmentation(pretrained=False, progress=True, **kwargs): 

78 return _vgg_for_segmentation( 

79 "vgg16_bn", "D", True, pretrained, progress, **kwargs 

80 ) 

81 

82 

83vgg16_bn_for_segmentation.__doc__ = torchvision.models.vgg16_bn.__doc__