1#!/usr/bin/env python
2# coding=utf-8
3
4
5"""Tests for HIV-TB_RS dataset"""
6
7from ..data.hivtb_RS import dataset
8
9def test_protocol_consistency():
10
11 # Cross-validation fold 0-2
12 for f in range(3):
13 subset = dataset.subsets("fold_"+str(f))
14 assert len(subset) == 3
15
16 assert "train" in subset
17 assert len(subset["train"]) == 174
18 for s in subset["train"]:
19 assert s.key.startswith("HIV-TB_Algorithm_study_X-rays/")
20
21 assert "validation" in subset
22 assert len(subset["validation"]) == 44
23 for s in subset["validation"]:
24 assert s.key.startswith("HIV-TB_Algorithm_study_X-rays/")
25
26 assert "test" in subset
27 assert len(subset["test"]) == 25
28 for s in subset["test"]:
29 assert s.key.startswith("HIV-TB_Algorithm_study_X-rays/")
30
31 # Check labels
32 for s in subset["train"]:
33 assert s.label in [0.0, 1.0]
34
35 for s in subset["validation"]:
36 assert s.label in [0.0, 1.0]
37
38 for s in subset["test"]:
39 assert s.label in [0.0, 1.0]
40
41 # Cross-validation fold 3-9
42 for f in range(3, 10):
43 subset = dataset.subsets("fold_"+str(f))
44 assert len(subset) == 3
45
46 assert "train" in subset
47 assert len(subset["train"]) == 175
48 for s in subset["train"]:
49 assert s.key.startswith("HIV-TB_Algorithm_study_X-rays/")
50
51 assert "validation" in subset
52 assert len(subset["validation"]) == 44
53 for s in subset["validation"]:
54 assert s.key.startswith("HIV-TB_Algorithm_study_X-rays/")
55
56 assert "test" in subset
57 assert len(subset["test"]) == 24
58 for s in subset["test"]:
59 assert s.key.startswith("HIV-TB_Algorithm_study_X-rays/")
60
61 # Check labels
62 for s in subset["train"]:
63 assert s.label in [0.0, 1.0]
64
65 for s in subset["validation"]:
66 assert s.label in [0.0, 1.0]
67
68 for s in subset["test"]:
69 assert s.label in [0.0, 1.0]
70
71def test_loading():
72
73 def _check_sample(s):
74
75 data = s.data
76
77 assert isinstance(data, dict)
78 assert len(data) == 2
79
80 assert "data" in data
81 assert len(data["data"]) == 14 # Check radiological signs
82
83 assert "label" in data
84 assert data["label"] in [0, 1] # Check labels
85
86 limit = 30 #use this to limit testing to first images only, else None
87
88 subset = dataset.subsets("fold_0")
89 for s in subset["train"][:limit]:
90 _check_sample(s)