Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1#!/usr/bin/env python 

2# coding=utf-8 

3 

4"""Tests model loading""" 

5 

6 

7from ..models.backbones.vgg import VGG4Segmentation 

8from ..models.normalizer import TorchVisionNormalizer 

9 

10 

11def test_driu(): 

12 

13 from ..models.driu import DRIU, driu 

14 

15 model = driu(pretrained_backbone=True, progress=True) 

16 assert len(model) == 3 

17 assert type(model[0]) == TorchVisionNormalizer 

18 assert type(model[1]) == VGG4Segmentation # backbone 

19 assert type(model[2]) == DRIU # head 

20 

21 model = driu(pretrained_backbone=False) 

22 assert len(model) == 2 

23 assert type(model[0]) == VGG4Segmentation # backbone 

24 assert type(model[1]) == DRIU # head 

25 

26 

27def test_driu_bn(): 

28 

29 from ..models.driu_bn import DRIUBN, driu_bn 

30 

31 model = driu_bn(pretrained_backbone=True, progress=True) 

32 assert len(model) == 3 

33 assert type(model[0]) == TorchVisionNormalizer 

34 assert type(model[1]) == VGG4Segmentation # backbone 

35 assert type(model[2]) == DRIUBN # head 

36 

37 model = driu_bn(pretrained_backbone=False) 

38 assert len(model) == 2 

39 assert type(model[0]) == VGG4Segmentation # backbone 

40 assert type(model[1]) == DRIUBN # head 

41 

42 

43def test_driu_od(): 

44 

45 from ..models.driu_od import DRIUOD, driu_od 

46 

47 model = driu_od(pretrained_backbone=True, progress=True) 

48 assert len(model) == 3 

49 assert type(model[0]) == TorchVisionNormalizer 

50 assert type(model[1]) == VGG4Segmentation # backbone 

51 assert type(model[2]) == DRIUOD # head 

52 

53 model = driu_od(pretrained_backbone=False) 

54 assert len(model) == 2 

55 assert type(model[0]) == VGG4Segmentation # backbone 

56 assert type(model[1]) == DRIUOD # head 

57 

58 

59def test_driu_pix(): 

60 

61 from ..models.driu_pix import DRIUPIX, driu_pix 

62 

63 model = driu_pix(pretrained_backbone=True, progress=True) 

64 assert len(model) == 3 

65 assert type(model[0]) == TorchVisionNormalizer 

66 assert type(model[1]) == VGG4Segmentation # backbone 

67 assert type(model[2]) == DRIUPIX # head 

68 

69 model = driu_pix(pretrained_backbone=False) 

70 assert len(model) == 2 

71 assert type(model[0]) == VGG4Segmentation # backbone 

72 assert type(model[1]) == DRIUPIX # head 

73 

74 

75def test_unet(): 

76 

77 from ..models.unet import UNet, unet 

78 

79 model = unet(pretrained_backbone=True, progress=True) 

80 assert len(model) == 3 

81 assert type(model[0]) == TorchVisionNormalizer 

82 assert type(model[1]) == VGG4Segmentation # backbone 

83 assert type(model[2]) == UNet # head 

84 

85 model = unet(pretrained_backbone=False) 

86 assert len(model) == 2 

87 assert type(model[0]) == VGG4Segmentation # backbone 

88 assert type(model[1]) == UNet # head 

89 

90 

91def test_hed(): 

92 

93 from ..models.hed import HED, hed 

94 

95 model = hed(pretrained_backbone=True, progress=True) 

96 assert len(model) == 3 

97 assert type(model[0]) == TorchVisionNormalizer 

98 assert type(model[1]) == VGG4Segmentation # backbone 

99 assert type(model[2]) == HED # head 

100 

101 model = hed(pretrained_backbone=False) 

102 assert len(model) == 2 

103 assert type(model[0]) == VGG4Segmentation # backbone 

104 assert type(model[1]) == HED # head 

105 

106 

107def test_m2unet(): 

108 

109 from ..models.backbones.mobilenetv2 import MobileNetV24Segmentation 

110 from ..models.m2unet import M2UNet, m2unet 

111 

112 model = m2unet(pretrained_backbone=True, progress=True) 

113 assert len(model) == 3 

114 assert type(model[0]) == TorchVisionNormalizer 

115 assert type(model[1]) == MobileNetV24Segmentation # backbone 

116 assert type(model[2]) == M2UNet # head 

117 

118 model = m2unet(pretrained_backbone=False) 

119 assert len(model) == 2 

120 assert type(model[0]) == MobileNetV24Segmentation # backbone 

121 assert type(model[1]) == M2UNet # head 

122 

123 

124def test_resunet50(): 

125 

126 from ..models.backbones.resnet import ResNet4Segmentation 

127 from ..models.resunet import ResUNet, resunet50 

128 

129 model = resunet50(pretrained_backbone=True, progress=True) 

130 assert len(model) == 3 

131 assert type(model[0]) == TorchVisionNormalizer 

132 assert type(model[1]) == ResNet4Segmentation # backbone 

133 assert type(model[2]) == ResUNet # head 

134 

135 model = resunet50(pretrained_backbone=False) 

136 assert len(model) == 2 

137 assert type(model[0]) == ResNet4Segmentation # backbone 

138 assert type(model[1]) == ResUNet # head 

139 print(model)