1#!/usr/bin/env python
2# coding=utf-8
3
4
5"""Tests for Montgomery dataset"""
6
7import pytest
8from ..data.montgomery import dataset
9
10def test_protocol_consistency():
11
12 # Default protocol
13 subset = dataset.subsets("default")
14 assert len(subset) == 3
15
16 assert "train" in subset
17 assert len(subset["train"]) == 88
18 for s in subset["train"]:
19 assert s.key.startswith("CXR_png/MCUCXR_0")
20
21 assert "validation" in subset
22 assert len(subset["validation"]) == 22
23 for s in subset["validation"]:
24 assert s.key.startswith("CXR_png/MCUCXR_0")
25
26 assert "test" in subset
27 assert len(subset["test"]) == 28
28 for s in subset["test"]:
29 assert s.key.startswith("CXR_png/MCUCXR_0")
30
31 # Check labels
32 for s in subset["train"]:
33 assert s.label in [0.0, 1.0]
34
35 for s in subset["validation"]:
36 assert s.label in [0.0, 1.0]
37
38 for s in subset["test"]:
39 assert s.label in [0.0, 1.0]
40
41 # Cross-validation fold 0-7
42 for f in range(8):
43 subset = dataset.subsets("fold_"+str(f))
44 assert len(subset) == 3
45
46 assert "train" in subset
47 assert len(subset["train"]) == 99
48 for s in subset["train"]:
49 assert s.key.startswith("CXR_png/MCUCXR_0")
50
51 assert "validation" in subset
52 assert len(subset["validation"]) == 25
53 for s in subset["validation"]:
54 assert s.key.startswith("CXR_png/MCUCXR_0")
55
56 assert "test" in subset
57 assert len(subset["test"]) == 14
58 for s in subset["test"]:
59 assert s.key.startswith("CXR_png/MCUCXR_0")
60
61 # Check labels
62 for s in subset["train"]:
63 assert s.label in [0.0, 1.0]
64
65 for s in subset["validation"]:
66 assert s.label in [0.0, 1.0]
67
68 for s in subset["test"]:
69 assert s.label in [0.0, 1.0]
70
71 # Cross-validation fold 8-9
72 for f in range(8, 10):
73 subset = dataset.subsets("fold_"+str(f))
74 assert len(subset) == 3
75
76 assert "train" in subset
77 assert len(subset["train"]) == 100
78 for s in subset["train"]:
79 assert s.key.startswith("CXR_png/MCUCXR_0")
80
81 assert "validation" in subset
82 assert len(subset["validation"]) == 25
83 for s in subset["validation"]:
84 assert s.key.startswith("CXR_png/MCUCXR_0")
85
86 assert "test" in subset
87 assert len(subset["test"]) == 13
88 for s in subset["test"]:
89 assert s.key.startswith("CXR_png/MCUCXR_0")
90
91 # Check labels
92 for s in subset["train"]:
93 assert s.label in [0.0, 1.0]
94
95 for s in subset["validation"]:
96 assert s.label in [0.0, 1.0]
97
98 for s in subset["test"]:
99 assert s.label in [0.0, 1.0]
100
101@pytest.mark.skip_if_rc_var_not_set('bob.med.tb.montgomery.datadir')
102def test_loading():
103
104 image_size_portrait = (4020, 4892)
105 image_size_landscape = (4892, 4020)
106
107 def _check_size(size):
108 if size == image_size_portrait:
109 return True
110 elif size == image_size_landscape:
111 return True
112 return False
113
114 def _check_sample(s):
115
116 data = s.data
117 assert isinstance(data, dict)
118 assert len(data) == 2
119
120 assert "data" in data
121 assert _check_size(data["data"].size) # Check size
122 assert data["data"].mode == "L" # Check colors
123
124 assert "label" in data
125 assert data["label"] in [0, 1] # Check labels
126
127 limit = 30 #use this to limit testing to first images only, else None
128
129 subset = dataset.subsets("default")
130 for s in subset["train"][:limit]:
131 _check_sample(s)
132
133@pytest.mark.skip_if_rc_var_not_set('bob.med.tb.montgomery.datadir')
134def test_check():
135 assert dataset.check() == 0