Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1#!/usr/bin/env python
2# coding=utf-8
4import torchvision.models.resnet
7class ResNet4Segmentation(torchvision.models.resnet.ResNet):
8 """Adaptation of base ResNet functionality to U-Net style segmentation
10 This version of ResNet is slightly modified so it can be used through
11 torchvision's API. It outputs intermediate features which are normally not
12 output by the base ResNet implementation, but are required for segmentation
13 operations.
16 Parameters
17 ==========
19 return_features : :py:class:`list`, Optional
20 A list of integers indicating the feature layers to be returned from
21 the original module.
23 """
25 def __init__(self, *args, **kwargs):
26 self._return_features = kwargs.pop("return_features")
27 super(ResNet4Segmentation, self).__init__(*args, **kwargs)
29 def forward(self, x):
30 outputs = []
31 # hardwiring of input
32 outputs.append(x.shape[2:4])
33 for index, m in enumerate(self.features):
34 x = m(x)
35 # extract layers
36 if index in self.return_features:
37 outputs.append(x)
38 return outputs
41def _resnet_for_segmentation(
42 arch, block, layers, pretrained, progress, **kwargs
43):
44 model = ResNet4Segmentation(block, layers, **kwargs)
45 if pretrained:
46 state_dict = torchvision.models.resnet.load_state_dict_from_url(
47 torchvision.models.resnet.model_urls[arch], progress=progress
48 )
49 model.load_state_dict(state_dict)
51 # erase ResNet head (for classification), not used for segmentation
52 delattr(model, "avgpool")
53 delattr(model, "fc")
55 return model
58def resnet50_for_segmentation(pretrained=False, progress=True, **kwargs):
59 return _resnet_for_segmentation(
60 "resnet50",
61 torchvision.models.resnet.Bottleneck,
62 [3, 4, 6, 3],
63 pretrained,
64 progress,
65 **kwargs
66 )
69resnet50_for_segmentation.__doc__ = torchvision.models.resnet.resnet50.__doc__