Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1#!/usr/bin/env python
2# coding=utf-8
4"""Tests model loading"""
7from ..models.backbones.vgg import VGG4Segmentation
8from ..models.normalizer import TorchVisionNormalizer
11def test_driu():
13 from ..models.driu import DRIU, driu
15 model = driu(pretrained_backbone=True, progress=True)
16 assert len(model) == 3
17 assert type(model[0]) == TorchVisionNormalizer
18 assert type(model[1]) == VGG4Segmentation # backbone
19 assert type(model[2]) == DRIU # head
21 model = driu(pretrained_backbone=False)
22 assert len(model) == 2
23 assert type(model[0]) == VGG4Segmentation # backbone
24 assert type(model[1]) == DRIU # head
27def test_driu_bn():
29 from ..models.driu_bn import DRIUBN, driu_bn
31 model = driu_bn(pretrained_backbone=True, progress=True)
32 assert len(model) == 3
33 assert type(model[0]) == TorchVisionNormalizer
34 assert type(model[1]) == VGG4Segmentation # backbone
35 assert type(model[2]) == DRIUBN # head
37 model = driu_bn(pretrained_backbone=False)
38 assert len(model) == 2
39 assert type(model[0]) == VGG4Segmentation # backbone
40 assert type(model[1]) == DRIUBN # head
43def test_driu_od():
45 from ..models.driu_od import DRIUOD, driu_od
47 model = driu_od(pretrained_backbone=True, progress=True)
48 assert len(model) == 3
49 assert type(model[0]) == TorchVisionNormalizer
50 assert type(model[1]) == VGG4Segmentation # backbone
51 assert type(model[2]) == DRIUOD # head
53 model = driu_od(pretrained_backbone=False)
54 assert len(model) == 2
55 assert type(model[0]) == VGG4Segmentation # backbone
56 assert type(model[1]) == DRIUOD # head
59def test_driu_pix():
61 from ..models.driu_pix import DRIUPIX, driu_pix
63 model = driu_pix(pretrained_backbone=True, progress=True)
64 assert len(model) == 3
65 assert type(model[0]) == TorchVisionNormalizer
66 assert type(model[1]) == VGG4Segmentation # backbone
67 assert type(model[2]) == DRIUPIX # head
69 model = driu_pix(pretrained_backbone=False)
70 assert len(model) == 2
71 assert type(model[0]) == VGG4Segmentation # backbone
72 assert type(model[1]) == DRIUPIX # head
75def test_unet():
77 from ..models.unet import UNet, unet
79 model = unet(pretrained_backbone=True, progress=True)
80 assert len(model) == 3
81 assert type(model[0]) == TorchVisionNormalizer
82 assert type(model[1]) == VGG4Segmentation # backbone
83 assert type(model[2]) == UNet # head
85 model = unet(pretrained_backbone=False)
86 assert len(model) == 2
87 assert type(model[0]) == VGG4Segmentation # backbone
88 assert type(model[1]) == UNet # head
91def test_hed():
93 from ..models.hed import HED, hed
95 model = hed(pretrained_backbone=True, progress=True)
96 assert len(model) == 3
97 assert type(model[0]) == TorchVisionNormalizer
98 assert type(model[1]) == VGG4Segmentation # backbone
99 assert type(model[2]) == HED # head
101 model = hed(pretrained_backbone=False)
102 assert len(model) == 2
103 assert type(model[0]) == VGG4Segmentation # backbone
104 assert type(model[1]) == HED # head
107def test_m2unet():
109 from ..models.backbones.mobilenetv2 import MobileNetV24Segmentation
110 from ..models.m2unet import M2UNet, m2unet
112 model = m2unet(pretrained_backbone=True, progress=True)
113 assert len(model) == 3
114 assert type(model[0]) == TorchVisionNormalizer
115 assert type(model[1]) == MobileNetV24Segmentation # backbone
116 assert type(model[2]) == M2UNet # head
118 model = m2unet(pretrained_backbone=False)
119 assert len(model) == 2
120 assert type(model[0]) == MobileNetV24Segmentation # backbone
121 assert type(model[1]) == M2UNet # head
124def test_resunet50():
126 from ..models.backbones.resnet import ResNet4Segmentation
127 from ..models.resunet import ResUNet, resunet50
129 model = resunet50(pretrained_backbone=True, progress=True)
130 assert len(model) == 3
131 assert type(model[0]) == TorchVisionNormalizer
132 assert type(model[1]) == ResNet4Segmentation # backbone
133 assert type(model[2]) == ResUNet # head
135 model = resunet50(pretrained_backbone=False)
136 assert len(model) == 2
137 assert type(model[0]) == ResNet4Segmentation # backbone
138 assert type(model[1]) == ResUNet # head
139 print(model)