1#!/usr/bin/env python
2# coding=utf-8
3
4
5"""Tests for NIH CXR14 dataset"""
6
7from ..data.nih_cxr14_re import dataset
8import pytest
9
10
11def test_protocol_consistency():
12
13 # Default protocol
14 subset = dataset.subsets("default")
15 assert len(subset) == 3
16
17 assert "train" in subset
18 assert len(subset["train"]) == 98637
19 for s in subset["train"]:
20 assert s.key.startswith("images/000")
21
22 assert "validation" in subset
23 assert len(subset["validation"]) == 6350
24 for s in subset["validation"]:
25 assert s.key.startswith("images/000")
26
27 assert "test" in subset
28 assert len(subset["test"]) == 4054
29 for s in subset["test"]:
30 assert s.key.startswith("images/000")
31
32 # Check labels
33 for s in subset["train"]:
34 for l in list(set(s.label)):
35 assert l in [0.0, 1.0]
36
37 for s in subset["validation"]:
38 for l in list(set(s.label)):
39 assert l in [0.0, 1.0]
40
41 for s in subset["test"]:
42 for l in list(set(s.label)):
43 assert l in [0.0, 1.0]
44
45 # Idiap protocol
46 subset = dataset.subsets("idiap")
47 assert len(subset) == 3
48
49 assert "train" in subset
50 assert len(subset["train"]) == 98637
51 for s in subset["train"]:
52 assert s.key.startswith("images/000")
53
54 assert "validation" in subset
55 assert len(subset["validation"]) == 6350
56 for s in subset["validation"]:
57 assert s.key.startswith("images/000")
58
59 assert "test" in subset
60 assert len(subset["test"]) == 4054
61 for s in subset["test"]:
62 assert s.key.startswith("images/000")
63
64 # Check labels
65 for s in subset["train"]:
66 for l in list(set(s.label)):
67 assert l in [0.0, 1.0]
68
69 for s in subset["validation"]:
70 for l in list(set(s.label)):
71 assert l in [0.0, 1.0]
72
73 for s in subset["test"]:
74 for l in list(set(s.label)):
75 assert l in [0.0, 1.0]
76
77
78@pytest.mark.skip_if_rc_var_not_set("bob.med.tb.nih_cxr14_re.datadir")
79def test_loading():
80 def _check_size(size):
81 if size == (1024, 1024):
82 return True
83 return False
84
85 def _check_sample(s):
86
87 data = s.data
88 assert isinstance(data, dict)
89 assert len(data) == 2
90
91 assert "data" in data
92 assert _check_size(data["data"].size) # Check size
93 assert data["data"].mode == "RGB" # Check colors
94
95 assert "label" in data
96 assert len(data["label"]) == 14 # Check labels
97
98 limit = 30 # use this to limit testing to first images only, else None
99
100 subset = dataset.subsets("default")
101 for s in subset["train"][:limit]:
102 _check_sample(s)