1#!/usr/bin/env python
2# coding=utf-8
3
4
5"""Tests for Extended Shenzhen dataset"""
6
7from ..data.shenzhen_RS import dataset
8
9
10def test_protocol_consistency():
11
12 # Default protocol
13 subset = dataset.subsets("default")
14 assert len(subset) == 3
15
16 assert "train" in subset
17 assert len(subset["train"]) == 422
18
19 assert "validation" in subset
20 assert len(subset["validation"]) == 107
21
22 assert "test" in subset
23 assert len(subset["test"]) == 133
24 for s in subset["test"]:
25 assert s.key.startswith("CXR_png/CHNCXR_0")
26
27 # Check labels
28 for s in subset["train"]:
29 assert s.label in [0.0, 1.0]
30
31 for s in subset["validation"]:
32 assert s.label in [0.0, 1.0]
33
34 for s in subset["test"]:
35 assert s.label in [0.0, 1.0]
36
37 # Cross-validation folds 0-1
38 for f in range(2):
39 subset = dataset.subsets("fold_" + str(f))
40 assert len(subset) == 3
41
42 assert "train" in subset
43 assert len(subset["train"]) == 476
44 for s in subset["train"]:
45 assert s.key.startswith("CXR_png/CHNCXR_0")
46
47 assert "validation" in subset
48 assert len(subset["validation"]) == 119
49 for s in subset["validation"]:
50 assert s.key.startswith("CXR_png/CHNCXR_0")
51
52 assert "test" in subset
53 assert len(subset["test"]) == 67
54 for s in subset["test"]:
55 assert s.key.startswith("CXR_png/CHNCXR_0")
56
57 # Check labels
58 for s in subset["train"]:
59 assert s.label in [0.0, 1.0]
60
61 for s in subset["validation"]:
62 assert s.label in [0.0, 1.0]
63
64 for s in subset["test"]:
65 assert s.label in [0.0, 1.0]
66
67 # Cross-validation folds 2-9
68 for f in range(2, 10):
69 subset = dataset.subsets("fold_" + str(f))
70 assert len(subset) == 3
71
72 assert "train" in subset
73 assert len(subset["train"]) == 476
74 for s in subset["train"]:
75 assert s.key.startswith("CXR_png/CHNCXR_0")
76
77 assert "validation" in subset
78 assert len(subset["validation"]) == 120
79 for s in subset["validation"]:
80 assert s.key.startswith("CXR_png/CHNCXR_0")
81
82 assert "test" in subset
83 assert len(subset["test"]) == 66
84 for s in subset["test"]:
85 assert s.key.startswith("CXR_png/CHNCXR_0")
86
87 # Check labels
88 for s in subset["train"]:
89 assert s.label in [0.0, 1.0]
90
91 for s in subset["validation"]:
92 assert s.label in [0.0, 1.0]
93
94 for s in subset["test"]:
95 assert s.label in [0.0, 1.0]
96
97
98def test_loading():
99 def _check_sample(s):
100
101 data = s.data
102
103 assert isinstance(data, dict)
104 assert len(data) == 2
105
106 assert "data" in data
107 assert len(data["data"]) == 14 # Check radiological signs
108
109 assert "label" in data
110 assert data["label"] in [0, 1] # Check labels
111
112 limit = 30 # use this to limit testing to first images only, else None
113
114 subset = dataset.subsets("default")
115 for s in subset["train"][:limit]:
116 _check_sample(s)