1#!/usr/bin/env python
2# coding=utf-8
3
4
5"""Tests for TB-POC dataset"""
6
7import pytest
8from ..data.tbpoc import dataset
9
10
11def test_protocol_consistency():
12
13 # Cross-validation fold 0-6
14 for f in range(7):
15 subset = dataset.subsets("fold_" + str(f))
16 assert len(subset) == 3
17
18 assert "train" in subset
19 assert len(subset["train"]) == 292
20 for s in subset["train"]:
21 assert s.key.upper().startswith("TBPOC_CXR/TBPOC-")
22
23 assert "validation" in subset
24 assert len(subset["validation"]) == 74
25 for s in subset["validation"]:
26 assert s.key.upper().startswith("TBPOC_CXR/TBPOC-")
27
28 assert "test" in subset
29 assert len(subset["test"]) == 41
30 for s in subset["test"]:
31 assert s.key.upper().startswith("TBPOC_CXR/TBPOC-")
32
33 # Check labels
34 for s in subset["train"]:
35 assert s.label in [0.0, 1.0]
36
37 for s in subset["validation"]:
38 assert s.label in [0.0, 1.0]
39
40 for s in subset["test"]:
41 assert s.label in [0.0, 1.0]
42
43 # Cross-validation fold 7-9
44 for f in range(7, 10):
45 subset = dataset.subsets("fold_" + str(f))
46 assert len(subset) == 3
47
48 assert "train" in subset
49 assert len(subset["train"]) == 293
50 for s in subset["train"]:
51 assert s.key.upper().startswith("TBPOC_CXR/TBPOC-")
52
53 assert "validation" in subset
54 assert len(subset["validation"]) == 74
55 for s in subset["validation"]:
56 assert s.key.upper().startswith("TBPOC_CXR/TBPOC-")
57
58 assert "test" in subset
59 assert len(subset["test"]) == 40
60 for s in subset["test"]:
61 assert s.key.upper().startswith("TBPOC_CXR/TBPOC-")
62
63 # Check labels
64 for s in subset["train"]:
65 assert s.label in [0.0, 1.0]
66
67 for s in subset["validation"]:
68 assert s.label in [0.0, 1.0]
69
70 for s in subset["test"]:
71 assert s.label in [0.0, 1.0]
72
73
74@pytest.mark.skip_if_rc_var_not_set("bob.med.tb.tbpoc.datadir")
75def test_loading():
76
77 image_size_portrait = (2048, 2500)
78 image_size_landscape = (2500, 2048)
79
80 def _check_size(size):
81 if size == image_size_portrait:
82 return True
83 elif size == image_size_landscape:
84 return True
85 return False
86
87 def _check_sample(s):
88
89 data = s.data
90 assert isinstance(data, dict)
91 assert len(data) == 2
92
93 assert "data" in data
94 assert _check_size(data["data"].size) # Check size
95 assert data["data"].mode, "L" # Check colors
96
97 assert "label" in data
98 assert data["label"] in [0, 1] # Check labels
99
100 limit = 30 # use this to limit testing to first images only, else None
101
102 subset = dataset.subsets("fold_0")
103 for s in subset["train"][:limit]:
104 _check_sample(s)
105
106
107@pytest.mark.skip_if_rc_var_not_set("bob.med.tb.tbpoc.datadir")
108def test_check():
109 assert dataset.check() == 0