1#!/usr/bin/env python
2# coding=utf-8
3
4
5"""Tests for Shenzhen dataset"""
6
7from ..data.shenzhen import dataset
8import pytest
9
10
11def test_protocol_consistency():
12
13 # Default protocol
14 subset = dataset.subsets("default")
15 assert len(subset) == 3
16
17 assert "train" in subset
18 assert len(subset["train"]) == 422
19 for s in subset["train"]:
20 assert s.key.startswith("CXR_png/CHNCXR_0")
21
22 assert "validation" in subset
23 assert len(subset["validation"]) == 107
24 for s in subset["validation"]:
25 assert s.key.startswith("CXR_png/CHNCXR_0")
26
27 assert "test" in subset
28 assert len(subset["test"]) == 133
29 for s in subset["test"]:
30 assert s.key.startswith("CXR_png/CHNCXR_0")
31
32 # Check labels
33 for s in subset["train"]:
34 assert s.label in [0.0, 1.0]
35
36 for s in subset["validation"]:
37 assert s.label in [0.0, 1.0]
38
39 for s in subset["test"]:
40 assert s.label in [0.0, 1.0]
41
42 # Cross-validation folds 0-1
43 for f in range(2):
44 subset = dataset.subsets("fold_" + str(f))
45 assert len(subset) == 3
46
47 assert "train" in subset
48 assert len(subset["train"]) == 476
49 for s in subset["train"]:
50 assert s.key.startswith("CXR_png/CHNCXR_0")
51
52 assert "validation" in subset
53 assert len(subset["validation"]) == 119
54 for s in subset["validation"]:
55 assert s.key.startswith("CXR_png/CHNCXR_0")
56
57 assert "test" in subset
58 assert len(subset["test"]) == 67
59 for s in subset["test"]:
60 assert s.key.startswith("CXR_png/CHNCXR_0")
61
62 # Check labels
63 for s in subset["train"]:
64 assert s.label in [0.0, 1.0]
65
66 for s in subset["validation"]:
67 assert s.label in [0.0, 1.0]
68
69 for s in subset["test"]:
70 assert s.label in [0.0, 1.0]
71
72 # Cross-validation folds 2-9
73 for f in range(2, 10):
74 subset = dataset.subsets("fold_" + str(f))
75 assert len(subset) == 3
76
77 assert "train" in subset
78 assert len(subset["train"]) == 476
79 for s in subset["train"]:
80 assert s.key.startswith("CXR_png/CHNCXR_0")
81
82 assert "validation" in subset
83 assert len(subset["validation"]) == 120
84 for s in subset["validation"]:
85 assert s.key.startswith("CXR_png/CHNCXR_0")
86
87 assert "test" in subset
88 assert len(subset["test"]) == 66
89 for s in subset["test"]:
90 assert s.key.startswith("CXR_png/CHNCXR_0")
91
92 # Check labels
93 for s in subset["train"]:
94 assert s.label in [0.0, 1.0]
95
96 for s in subset["validation"]:
97 assert s.label in [0.0, 1.0]
98
99 for s in subset["test"]:
100 assert s.label in [0.0, 1.0]
101
102
103@pytest.mark.skip_if_rc_var_not_set("bob.med.tb.shenzhen.datadir")
104def test_loading():
105 def _check_size(size):
106 if (
107 size[0] >= 1130
108 and size[0] <= 3001
109 and size[1] >= 948
110 and size[1] <= 3001
111 ):
112 return True
113 return False
114
115 def _check_sample(s):
116
117 data = s.data
118 assert isinstance(data, dict)
119 assert len(data) == 2
120
121 assert "data" in data
122 assert _check_size(data["data"].size) # Check size
123 assert data["data"].mode == "L" # Check colors
124
125 assert "label" in data
126 assert data["label"] in [0, 1] # Check labels
127
128 limit = 30 # use this to limit testing to first images only, else None
129
130 subset = dataset.subsets("default")
131 for s in subset["train"][:limit]:
132 _check_sample(s)
133
134
135@pytest.mark.skip_if_rc_var_not_set("bob.med.tb.shenzhen.datadir")
136def test_check():
137 assert dataset.check() == 0