Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
4import torchvision.models.vgg
7class VGG4Segmentation(torchvision.models.vgg.VGG):
8 """Adaptation of base VGG functionality to U-Net style segmentation
10 This version of VGG is slightly modified so it can be used through
11 torchvision's API. It outputs intermediate features which are normally not
12 output by the base VGG implementation, but are required for segmentation
13 operations.
16 Parameters
17 ==========
19 return_features : :py:class:`list`, Optional
20 A list of integers indicating the feature layers to be returned from
21 the original module.
23 """
25 def __init__(self, *args, **kwargs):
26 self._return_features = kwargs.pop("return_features")
27 super(VGG4Segmentation, self).__init__(*args, **kwargs)
29 def forward(self, x):
30 outputs = []
31 # hardwiring of input
32 outputs.append(x.shape[2:4])
33 for index, m in enumerate(self.features):
34 x = m(x)
35 # extract layers
36 if index in self._return_features:
37 outputs.append(x)
38 return outputs
41def _vgg_for_segmentation(
42 arch, cfg, batch_norm, pretrained, progress, **kwargs
43):
45 if pretrained:
46 kwargs["init_weights"] = False
48 model = VGG4Segmentation(
49 torchvision.models.vgg.make_layers(
50 torchvision.models.vgg.cfgs[cfg], batch_norm=batch_norm
51 ),
52 **kwargs
53 )
55 if pretrained:
56 state_dict = torchvision.models.vgg.load_state_dict_from_url(
57 torchvision.models.vgg.model_urls[arch], progress=progress
58 )
59 model.load_state_dict(state_dict)
61 # erase VGG head (for classification), not used for segmentation
62 delattr(model, "classifier")
63 delattr(model, "avgpool")
65 return model
68def vgg16_for_segmentation(pretrained=False, progress=True, **kwargs):
69 return _vgg_for_segmentation(
70 "vgg16", "D", False, pretrained, progress, **kwargs
71 )
74vgg16_for_segmentation.__doc__ = torchvision.models.vgg16.__doc__
77def vgg16_bn_for_segmentation(pretrained=False, progress=True, **kwargs):
78 return _vgg_for_segmentation(
79 "vgg16_bn", "D", True, pretrained, progress, **kwargs
80 )
83vgg16_bn_for_segmentation.__doc__ = torchvision.models.vgg16_bn.__doc__