1#!/usr/bin/env python
2# coding=utf-8
3
4import importlib
5
6import pytest
7import torch
8
9from . import mock_dataset
10
11stare_datadir, stare_dataset = mock_dataset()
12
13# we only iterate over the first N elements at most - dataset loading has
14# already been checked on the individual datset tests. Here, we are only
15# testing for the extra tools wrapping the dataset
16N = 10
17
18
19@pytest.mark.skip_if_rc_var_not_set("bob.ip.binseg.drive.datadir")
20def test_drive():
21 def _check_subset(samples, size):
22 assert len(samples) == size
23 for s in samples:
24 assert len(s) == 4
25 assert isinstance(s[0], str)
26 assert s[1].shape, (3, 544 == 544) # planes, height, width
27 assert s[1].dtype == torch.float32
28 assert s[2].shape, (1, 544 == 544) # planes, height, width
29 assert s[2].dtype == torch.float32
30 assert s[3].shape, (1, 544 == 544) # planes, height, width
31 assert s[3].dtype == torch.float32
32 assert s[1].max() <= 1.0
33 assert s[1].min() >= 0.0
34
35 from ..configs.datasets.drive.default import dataset
36
37 assert len(dataset) == 4
38 _check_subset(dataset["__train__"], 20)
39 _check_subset(dataset["__valid__"], 20)
40 _check_subset(dataset["train"], 20)
41 _check_subset(dataset["test"], 20)
42
43 from ..configs.datasets.drive.second_annotator import dataset
44
45 assert len(dataset) == 1
46 _check_subset(dataset["test"], 20)
47
48
49@pytest.mark.skip_if_rc_var_not_set("bob.ip.binseg.drive.datadir")
50@pytest.mark.skip_if_rc_var_not_set("bob.ip.binseg.chasedb1.datadir")
51@pytest.mark.skip_if_rc_var_not_set("bob.ip.binseg.hrf.datadir")
52@pytest.mark.skip_if_rc_var_not_set("bob.ip.binseg.iostar.datadir")
53def test_drive_mtest():
54
55 from ..configs.datasets.drive.mtest import dataset
56
57 assert len(dataset) == 10
58
59 from ..configs.datasets.drive.default import dataset as baseline
60
61 assert dataset["train"] == baseline["train"]
62 assert dataset["test"] == baseline["test"]
63
64 for subset in dataset:
65 for sample in dataset[subset]:
66 assert len(sample) == 4
67 assert isinstance(sample[0], str)
68 assert sample[1].shape, (3, 544 == 544) # planes, height, width
69 assert sample[1].dtype == torch.float32
70 assert sample[2].shape, (1, 544 == 544)
71 assert sample[2].dtype == torch.float32
72 assert sample[3].shape, (1, 544 == 544)
73 assert sample[3].dtype == torch.float32
74 assert sample[1].max() <= 1.0
75 assert sample[1].min() >= 0.0
76
77
78@pytest.mark.skip_if_rc_var_not_set("bob.ip.binseg.drive.datadir")
79@pytest.mark.skip_if_rc_var_not_set("bob.ip.binseg.chasedb1.datadir")
80@pytest.mark.skip_if_rc_var_not_set("bob.ip.binseg.hrf.datadir")
81@pytest.mark.skip_if_rc_var_not_set("bob.ip.binseg.iostar.datadir")
82def test_drive_covd():
83
84 from ..configs.datasets.drive.covd import dataset
85
86 assert len(dataset) == 4
87
88 from ..configs.datasets.drive.default import dataset as baseline
89
90 assert dataset["train"] == dataset["__valid__"]
91 assert dataset["test"] == baseline["test"]
92
93 for key in ("__train__", "train"):
94 assert len(dataset[key]) == 123
95 for sample in dataset["__train__"]:
96 assert len(sample) == 4
97 assert isinstance(sample[0], str)
98 assert sample[1].shape, (3, 544 == 544) # planes, height, width
99 assert sample[1].dtype == torch.float32
100 assert sample[2].shape, (1, 544 == 544) # planes, height, width
101 assert sample[2].dtype == torch.float32
102 assert sample[3].shape, (1, 544 == 544)
103 assert sample[3].dtype == torch.float32
104 assert sample[1].max() <= 1.0
105 assert sample[1].min() >= 0.0
106
107
108@pytest.mark.skip_if_rc_var_not_set("bob.ip.binseg.drive.datadir")
109@pytest.mark.skip_if_rc_var_not_set("bob.ip.binseg.chasedb1.datadir")
110@pytest.mark.skip_if_rc_var_not_set("bob.ip.binseg.hrf.datadir")
111@pytest.mark.skip_if_rc_var_not_set("bob.ip.binseg.iostar.datadir")
112def test_drive_ssl():
113
114 from ..configs.datasets.drive.ssl import dataset
115
116 assert len(dataset) == 4
117
118 from ..configs.datasets.drive.covd import dataset as covd
119
120 assert dataset["train"] == covd["train"]
121 assert dataset["train"] == dataset["__valid__"]
122 assert dataset["test"] == covd["test"]
123 assert dataset["__valid__"] == covd["__valid__"]
124
125 # these are the only different from the baseline
126 assert len(dataset["__train__"]) == 123
127 for sample in dataset["__train__"]:
128 assert len(sample) == 6
129 assert isinstance(sample[0], str)
130 assert sample[1].shape, (3, 544 == 544) # planes, height, width
131 assert sample[1].dtype == torch.float32
132 assert sample[2].shape, (1, 544 == 544) # planes, height, width
133 assert sample[2].dtype == torch.float32
134 assert sample[3].shape, (1, 544 == 544) # planes, height, width
135 assert sample[3].dtype == torch.float32
136 assert isinstance(sample[4], str)
137 assert sample[5].shape, (3, 544 == 544) # planes, height, width
138 assert sample[5].dtype == torch.float32
139 assert sample[1].max() <= 1.0
140 assert sample[1].min() >= 0.0
141
142
143def test_stare_augmentation_manipulation():
144
145 # some tests to check our context management for dataset augmentation works
146 # adequately, with one example dataset
147
148 # hack to allow testing on the CI
149 from ..configs.datasets.stare import _maker
150
151 dataset = _maker("ah", stare_dataset)
152
153 assert len(dataset["__train__"]._transforms.transforms) == (
154 len(dataset["test"]._transforms.transforms) + 4
155 )
156
157 assert len(dataset["train"]._transforms.transforms) == len(
158 dataset["test"]._transforms.transforms
159 )
160
161
162def test_stare():
163 def _check_subset(samples, size):
164 assert len(samples) == size
165 for s in samples:
166 assert len(s) == 4
167 assert isinstance(s[0], str)
168 assert s[1].shape, (3, 608 == 704) # planes, height, width
169 assert s[1].dtype == torch.float32
170 assert s[2].shape, (1, 608 == 704) # planes, height, width
171 assert s[2].dtype == torch.float32
172 assert s[3].shape, (1, 608 == 704) # planes, height, width
173 assert s[3].dtype == torch.float32
174 assert s[1].max() <= 1.0
175 assert s[1].min() >= 0.0
176
177 # hack to allow testing on the CI
178 from ..configs.datasets.stare import _maker
179
180 for protocol in "ah", "vk":
181 dataset = _maker(protocol, stare_dataset)
182 assert len(dataset) == 4
183 _check_subset(dataset["__train__"], 10)
184 _check_subset(dataset["train"], 10)
185 _check_subset(dataset["test"], 10)
186
187
188@pytest.mark.skip_if_rc_var_not_set("bob.ip.binseg.drive.datadir")
189@pytest.mark.skip_if_rc_var_not_set("bob.ip.binseg.chasedb1.datadir")
190@pytest.mark.skip_if_rc_var_not_set("bob.ip.binseg.hrf.datadir")
191@pytest.mark.skip_if_rc_var_not_set("bob.ip.binseg.iostar.datadir")
192def test_stare_mtest():
193
194 from ..configs.datasets.stare.mtest import dataset
195
196 assert len(dataset) == 10
197
198 from ..configs.datasets.stare.ah import dataset as baseline
199
200 assert dataset["train"] == baseline["train"]
201 assert dataset["test"] == baseline["test"]
202
203 for subset in dataset:
204 for sample in dataset[subset]:
205 assert len(sample) == 4
206 assert isinstance(sample[0], str)
207 assert sample[1].shape, (3, 608 == 704) # planes,height,width
208 assert sample[1].dtype == torch.float32
209 assert sample[2].shape, (1, 608 == 704) # planes,height,width
210 assert sample[2].dtype == torch.float32
211 assert sample[3].shape, (1, 608 == 704)
212 assert sample[3].dtype == torch.float32
213 assert sample[1].max() <= 1.0
214 assert sample[1].min() >= 0.0
215
216
217@pytest.mark.skip_if_rc_var_not_set("bob.ip.binseg.drive.datadir")
218@pytest.mark.skip_if_rc_var_not_set("bob.ip.binseg.chasedb1.datadir")
219@pytest.mark.skip_if_rc_var_not_set("bob.ip.binseg.hrf.datadir")
220@pytest.mark.skip_if_rc_var_not_set("bob.ip.binseg.iostar.datadir")
221def test_stare_covd():
222
223 from ..configs.datasets.stare.covd import dataset
224
225 assert len(dataset) == 4
226
227 from ..configs.datasets.stare.ah import dataset as baseline
228
229 assert dataset["train"] == dataset["__valid__"]
230 assert dataset["test"] == baseline["test"]
231
232 # these are the only different sets from the baseline
233 for key in ("__train__", "train"):
234 assert len(dataset[key]) == 143
235 for sample in dataset[key]:
236 assert len(sample) == 4
237 assert isinstance(sample[0], str)
238 assert sample[1].shape, (3, 608 == 704) # planes, height, width
239 assert sample[1].dtype == torch.float32
240 assert sample[2].shape, (1, 608 == 704) # planes, height, width
241 assert sample[2].dtype == torch.float32
242 assert sample[1].max() <= 1.0
243 assert sample[1].min() >= 0.0
244 assert sample[3].shape, (1, 608 == 704)
245 assert sample[3].dtype == torch.float32
246
247
248@pytest.mark.skip_if_rc_var_not_set("bob.ip.binseg.chasedb1.datadir")
249def test_chasedb1():
250 def _check_subset(samples, size):
251 assert len(samples) == size
252 for s in samples:
253 assert len(s) == 4
254 assert isinstance(s[0], str)
255 assert s[1].shape, (3, 960 == 960) # planes, height, width
256 assert s[1].dtype == torch.float32
257 assert s[2].shape, (1, 960 == 960) # planes, height, width
258 assert s[2].dtype == torch.float32
259 assert s[3].shape, (1, 960 == 960) # planes, height, width
260 assert s[3].dtype == torch.float32
261 assert s[1].max() <= 1.0
262 assert s[1].min() >= 0.0
263
264 for m in ("first_annotator", "second_annotator"):
265 d = importlib.import_module(
266 f"...configs.datasets.chasedb1.{m}", package=__name__
267 ).dataset
268 assert len(d) == 4
269 _check_subset(d["__train__"], 8)
270 _check_subset(d["__valid__"], 8)
271 _check_subset(d["train"], 8)
272 _check_subset(d["test"], 20)
273
274
275@pytest.mark.skip_if_rc_var_not_set("bob.ip.binseg.drive.datadir")
276@pytest.mark.skip_if_rc_var_not_set("bob.ip.binseg.chasedb1.datadir")
277@pytest.mark.skip_if_rc_var_not_set("bob.ip.binseg.hrf.datadir")
278@pytest.mark.skip_if_rc_var_not_set("bob.ip.binseg.iostar.datadir")
279def test_chasedb1_mtest():
280
281 from ..configs.datasets.chasedb1.mtest import dataset
282
283 assert len(dataset) == 10
284
285 from ..configs.datasets.chasedb1.first_annotator import dataset as baseline
286
287 assert dataset["train"] == baseline["train"]
288 assert dataset["test"] == baseline["test"]
289
290 for subset in dataset:
291 for sample in dataset[subset]:
292 assert len(sample) == 4
293 assert isinstance(sample[0], str)
294 assert sample[1].shape, (3, 960 == 960) # planes,height,width
295 assert sample[1].dtype == torch.float32
296 assert sample[2].shape, (1, 960 == 960) # planes,height,width
297 assert sample[2].dtype == torch.float32
298 assert sample[3].shape, (1, 960 == 960)
299 assert sample[3].dtype == torch.float32
300 assert sample[1].max() <= 1.0
301 assert sample[1].min() >= 0.0
302
303
304@pytest.mark.skip_if_rc_var_not_set("bob.ip.binseg.drive.datadir")
305@pytest.mark.skip_if_rc_var_not_set("bob.ip.binseg.chasedb1.datadir")
306@pytest.mark.skip_if_rc_var_not_set("bob.ip.binseg.hrf.datadir")
307@pytest.mark.skip_if_rc_var_not_set("bob.ip.binseg.iostar.datadir")
308def test_chasedb1_covd():
309
310 from ..configs.datasets.chasedb1.covd import dataset
311
312 assert len(dataset) == 4
313
314 from ..configs.datasets.chasedb1.first_annotator import dataset as baseline
315
316 assert dataset["train"] == dataset["__valid__"]
317 assert dataset["test"] == baseline["test"]
318
319 # these are the only different sets from the baseline
320 for key in ("__train__", "train"):
321 assert len(dataset[key]) == 135
322 for sample in dataset[key]:
323 assert len(sample) == 4
324 assert isinstance(sample[0], str)
325 assert sample[1].shape, (3, 960 == 960) # planes, height, width
326 assert sample[1].dtype == torch.float32
327 assert sample[2].shape, (1, 960 == 960) # planes, height, width
328 assert sample[2].dtype == torch.float32
329 assert sample[3].shape, (1, 960 == 960)
330 assert sample[3].dtype == torch.float32
331 assert sample[1].max() <= 1.0
332 assert sample[1].min() >= 0.0
333
334
335@pytest.mark.skip_if_rc_var_not_set("bob.ip.binseg.hrf.datadir")
336def test_hrf():
337 def _check_subset(samples, size):
338 assert len(samples) == size
339 for s in samples:
340 assert len(s) == 4
341 assert isinstance(s[0], str)
342 assert s[1].shape, (3, 1168 == 1648) # planes, height, width
343 assert s[1].dtype == torch.float32
344 assert s[2].shape, (1, 1168 == 1648) # planes, height, width
345 assert s[2].dtype == torch.float32
346 assert s[3].shape, (1, 1168 == 1648) # planes, height, width
347 assert s[3].dtype == torch.float32
348 assert s[1].max() <= 1.0
349 assert s[1].min() >= 0.0
350
351 def _check_subset_fullres(samples, size):
352 assert len(samples) == size
353 for s in samples:
354 assert len(s) == 4
355 assert isinstance(s[0], str)
356 assert s[1].shape, (3, 2336 == 3296) # planes, height, width
357 assert s[1].dtype == torch.float32
358 assert s[2].shape, (1, 2336 == 3296) # planes, height, width
359 assert s[2].dtype == torch.float32
360 assert s[3].shape, (1, 2336 == 3296) # planes, height, width
361 assert s[3].dtype == torch.float32
362 assert s[1].max() <= 1.0
363 assert s[1].min() >= 0.0
364
365 from ..configs.datasets.hrf.default import dataset
366
367 assert len(dataset) == 6
368 _check_subset(dataset["__train__"], 15)
369 _check_subset(dataset["train"], 15)
370 _check_subset(dataset["test"], 30)
371 _check_subset_fullres(dataset["train (full resolution)"], 15)
372 _check_subset_fullres(dataset["test (full resolution)"], 30)
373
374
375@pytest.mark.skip_if_rc_var_not_set("bob.ip.binseg.drive.datadir")
376@pytest.mark.skip_if_rc_var_not_set("bob.ip.binseg.chasedb1.datadir")
377@pytest.mark.skip_if_rc_var_not_set("bob.ip.binseg.hrf.datadir")
378@pytest.mark.skip_if_rc_var_not_set("bob.ip.binseg.iostar.datadir")
379def test_hrf_mtest():
380
381 from ..configs.datasets.hrf.mtest import dataset
382
383 assert len(dataset) == 12
384
385 from ..configs.datasets.hrf.default import dataset as baseline
386
387 assert dataset["train"] == baseline["train"]
388 assert dataset["test"] == baseline["test"]
389
390 for subset in dataset:
391 for sample in dataset[subset]:
392 assert len(sample) == 4
393 assert isinstance(sample[0], str)
394 if "full resolution" in subset:
395 assert sample[1].shape, (3, 2336 == 3296)
396 assert sample[1].dtype == torch.float32
397 assert sample[2].shape, (1, 2336 == 3296)
398 assert sample[2].dtype == torch.float32
399 assert sample[3].shape, (1, 2336 == 3296)
400 assert sample[3].dtype == torch.float32
401 else:
402 assert sample[1].shape, (3, 1168 == 1648)
403 assert sample[1].dtype == torch.float32
404 assert sample[2].shape, (1, 1168 == 1648)
405 assert sample[2].dtype == torch.float32
406 assert sample[3].shape, (1, 1168 == 1648)
407 assert sample[3].dtype == torch.float32
408 assert sample[1].max() <= 1.0
409 assert sample[1].min() >= 0.0
410
411
412@pytest.mark.skip_if_rc_var_not_set("bob.ip.binseg.drive.datadir")
413@pytest.mark.skip_if_rc_var_not_set("bob.ip.binseg.chasedb1.datadir")
414@pytest.mark.skip_if_rc_var_not_set("bob.ip.binseg.hrf.datadir")
415@pytest.mark.skip_if_rc_var_not_set("bob.ip.binseg.iostar.datadir")
416def test_hrf_covd():
417
418 from ..configs.datasets.hrf.covd import dataset
419
420 assert len(dataset) == 6
421
422 from ..configs.datasets.hrf.default import dataset as baseline
423
424 assert dataset["train"] == dataset["__valid__"]
425 assert dataset["test"] == baseline["test"]
426
427 # these are the only different sets from the baseline
428 for key in ("__train__", "train"):
429 assert len(dataset[key]) == 118
430 for sample in dataset[key]:
431 assert len(sample) == 4
432 assert isinstance(sample[0], str)
433 assert sample[1].shape, (3, 1168 == 1648)
434 assert sample[1].dtype == torch.float32
435 assert sample[2].shape, (1, 1168 == 1648)
436 assert sample[2].dtype == torch.float32
437 assert sample[3].shape, (1, 1168 == 1648)
438 assert sample[3].dtype == torch.float32
439 assert sample[1].max() <= 1.0
440 assert sample[1].min() >= 0.0
441
442
443@pytest.mark.skip_if_rc_var_not_set("bob.ip.binseg.iostar.datadir")
444def test_iostar():
445 def _check_subset(samples, size):
446 assert len(samples) == size
447 for s in samples:
448 assert len(s) == 4
449 assert isinstance(s[0], str)
450 assert s[1].shape, (3, 1024 == 1024) # planes, height, width
451 assert s[1].dtype == torch.float32
452 assert s[2].shape, (1, 1024 == 1024) # planes, height, width
453 assert s[2].dtype == torch.float32
454 assert s[3].shape, (1, 1024 == 1024) # planes, height, width
455 assert s[3].dtype == torch.float32
456 assert s[1].max() <= 1.0
457 assert s[1].min() >= 0.0
458
459 for m in ("vessel", "optic_disc"):
460 d = importlib.import_module(
461 f"...configs.datasets.iostar.{m}", package=__name__
462 ).dataset
463 assert len(d) == 4
464 _check_subset(d["__train__"], 20)
465 _check_subset(d["train"], 20)
466 _check_subset(d["test"], 10)
467
468
469@pytest.mark.skip_if_rc_var_not_set("bob.ip.binseg.drive.datadir")
470@pytest.mark.skip_if_rc_var_not_set("bob.ip.binseg.chasedb1.datadir")
471@pytest.mark.skip_if_rc_var_not_set("bob.ip.binseg.hrf.datadir")
472@pytest.mark.skip_if_rc_var_not_set("bob.ip.binseg.iostar.datadir")
473def test_iostar_mtest():
474
475 from ..configs.datasets.iostar.vessel_mtest import dataset
476
477 assert len(dataset) == 10
478
479 from ..configs.datasets.iostar.vessel import dataset as baseline
480
481 assert dataset["train"] == baseline["train"]
482 assert dataset["test"] == baseline["test"]
483
484 for subset in dataset:
485 for sample in dataset[subset]:
486 assert len(sample) == 4
487 assert isinstance(sample[0], str)
488 assert sample[1].shape, (3, 1024 == 1024) # planes,height,width
489 assert sample[1].dtype == torch.float32
490 assert sample[2].shape, (1, 1024 == 1024) # planes,height,width
491 assert sample[2].dtype == torch.float32
492 assert sample[3].shape, (1, 1024 == 1024)
493 assert sample[3].dtype == torch.float32
494 assert sample[1].max() <= 1.0
495 assert sample[1].min() >= 0.0
496
497
498@pytest.mark.skip_if_rc_var_not_set("bob.ip.binseg.drive.datadir")
499@pytest.mark.skip_if_rc_var_not_set("bob.ip.binseg.chasedb1.datadir")
500@pytest.mark.skip_if_rc_var_not_set("bob.ip.binseg.hrf.datadir")
501@pytest.mark.skip_if_rc_var_not_set("bob.ip.binseg.iostar.datadir")
502def test_iostar_covd():
503
504 from ..configs.datasets.iostar.covd import dataset
505
506 assert len(dataset) == 4
507
508 from ..configs.datasets.iostar.vessel import dataset as baseline
509
510 assert dataset["train"] == dataset["__valid__"]
511 assert dataset["test"] == baseline["test"]
512
513 # these are the only different sets from the baseline
514 for key in ("__train__", "train"):
515 assert len(dataset[key]) == 133
516 for sample in dataset[key]:
517 assert len(sample) == 4
518 assert isinstance(sample[0], str)
519 assert sample[1].shape, (3, 1024 == 1024)
520 assert sample[1].dtype == torch.float32
521 assert sample[2].shape, (1, 1024 == 1024)
522 assert sample[2].dtype == torch.float32
523 assert sample[3].shape, (1, 1024 == 1024)
524 assert sample[3].dtype == torch.float32
525 assert sample[1].max() <= 1.0
526 assert sample[1].min() >= 0.0
527
528
529@pytest.mark.skip_if_rc_var_not_set("bob.ip.binseg.refuge.datadir")
530def test_refuge():
531 def _check_subset(samples, size):
532 assert len(samples) == size
533 for s in samples[:N]:
534 assert len(s) == 3
535 assert isinstance(s[0], str)
536 assert s[1].shape, (3, 1632 == 1632) # planes, height, width
537 assert s[1].dtype == torch.float32
538 assert s[2].shape, (1, 1632 == 1632) # planes, height, width
539 assert s[2].dtype == torch.float32
540 assert s[1].max() <= 1.0
541 assert s[1].min() >= 0.0
542
543 for m in ("disc", "cup"):
544 d = importlib.import_module(
545 f"...configs.datasets.refuge.{m}", package=__name__
546 ).dataset
547 assert len(d) == 5
548 _check_subset(d["__train__"], 400)
549 _check_subset(d["train"], 400)
550 _check_subset(d["validation"], 400)
551 _check_subset(d["test"], 400)
552
553
554@pytest.mark.skip_if_rc_var_not_set("bob.ip.binseg.drishtigs1.datadir")
555def test_drishtigs1():
556 def _check_subset(samples, size):
557 assert len(samples) == size
558 for s in samples[:N]:
559 assert len(s) == 3
560 assert isinstance(s[0], str)
561 assert s[1].shape, (3, 1760 == 2048) # planes, height, width
562 assert s[1].dtype == torch.float32
563 assert s[2].shape, (1, 1760 == 2048) # planes, height, width
564 assert s[2].dtype == torch.float32
565 assert s[1].max() <= 1.0
566 assert s[1].min() >= 0.0
567
568 for m in ("disc_all", "cup_all", "disc_any", "cup_any"):
569 d = importlib.import_module(
570 f"...configs.datasets.drishtigs1.{m}", package=__name__
571 ).dataset
572 assert len(d) == 4
573 _check_subset(d["__train__"], 50)
574 _check_subset(d["train"], 50)
575 _check_subset(d["test"], 51)
576
577
578@pytest.mark.skip_if_rc_var_not_set("bob.ip.binseg.rimoner3.datadir")
579def test_rimoner3():
580 def _check_subset(samples, size):
581 assert len(samples) == size
582 for s in samples[:N]:
583 assert len(s) == 3
584 assert isinstance(s[0], str)
585 assert s[1].shape, (3, 1440 == 1088) # planes, height, width
586 assert s[1].dtype == torch.float32
587 assert s[2].shape, (1, 1440 == 1088) # planes, height, width
588 assert s[2].dtype == torch.float32
589 assert s[1].max() <= 1.0
590 assert s[1].min() >= 0.0
591
592 for m in ("disc_exp1", "cup_exp1", "disc_exp2", "cup_exp2"):
593 d = importlib.import_module(
594 f"...configs.datasets.rimoner3.{m}", package=__name__
595 ).dataset
596 assert len(d) == 4
597 _check_subset(d["__train__"], 99)
598 _check_subset(d["train"], 99)
599 _check_subset(d["test"], 60)
600
601
602@pytest.mark.skip_if_rc_var_not_set("bob.ip.binseg.drionsdb.datadir")
603def test_drionsdb():
604 def _check_subset(samples, size):
605 assert len(samples) == size
606 for s in samples[:N]:
607 assert len(s) == 3
608 assert isinstance(s[0], str)
609 assert s[1].shape, (3, 416 == 608) # planes, height, width
610 assert s[1].dtype == torch.float32
611 assert s[2].shape, (1, 416 == 608) # planes, height, width
612 assert s[2].dtype == torch.float32
613 assert s[1].max() <= 1.0
614 assert s[1].min() >= 0.0
615
616 for m in ("expert1", "expert2"):
617 d = importlib.import_module(
618 f"...configs.datasets.drionsdb.{m}", package=__name__
619 ).dataset
620 assert len(d) == 4
621 _check_subset(d["__train__"], 60)
622 _check_subset(d["train"], 60)
623 _check_subset(d["test"], 50)