1#!/usr/bin/env python
2# coding=utf-8
3
4import torch
5import pytest
6from torch.utils.data import ConcatDataset
7import numpy as np
8import contextlib
9from bob.extension import rc
10
11from ..configs.datasets import get_samples_weights, get_positive_weights
12
13from . import mock_dataset
14
15# Download test data and get their location if needed
16montgomery_datadir = mock_dataset()
17
18# we only iterate over the first N elements at most - dataset loading has
19# already been checked on the individual datset tests. Here, we are only
20# testing for the extra tools wrapping the dataset
21N = 10
22
23
24@contextlib.contextmanager
25def rc_context(**new_config):
26 old_rc = rc.copy()
27 rc.update(new_config)
28 try:
29 yield
30 finally:
31 rc.clear()
32 rc.update(old_rc)
33
34
35@pytest.mark.skip_if_rc_var_not_set("bob.med.tb.montgomery.datadir")
36def test_montgomery():
37
38 def _check_subset(samples, size):
39 assert len(samples) == size
40 for s in samples[:N]:
41 assert len(s) == 3
42 assert isinstance(s[0], str) #key
43 assert s[1].shape == (1, 512, 512) #planes, height, width
44 assert s[1].dtype == torch.float32
45 assert isinstance(s[2], int) #label
46 assert s[1].max() <= 1.0
47 assert s[1].min() >= 0.0
48
49 from ..configs.datasets.montgomery.default import dataset
50
51 assert len(dataset) == 4
52 _check_subset(dataset["__train__"], 110)
53 _check_subset(dataset["__valid__"], 110)
54 _check_subset(dataset["train"], 110)
55 _check_subset(dataset["test"], 28)
56
57def test_get_samples_weights():
58
59 # Temporarily modify Montgomery datadir
60 new_value = {"bob.med.tb.montgomery.datadir": montgomery_datadir}
61 with rc_context(**new_value):
62
63 from ..configs.datasets.montgomery.default import dataset
64
65 train_samples_weights = get_samples_weights(dataset['__train__']).numpy()
66
67 unique, counts = np.unique(train_samples_weights, return_counts=True)
68
69 np.testing.assert_equal(counts, np.array([51, 37]))
70 np.testing.assert_equal(unique, np.array(1 / counts, dtype=np.float32))
71
72@pytest.mark.skip_if_rc_var_not_set('bob.med.tb.nih_cxr14_re.datadir')
73def test_get_samples_weights_multi():
74
75 from ..configs.datasets.nih_cxr14_re.default import dataset
76
77 train_samples_weights = get_samples_weights(dataset['__train__']).numpy()
78
79 np.testing.assert_equal(
80 train_samples_weights,
81 np.ones(len(dataset['__train__']))
82 )
83
84def test_get_samples_weights_concat():
85
86 # Temporarily modify Montgomery datadir
87 new_value = {"bob.med.tb.montgomery.datadir": montgomery_datadir}
88 with rc_context(**new_value):
89
90 from ..configs.datasets.montgomery.default import dataset
91
92 train_dataset = ConcatDataset((dataset['__train__'], dataset['__train__']))
93
94 train_samples_weights = get_samples_weights(train_dataset).numpy()
95
96 unique, counts = np.unique(train_samples_weights, return_counts=True)
97
98 np.testing.assert_equal(counts, np.array([102, 74]))
99 np.testing.assert_equal(unique, np.array(2 / counts, dtype=np.float32))
100
101@pytest.mark.skip_if_rc_var_not_set('bob.med.tb.nih_cxr14_re.datadir')
102def test_get_samples_weights_multi_concat():
103
104 from ..configs.datasets.nih_cxr14_re.default import dataset
105
106 train_dataset = ConcatDataset((dataset['__train__'], dataset['__train__']))
107
108 train_samples_weights = get_samples_weights(train_dataset).numpy()
109
110 ref_samples_weights = np.concatenate((
111 torch.full((len(dataset['__train__']),), 1. / len(dataset['__train__'])),
112 torch.full((len(dataset['__train__']),), 1. / len(dataset['__train__'])),
113 ))
114
115 np.testing.assert_equal(train_samples_weights, ref_samples_weights)
116
117def test_get_positive_weights():
118
119 # Temporarily modify Montgomery datadir
120 new_value = {"bob.med.tb.montgomery.datadir": montgomery_datadir}
121 with rc_context(**new_value):
122
123 from ..configs.datasets.montgomery.default import dataset
124
125 train_positive_weights = get_positive_weights(dataset['__train__']).numpy()
126
127 np.testing.assert_equal(
128 train_positive_weights,
129 np.array([51.0/37.0],
130 dtype=np.float32)
131 )
132
133@pytest.mark.skip_if_rc_var_not_set('bob.med.tb.nih_cxr14_re.datadir')
134def test_get_positive_weights_multi():
135
136 from ..configs.datasets.nih_cxr14_re.default import dataset
137
138 train_positive_weights = get_positive_weights(dataset['__train__']).numpy()
139 valid_positive_weights = get_positive_weights(dataset['__valid__']).numpy()
140
141 assert torch.all(
142 torch.eq(
143 torch.FloatTensor(np.around(train_positive_weights, 4)),
144 torch.FloatTensor(
145 np.around(
146 [
147 0.9195434,
148 0.9462068,
149 0.8070095,
150 0.94879204,
151 0.767055,
152 0.8944615,
153 0.88212335,
154 0.8227136,
155 0.8943905,
156 0.8864118,
157 0.90026057,
158 0.8888551,
159 0.884739,
160 0.84540284,
161 ],
162 4,
163 )
164 ),
165 )
166 )
167
168 assert torch.all(
169 torch.eq(
170 torch.FloatTensor(np.around(valid_positive_weights, 4)),
171 torch.FloatTensor(
172 np.around(
173 [
174 0.9366929,
175 0.9535433,
176 0.79543304,
177 0.9530709,
178 0.74834645,
179 0.88708663,
180 0.86661416,
181 0.81496066,
182 0.89480317,
183 0.8888189,
184 0.8933858,
185 0.89795274,
186 0.87181103,
187 0.8266142,
188 ],
189 4,
190 )
191 ),
192 )
193 )
194
195def test_get_positive_weights_concat():
196
197 # Temporarily modify Montgomery datadir
198 new_value = {"bob.med.tb.montgomery.datadir": montgomery_datadir}
199 with rc_context(**new_value):
200
201 from ..configs.datasets.montgomery.default import dataset
202
203 train_dataset = ConcatDataset((dataset['__train__'], dataset['__train__']))
204
205 train_positive_weights = get_positive_weights(train_dataset).numpy()
206
207 np.testing.assert_equal(
208 train_positive_weights,
209 np.array([51.0/37.0],
210 dtype=np.float32)
211 )
212
213@pytest.mark.skip_if_rc_var_not_set('bob.med.tb.nih_cxr14_re.datadir')
214def test_get_positive_weights_multi_concat():
215
216 from ..configs.datasets.nih_cxr14_re.default import dataset
217
218 train_dataset = ConcatDataset((dataset['__train__'], dataset['__train__']))
219 valid_dataset = ConcatDataset((dataset['__valid__'], dataset['__valid__']))
220
221 train_positive_weights = get_positive_weights(train_dataset).numpy()
222 valid_positive_weights = get_positive_weights(valid_dataset).numpy()
223
224 assert torch.all(
225 torch.eq(
226 torch.FloatTensor(np.around(train_positive_weights, 4)),
227 torch.FloatTensor(
228 np.around(
229 [
230 0.9195434,
231 0.9462068,
232 0.8070095,
233 0.94879204,
234 0.767055,
235 0.8944615,
236 0.88212335,
237 0.8227136,
238 0.8943905,
239 0.8864118,
240 0.90026057,
241 0.8888551,
242 0.884739,
243 0.84540284,
244 ],
245 4,
246 )
247 ),
248 )
249 )
250
251 assert torch.all(
252 torch.eq(
253 torch.FloatTensor(np.around(valid_positive_weights, 4)),
254 torch.FloatTensor(
255 np.around(
256 [
257 0.9366929,
258 0.9535433,
259 0.79543304,
260 0.9530709,
261 0.74834645,
262 0.88708663,
263 0.86661416,
264 0.81496066,
265 0.89480317,
266 0.8888189,
267 0.8933858,
268 0.89795274,
269 0.87181103,
270 0.8266142,
271 ],
272 4,
273 )
274 ),
275 )
276 )