Coverage for src/deepdraw/models/backbones/mobilenetv2.py: 67%

30 statements  

« prev     ^ index     » next       coverage.py v7.3.1, created at 2023-11-30 15:00 +0100

1# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> 

2# 

3# SPDX-License-Identifier: GPL-3.0-or-later 

4 

5import torchvision.models 

6import torchvision.models.mobilenetv2 

7 

8try: 

9 # pytorch >= 1.12 

10 from torch.hub import load_state_dict_from_url 

11except ImportError: 

12 # pytorch < 1.12 

13 from torchvision.models.utils import load_state_dict_from_url 

14 

15 

16class MobileNetV24Segmentation(torchvision.models.mobilenetv2.MobileNetV2): 

17 """Adaptation of base MobileNetV2 functionality to U-Net style 

18 segmentation. 

19 

20 This version of MobileNetV2 is slightly modified so it can be used through 

21 torchvision's API. It outputs intermediate features which are normally not 

22 output by the base MobileNetV2 implementation, but are required for 

23 segmentation operations. 

24 

25 

26 Parameters 

27 ========== 

28 

29 return_features : :py:class:`list`, Optional 

30 A list of integers indicating the feature layers to be returned from 

31 the original module. 

32 """ 

33 

34 def __init__(self, *args, **kwargs): 

35 self._return_features = kwargs.pop("return_features") 

36 super().__init__(*args, **kwargs) 

37 

38 def forward(self, x): 

39 outputs = [] 

40 # hw of input, needed for DRIU and HED 

41 outputs.append(x.shape[2:4]) 

42 outputs.append(x) 

43 for index, m in enumerate(self.features): 

44 x = m(x) 

45 # extract layers 

46 if index in self._return_features: 

47 outputs.append(x) 

48 return outputs 

49 

50 

51def mobilenet_v2_for_segmentation(pretrained=False, progress=True, **kwargs): 

52 model = MobileNetV24Segmentation(**kwargs) 

53 

54 if pretrained: 

55 state_dict = load_state_dict_from_url( 

56 torchvision.models.mobilenetv2.MobileNet_V2_Weights.DEFAULT.url, 

57 progress=progress, 

58 ) 

59 model.load_state_dict(state_dict) 

60 

61 # erase MobileNetV2 head (for classification), not used for segmentation 

62 delattr(model, "classifier") 

63 

64 return_features = kwargs.get("return_features") 

65 if return_features is not None: 

66 model.features = model.features[: (max(return_features) + 1)] 

67 

68 return model 

69 

70 

71mobilenet_v2_for_segmentation.__doc__ = torchvision.models.mobilenetv2.__doc__