Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1#!/usr/bin/env python 

2# vim: set fileencoding=utf-8 : 

3 

4import torchvision.models.mobilenetv2 

5 

6 

7class MobileNetV24Segmentation(torchvision.models.mobilenetv2.MobileNetV2): 

8 """Adaptation of base MobileNetV2 functionality to U-Net style segmentation 

9 

10 This version of MobileNetV2 is slightly modified so it can be used through 

11 torchvision's API. It outputs intermediate features which are normally not 

12 output by the base MobileNetV2 implementation, but are required for 

13 segmentation operations. 

14 

15 

16 Parameters 

17 ========== 

18 

19 return_features : :py:class:`list`, Optional 

20 A list of integers indicating the feature layers to be returned from 

21 the original module. 

22 

23 """ 

24 

25 def __init__(self, *args, **kwargs): 

26 self._return_features = kwargs.pop("return_features") 

27 super(MobileNetV24Segmentation, self).__init__(*args, **kwargs) 

28 

29 def forward(self, x): 

30 outputs = [] 

31 # hw of input, needed for DRIU and HED 

32 outputs.append(x.shape[2:4]) 

33 outputs.append(x) 

34 for index, m in enumerate(self.features): 

35 x = m(x) 

36 # extract layers 

37 if index in self._return_features: 

38 outputs.append(x) 

39 return outputs 

40 

41 

42def mobilenet_v2_for_segmentation(pretrained=False, progress=True, **kwargs): 

43 model = MobileNetV24Segmentation(**kwargs) 

44 

45 if pretrained: 

46 state_dict = torchvision.models.mobilenetv2.load_state_dict_from_url( 

47 torchvision.models.mobilenetv2.model_urls["mobilenet_v2"], 

48 progress=progress, 

49 ) 

50 model.load_state_dict(state_dict) 

51 

52 # erase MobileNetV2 head (for classification), not used for segmentation 

53 delattr(model, "classifier") 

54 

55 return_features = kwargs.get("return_features") 

56 if return_features is not None: 

57 model.features = model.features[: (max(return_features) + 1)] 

58 

59 return model 

60 

61 

62mobilenet_v2_for_segmentation.__doc__ = ( 

63 torchvision.models.mobilenetv2.mobilenet_v2.__doc__ 

64)