Coverage for src/deepdraw/models/backbones/vgg.py: 74%
34 statements
« prev ^ index » next coverage.py v7.3.1, created at 2023-11-30 15:00 +0100
« prev ^ index » next coverage.py v7.3.1, created at 2023-11-30 15:00 +0100
1# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
2#
3# SPDX-License-Identifier: GPL-3.0-or-later
5import torchvision.models
7try:
8 # pytorch >= 1.12
9 from torch.hub import load_state_dict_from_url
10except ImportError:
11 # pytorch < 1.12
12 from torchvision.models.utils import load_state_dict_from_url
15class VGG4Segmentation(torchvision.models.vgg.VGG):
16 """Adaptation of base VGG functionality to U-Net style segmentation.
18 This version of VGG is slightly modified so it can be used through
19 torchvision's API. It outputs intermediate features which are normally not
20 output by the base VGG implementation, but are required for segmentation
21 operations.
24 Parameters
25 ==========
27 return_features : :py:class:`list`, Optional
28 A list of integers indicating the feature layers to be returned from
29 the original module.
30 """
32 def __init__(self, *args, **kwargs):
33 self._return_features = kwargs.pop("return_features")
34 super().__init__(*args, **kwargs)
36 def forward(self, x):
37 outputs = []
38 # hardwiring of input
39 outputs.append(x.shape[2:4])
40 for index, m in enumerate(self.features):
41 x = m(x)
42 # extract layers
43 if index in self._return_features:
44 outputs.append(x)
45 return outputs
48def _make_vgg16_typeD_for_segmentation(
49 pretrained, batch_norm, progress, **kwargs
50):
51 if pretrained:
52 kwargs["init_weights"] = False
54 model = VGG4Segmentation(
55 torchvision.models.vgg.make_layers(
56 torchvision.models.vgg.cfgs["D"],
57 batch_norm=batch_norm,
58 ),
59 **kwargs,
60 )
62 if pretrained:
63 weights = (
64 torchvision.models.vgg.VGG16_Weights.DEFAULT.url
65 if not batch_norm
66 else torchvision.models.vgg.VGG16_BN_Weights.DEFAULT.url
67 )
69 state_dict = load_state_dict_from_url(weights, progress=progress)
70 model.load_state_dict(state_dict)
72 # erase VGG head (for classification), not used for segmentation
73 delattr(model, "classifier")
74 delattr(model, "avgpool")
76 return model
79def vgg16_for_segmentation(pretrained=False, progress=True, **kwargs):
80 return _make_vgg16_typeD_for_segmentation(
81 pretrained=pretrained, batch_norm=False, progress=progress, **kwargs
82 )
85vgg16_for_segmentation.__doc__ = torchvision.models.vgg16.__doc__
88def vgg16_bn_for_segmentation(pretrained=False, progress=True, **kwargs):
89 return _make_vgg16_typeD_for_segmentation(
90 pretrained=pretrained, batch_norm=True, progress=progress, **kwargs
91 )
94vgg16_bn_for_segmentation.__doc__ = torchvision.models.vgg16_bn.__doc__