Coverage for src/bob/bio/face/pytorch/facexzoo/resnest/resnest.py: 79%

29 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2023-12-06 22:15 +0100

1""" 

2@author: Jun Wang 

3@date: 20210301 

4@contact: jun21wangustc@gmail.com 

5""" 

6 

7# based on: 

8# https://github.com/zhanghang1989/ResNeSt/blob/master/resnest/torch/resnest.py 

9 

10import torch 

11import torch.nn as nn 

12 

13from .resnet import Bottleneck, ResNet 

14 

15 

16class Flatten(nn.Module): 

17 def forward(self, input): 

18 return input.view(input.size(0), -1) 

19 

20 

21def l2_norm(input, axis=1): 

22 norm = torch.norm(input, 2, axis, True) 

23 output = torch.div(input, norm) 

24 return output 

25 

26 

27class ResNeSt(nn.Module): 

28 def __init__(self, num_layers, drop_ratio, feat_dim, out_h=7, out_w=7): 

29 super(ResNeSt, self).__init__() 

30 self.input_layer = nn.Sequential( 

31 nn.Conv2d(3, 64, (3, 3), 1, 1, bias=False), 

32 nn.BatchNorm2d(64), 

33 nn.PReLU(64), 

34 ) 

35 self.output_layer = nn.Sequential( 

36 nn.BatchNorm2d(2048), 

37 nn.Dropout(drop_ratio), 

38 Flatten(), 

39 nn.Linear(2048 * out_h * out_w, feat_dim), 

40 nn.BatchNorm1d(feat_dim), 

41 ) 

42 if num_layers == 50: 

43 self.body = ResNet( 

44 Bottleneck, 

45 [3, 4, 6, 3], 

46 radix=2, 

47 groups=1, 

48 bottleneck_width=64, 

49 deep_stem=True, 

50 stem_width=32, 

51 avg_down=True, 

52 avd=True, 

53 avd_first=False, 

54 ) 

55 elif num_layers == 101: 

56 self.body = ResNet( 

57 Bottleneck, 

58 [3, 4, 23, 3], 

59 radix=2, 

60 groups=1, 

61 bottleneck_width=64, 

62 deep_stem=True, 

63 stem_width=64, 

64 avg_down=True, 

65 avd=True, 

66 avd_first=False, 

67 ) 

68 elif num_layers == 200: 

69 self.body = ResNet( 

70 Bottleneck, 

71 [3, 24, 36, 3], 

72 radix=2, 

73 groups=1, 

74 bottleneck_width=64, 

75 deep_stem=True, 

76 stem_width=64, 

77 avg_down=True, 

78 avd=True, 

79 avd_first=False, 

80 ) 

81 elif num_layers == 269: 

82 self.body = ResNet( 

83 Bottleneck, 

84 [3, 30, 48, 8], 

85 radix=2, 

86 groups=1, 

87 bottleneck_width=64, 

88 deep_stem=True, 

89 stem_width=64, 

90 avg_down=True, 

91 avd=True, 

92 avd_first=False, 

93 ) 

94 else: 

95 pass 

96 

97 def forward(self, x): 

98 x = self.input_layer(x) 

99 x = self.body(x) 

100 x = self.output_layer(x) 

101 return l2_norm(x)