Coverage for src/deepdraw/models/backbones/mobilenetv2.py: 67%
30 statements
« prev ^ index » next coverage.py v7.3.1, created at 2023-11-30 15:00 +0100
« prev ^ index » next coverage.py v7.3.1, created at 2023-11-30 15:00 +0100
1# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
2#
3# SPDX-License-Identifier: GPL-3.0-or-later
5import torchvision.models
6import torchvision.models.mobilenetv2
8try:
9 # pytorch >= 1.12
10 from torch.hub import load_state_dict_from_url
11except ImportError:
12 # pytorch < 1.12
13 from torchvision.models.utils import load_state_dict_from_url
16class MobileNetV24Segmentation(torchvision.models.mobilenetv2.MobileNetV2):
17 """Adaptation of base MobileNetV2 functionality to U-Net style
18 segmentation.
20 This version of MobileNetV2 is slightly modified so it can be used through
21 torchvision's API. It outputs intermediate features which are normally not
22 output by the base MobileNetV2 implementation, but are required for
23 segmentation operations.
26 Parameters
27 ==========
29 return_features : :py:class:`list`, Optional
30 A list of integers indicating the feature layers to be returned from
31 the original module.
32 """
34 def __init__(self, *args, **kwargs):
35 self._return_features = kwargs.pop("return_features")
36 super().__init__(*args, **kwargs)
38 def forward(self, x):
39 outputs = []
40 # hw of input, needed for DRIU and HED
41 outputs.append(x.shape[2:4])
42 outputs.append(x)
43 for index, m in enumerate(self.features):
44 x = m(x)
45 # extract layers
46 if index in self._return_features:
47 outputs.append(x)
48 return outputs
51def mobilenet_v2_for_segmentation(pretrained=False, progress=True, **kwargs):
52 model = MobileNetV24Segmentation(**kwargs)
54 if pretrained:
55 state_dict = load_state_dict_from_url(
56 torchvision.models.mobilenetv2.MobileNet_V2_Weights.DEFAULT.url,
57 progress=progress,
58 )
59 model.load_state_dict(state_dict)
61 # erase MobileNetV2 head (for classification), not used for segmentation
62 delattr(model, "classifier")
64 return_features = kwargs.get("return_features")
65 if return_features is not None:
66 model.features = model.features[: (max(return_features) + 1)]
68 return model
71mobilenet_v2_for_segmentation.__doc__ = torchvision.models.mobilenetv2.__doc__