1#!/usr/bin/env python
2# coding=utf-8
3
4"""Tests model loading"""
5
6
7from ...binseg.models.backbones.vgg import VGG4Segmentation
8from ...binseg.models.normalizer import TorchVisionNormalizer
9
10
11def test_driu():
12
13 from ...binseg.models.driu import DRIU, driu
14
15 model = driu(pretrained_backbone=True, progress=True)
16 assert len(model) == 3
17 assert type(model[0]) == TorchVisionNormalizer
18 assert type(model[1]) == VGG4Segmentation # backbone
19 assert type(model[2]) == DRIU # head
20
21 model = driu(pretrained_backbone=False)
22 assert len(model) == 2
23 assert type(model[0]) == VGG4Segmentation # backbone
24 assert type(model[1]) == DRIU # head
25
26
27def test_driu_bn():
28
29 from ...binseg.models.driu_bn import DRIUBN, driu_bn
30
31 model = driu_bn(pretrained_backbone=True, progress=True)
32 assert len(model) == 3
33 assert type(model[0]) == TorchVisionNormalizer
34 assert type(model[1]) == VGG4Segmentation # backbone
35 assert type(model[2]) == DRIUBN # head
36
37 model = driu_bn(pretrained_backbone=False)
38 assert len(model) == 2
39 assert type(model[0]) == VGG4Segmentation # backbone
40 assert type(model[1]) == DRIUBN # head
41
42
43def test_driu_od():
44
45 from ...binseg.models.driu_od import DRIUOD, driu_od
46
47 model = driu_od(pretrained_backbone=True, progress=True)
48 assert len(model) == 3
49 assert type(model[0]) == TorchVisionNormalizer
50 assert type(model[1]) == VGG4Segmentation # backbone
51 assert type(model[2]) == DRIUOD # head
52
53 model = driu_od(pretrained_backbone=False)
54 assert len(model) == 2
55 assert type(model[0]) == VGG4Segmentation # backbone
56 assert type(model[1]) == DRIUOD # head
57
58
59def test_driu_pix():
60
61 from ...binseg.models.driu_pix import DRIUPIX, driu_pix
62
63 model = driu_pix(pretrained_backbone=True, progress=True)
64 assert len(model) == 3
65 assert type(model[0]) == TorchVisionNormalizer
66 assert type(model[1]) == VGG4Segmentation # backbone
67 assert type(model[2]) == DRIUPIX # head
68
69 model = driu_pix(pretrained_backbone=False)
70 assert len(model) == 2
71 assert type(model[0]) == VGG4Segmentation # backbone
72 assert type(model[1]) == DRIUPIX # head
73
74
75def test_unet():
76
77 from ...binseg.models.unet import UNet, unet
78
79 model = unet(pretrained_backbone=True, progress=True)
80 assert len(model) == 3
81 assert type(model[0]) == TorchVisionNormalizer
82 assert type(model[1]) == VGG4Segmentation # backbone
83 assert type(model[2]) == UNet # head
84
85 model = unet(pretrained_backbone=False)
86 assert len(model) == 2
87 assert type(model[0]) == VGG4Segmentation # backbone
88 assert type(model[1]) == UNet # head
89
90
91def test_hed():
92
93 from ...binseg.models.hed import HED, hed
94
95 model = hed(pretrained_backbone=True, progress=True)
96 assert len(model) == 3
97 assert type(model[0]) == TorchVisionNormalizer
98 assert type(model[1]) == VGG4Segmentation # backbone
99 assert type(model[2]) == HED # head
100
101 model = hed(pretrained_backbone=False)
102 assert len(model) == 2
103 assert type(model[0]) == VGG4Segmentation # backbone
104 assert type(model[1]) == HED # head
105
106
107def test_m2unet():
108
109 from ...binseg.models.backbones.mobilenetv2 import MobileNetV24Segmentation
110 from ...binseg.models.m2unet import M2UNet, m2unet
111
112 model = m2unet(pretrained_backbone=True, progress=True)
113 assert len(model) == 3
114 assert type(model[0]) == TorchVisionNormalizer
115 assert type(model[1]) == MobileNetV24Segmentation # backbone
116 assert type(model[2]) == M2UNet # head
117
118 model = m2unet(pretrained_backbone=False)
119 assert len(model) == 2
120 assert type(model[0]) == MobileNetV24Segmentation # backbone
121 assert type(model[1]) == M2UNet # head
122
123
124def test_resunet50():
125
126 from ...binseg.models.backbones.resnet import ResNet4Segmentation
127 from ...binseg.models.resunet import ResUNet, resunet50
128
129 model = resunet50(pretrained_backbone=True, progress=True)
130 assert len(model) == 3
131 assert type(model[0]) == TorchVisionNormalizer
132 assert type(model[1]) == ResNet4Segmentation # backbone
133 assert type(model[2]) == ResUNet # head
134
135 model = resunet50(pretrained_backbone=False)
136 assert len(model) == 2
137 assert type(model[0]) == ResNet4Segmentation # backbone
138 assert type(model[1]) == ResUNet # head
139 print(model)
140
141
142def test_fasterrcnn():
143 import torchvision
144
145 from ...detect.models.faster_rcnn import faster_rcnn
146
147 model = faster_rcnn()
148 assert type(model) == torchvision.models.detection.faster_rcnn.FasterRCNN
149 assert (
150 type(model.backbone)
151 == torchvision.models.detection.backbone_utils.BackboneWithFPN
152 )
153 assert (
154 type(model.roi_heads.box_predictor)
155 == torchvision.models.detection.faster_rcnn.FastRCNNPredictor
156 )