Coverage for /scratch/builds/bob/bob.ip.binseg/miniconda/conda-bld/bob.ip.binseg_1673966692152/_test_env_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_p/lib/python3.10/site-packages/bob/ip/common/test/test_models.py: 100%

101 statements  

« prev     ^ index     » next       coverage.py v7.0.5, created at 2023-01-17 15:03 +0000

1#!/usr/bin/env python 

2# coding=utf-8 

3 

4"""Tests model loading""" 

5 

6 

7from ...binseg.models.backbones.vgg import VGG4Segmentation 

8from ...binseg.models.normalizer import TorchVisionNormalizer 

9 

10 

11def test_driu(): 

12 

13 from ...binseg.models.driu import DRIU, driu 

14 

15 model = driu(pretrained_backbone=True, progress=True) 

16 assert len(model) == 3 

17 assert type(model[0]) == TorchVisionNormalizer 

18 assert type(model[1]) == VGG4Segmentation # backbone 

19 assert type(model[2]) == DRIU # head 

20 

21 model = driu(pretrained_backbone=False) 

22 assert len(model) == 2 

23 assert type(model[0]) == VGG4Segmentation # backbone 

24 assert type(model[1]) == DRIU # head 

25 

26 

27def test_driu_bn(): 

28 

29 from ...binseg.models.driu_bn import DRIUBN, driu_bn 

30 

31 model = driu_bn(pretrained_backbone=True, progress=True) 

32 assert len(model) == 3 

33 assert type(model[0]) == TorchVisionNormalizer 

34 assert type(model[1]) == VGG4Segmentation # backbone 

35 assert type(model[2]) == DRIUBN # head 

36 

37 model = driu_bn(pretrained_backbone=False) 

38 assert len(model) == 2 

39 assert type(model[0]) == VGG4Segmentation # backbone 

40 assert type(model[1]) == DRIUBN # head 

41 

42 

43def test_driu_od(): 

44 

45 from ...binseg.models.driu_od import DRIUOD, driu_od 

46 

47 model = driu_od(pretrained_backbone=True, progress=True) 

48 assert len(model) == 3 

49 assert type(model[0]) == TorchVisionNormalizer 

50 assert type(model[1]) == VGG4Segmentation # backbone 

51 assert type(model[2]) == DRIUOD # head 

52 

53 model = driu_od(pretrained_backbone=False) 

54 assert len(model) == 2 

55 assert type(model[0]) == VGG4Segmentation # backbone 

56 assert type(model[1]) == DRIUOD # head 

57 

58 

59def test_driu_pix(): 

60 

61 from ...binseg.models.driu_pix import DRIUPIX, driu_pix 

62 

63 model = driu_pix(pretrained_backbone=True, progress=True) 

64 assert len(model) == 3 

65 assert type(model[0]) == TorchVisionNormalizer 

66 assert type(model[1]) == VGG4Segmentation # backbone 

67 assert type(model[2]) == DRIUPIX # head 

68 

69 model = driu_pix(pretrained_backbone=False) 

70 assert len(model) == 2 

71 assert type(model[0]) == VGG4Segmentation # backbone 

72 assert type(model[1]) == DRIUPIX # head 

73 

74 

75def test_unet(): 

76 

77 from ...binseg.models.unet import UNet, unet 

78 

79 model = unet(pretrained_backbone=True, progress=True) 

80 assert len(model) == 3 

81 assert type(model[0]) == TorchVisionNormalizer 

82 assert type(model[1]) == VGG4Segmentation # backbone 

83 assert type(model[2]) == UNet # head 

84 

85 model = unet(pretrained_backbone=False) 

86 assert len(model) == 2 

87 assert type(model[0]) == VGG4Segmentation # backbone 

88 assert type(model[1]) == UNet # head 

89 

90 

91def test_hed(): 

92 

93 from ...binseg.models.hed import HED, hed 

94 

95 model = hed(pretrained_backbone=True, progress=True) 

96 assert len(model) == 3 

97 assert type(model[0]) == TorchVisionNormalizer 

98 assert type(model[1]) == VGG4Segmentation # backbone 

99 assert type(model[2]) == HED # head 

100 

101 model = hed(pretrained_backbone=False) 

102 assert len(model) == 2 

103 assert type(model[0]) == VGG4Segmentation # backbone 

104 assert type(model[1]) == HED # head 

105 

106 

107def test_m2unet(): 

108 

109 from ...binseg.models.backbones.mobilenetv2 import MobileNetV24Segmentation 

110 from ...binseg.models.m2unet import M2UNet, m2unet 

111 

112 model = m2unet(pretrained_backbone=True, progress=True) 

113 assert len(model) == 3 

114 assert type(model[0]) == TorchVisionNormalizer 

115 assert type(model[1]) == MobileNetV24Segmentation # backbone 

116 assert type(model[2]) == M2UNet # head 

117 

118 model = m2unet(pretrained_backbone=False) 

119 assert len(model) == 2 

120 assert type(model[0]) == MobileNetV24Segmentation # backbone 

121 assert type(model[1]) == M2UNet # head 

122 

123 

124def test_resunet50(): 

125 

126 from ...binseg.models.backbones.resnet import ResNet4Segmentation 

127 from ...binseg.models.resunet import ResUNet, resunet50 

128 

129 model = resunet50(pretrained_backbone=True, progress=True) 

130 assert len(model) == 3 

131 assert type(model[0]) == TorchVisionNormalizer 

132 assert type(model[1]) == ResNet4Segmentation # backbone 

133 assert type(model[2]) == ResUNet # head 

134 

135 model = resunet50(pretrained_backbone=False) 

136 assert len(model) == 2 

137 assert type(model[0]) == ResNet4Segmentation # backbone 

138 assert type(model[1]) == ResUNet # head 

139 print(model) 

140 

141 

142def test_fasterrcnn(): 

143 import torchvision 

144 

145 from ...detect.models.faster_rcnn import faster_rcnn 

146 

147 model = faster_rcnn() 

148 assert type(model) == torchvision.models.detection.faster_rcnn.FasterRCNN 

149 assert ( 

150 type(model.backbone) 

151 == torchvision.models.detection.backbone_utils.BackboneWithFPN 

152 ) 

153 assert ( 

154 type(model.roi_heads.box_predictor) 

155 == torchvision.models.detection.faster_rcnn.FastRCNNPredictor 

156 )