1#!/usr/bin/env python
2# coding=utf-8
3
4
5"""Tests for TB-POC_RS dataset"""
6
7from ..data.tbpoc_RS import dataset
8
9
10def test_protocol_consistency():
11
12 # Cross-validation fold 0-6
13 for f in range(7):
14 subset = dataset.subsets("fold_" + str(f))
15 assert len(subset) == 3
16
17 assert "train" in subset
18 assert len(subset["train"]) == 292
19 for s in subset["train"]:
20 assert s.key.upper().startswith("TBPOC_CXR/TBPOC-")
21
22 assert "validation" in subset
23 assert len(subset["validation"]) == 74
24 for s in subset["validation"]:
25 assert s.key.upper().startswith("TBPOC_CXR/TBPOC-")
26
27 assert "test" in subset
28 assert len(subset["test"]) == 41
29 for s in subset["test"]:
30 assert s.key.upper().startswith("TBPOC_CXR/TBPOC-")
31
32 # Check labels
33 for s in subset["train"]:
34 assert s.label in [0.0, 1.0]
35
36 for s in subset["validation"]:
37 assert s.label in [0.0, 1.0]
38
39 for s in subset["test"]:
40 assert s.label in [0.0, 1.0]
41
42 # Cross-validation fold 7-9
43 for f in range(7, 10):
44 subset = dataset.subsets("fold_" + str(f))
45 assert len(subset) == 3
46
47 assert "train" in subset
48 assert len(subset["train"]) == 293
49 for s in subset["train"]:
50 assert s.key.upper().startswith("TBPOC_CXR/TBPOC-")
51
52 assert "validation" in subset
53 assert len(subset["validation"]) == 74
54 for s in subset["validation"]:
55 assert s.key.upper().startswith("TBPOC_CXR/TBPOC-")
56
57 assert "test" in subset
58 assert len(subset["test"]) == 40
59 for s in subset["test"]:
60 assert s.key.upper().startswith("TBPOC_CXR/TBPOC-")
61
62 # Check labels
63 for s in subset["train"]:
64 assert s.label in [0.0, 1.0]
65
66 for s in subset["validation"]:
67 assert s.label in [0.0, 1.0]
68
69 for s in subset["test"]:
70 assert s.label in [0.0, 1.0]
71
72
73def test_loading():
74 def _check_sample(s):
75
76 data = s.data
77
78 assert isinstance(data, dict)
79 assert len(data) == 2
80
81 assert "data" in data
82 assert len(data["data"]) == 14 # Check radiological signs
83
84 assert "label" in data
85 assert data["label"] in [0, 1] # Check labels
86
87 limit = 30 # use this to limit testing to first images only, else None
88
89 subset = dataset.subsets("fold_0")
90 for s in subset["train"][:limit]:
91 _check_sample(s)