1#!/usr/bin/env python
2# coding=utf-8
3
4
5"""Tests for Extended Indian dataset"""
6
7from ..data.indian_RS import dataset
8
9
10def test_protocol_consistency():
11
12 # Default protocol
13 subset = dataset.subsets("default")
14 assert len(subset) == 3
15
16 assert "train" in subset
17 assert len(subset["train"]) == 83
18
19 assert "validation" in subset
20 assert len(subset["validation"]) == 20
21
22 assert "test" in subset
23 assert len(subset["test"]) == 52
24
25 # Check labels
26 for s in subset["train"]:
27 assert s.label in [0.0, 1.0]
28
29 for s in subset["validation"]:
30 assert s.label in [0.0, 1.0]
31
32 for s in subset["test"]:
33 assert s.label in [0.0, 1.0]
34
35 # Cross-validation fold 0-4
36 for f in range(5):
37 subset = dataset.subsets("fold_" + str(f))
38 assert len(subset) == 3
39
40 assert "train" in subset
41 assert len(subset["train"]) == 111
42 for s in subset["train"]:
43 assert s.key.startswith("DatasetA")
44
45 assert "validation" in subset
46 assert len(subset["validation"]) == 28
47 for s in subset["validation"]:
48 assert s.key.startswith("DatasetA")
49
50 assert "test" in subset
51 assert len(subset["test"]) == 16
52 for s in subset["test"]:
53 assert s.key.startswith("DatasetA")
54
55 # Check labels
56 for s in subset["train"]:
57 assert s.label in [0.0, 1.0]
58
59 for s in subset["validation"]:
60 assert s.label in [0.0, 1.0]
61
62 for s in subset["test"]:
63 assert s.label in [0.0, 1.0]
64
65 # Cross-validation fold 5-9
66 for f in range(5, 10):
67 subset = dataset.subsets("fold_" + str(f))
68 assert len(subset) == 3
69
70 assert "train" in subset
71 assert len(subset["train"]) == 112
72 for s in subset["train"]:
73 assert s.key.startswith("DatasetA")
74
75 assert "validation" in subset
76 assert len(subset["validation"]) == 28
77 for s in subset["validation"]:
78 assert s.key.startswith("DatasetA")
79
80 assert "test" in subset
81 assert len(subset["test"]) == 15
82 for s in subset["test"]:
83 assert s.key.startswith("DatasetA")
84
85 # Check labels
86 for s in subset["train"]:
87 assert s.label in [0.0, 1.0]
88
89 for s in subset["validation"]:
90 assert s.label in [0.0, 1.0]
91
92 for s in subset["test"]:
93 assert s.label in [0.0, 1.0]
94
95
96def test_loading():
97 def _check_sample(s):
98
99 data = s.data
100
101 assert isinstance(data, dict)
102 assert len(data) == 2
103
104 assert "data" in data
105 assert len(data["data"]) == 14 # Check radiological signs
106
107 assert "label" in data
108 assert data["label"] in [0, 1] # Check labels
109
110 limit = 30 # use this to limit testing to first images only, else None
111
112 subset = dataset.subsets("default")
113 for s in subset["train"][:limit]:
114 _check_sample(s)