1#!/usr/bin/env python
2# coding=utf-8
3
4
5"""Tests for HIV-TB dataset"""
6
7import pytest
8from ..data.hivtb import dataset
9
10
11def test_protocol_consistency():
12
13 # Cross-validation fold 0-2
14 for f in range(3):
15 subset = dataset.subsets("fold_" + str(f))
16 assert len(subset) == 3
17
18 assert "train" in subset
19 assert len(subset["train"]) == 174
20 for s in subset["train"]:
21 assert s.key.startswith("HIV-TB_Algorithm_study_X-rays/")
22
23 assert "validation" in subset
24 assert len(subset["validation"]) == 44
25 for s in subset["validation"]:
26 assert s.key.startswith("HIV-TB_Algorithm_study_X-rays/")
27
28 assert "test" in subset
29 assert len(subset["test"]) == 25
30 for s in subset["test"]:
31 assert s.key.startswith("HIV-TB_Algorithm_study_X-rays/")
32
33 # Check labels
34 for s in subset["train"]:
35 assert s.label in [0.0, 1.0]
36
37 for s in subset["validation"]:
38 assert s.label in [0.0, 1.0]
39
40 for s in subset["test"]:
41 assert s.label in [0.0, 1.0]
42
43 # Cross-validation fold 3-9
44 for f in range(3, 10):
45 subset = dataset.subsets("fold_" + str(f))
46 assert len(subset) == 3
47
48 assert "train" in subset
49 assert len(subset["train"]) == 175
50 for s in subset["train"]:
51 assert s.key.startswith("HIV-TB_Algorithm_study_X-rays/")
52
53 assert "validation" in subset
54 assert len(subset["validation"]) == 44
55 for s in subset["validation"]:
56 assert s.key.startswith("HIV-TB_Algorithm_study_X-rays/")
57
58 assert "test" in subset
59 assert len(subset["test"]) == 24
60 for s in subset["test"]:
61 assert s.key.startswith("HIV-TB_Algorithm_study_X-rays/")
62
63 # Check labels
64 for s in subset["train"]:
65 assert s.label in [0.0, 1.0]
66
67 for s in subset["validation"]:
68 assert s.label in [0.0, 1.0]
69
70 for s in subset["test"]:
71 assert s.label in [0.0, 1.0]
72
73
74@pytest.mark.skip_if_rc_var_not_set("bob.med.tb.hivtb.datadir")
75def test_loading():
76
77 image_size_portrait = (2048, 2500)
78 image_size_landscape = (2500, 2048)
79
80 def _check_size(size):
81 if size == image_size_portrait:
82 return True
83 elif size == image_size_landscape:
84 return True
85 return False
86
87 def _check_sample(s):
88
89 data = s.data
90 assert isinstance(data, dict)
91 assert len(data) == 2
92
93 assert "data" in data
94 assert _check_size(data["data"].size) # Check size
95 assert data["data"].mode == "L" # Check colors
96
97 assert "label" in data
98 assert data["label"] in [0, 1] # Check labels
99
100 limit = 30 # use this to limit testing to first images only, else None
101
102 subset = dataset.subsets("fold_0")
103 for s in subset["train"][:limit]:
104 _check_sample(s)
105
106
107@pytest.mark.skip_if_rc_var_not_set("bob.med.tb.hivtb.datadir")
108def test_check():
109 assert dataset.check() == 0