Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1#!/usr/bin/env python
2# coding=utf-8
4import importlib
6import pytest
7import torch
9from . import mock_dataset
11stare_datadir, stare_dataset = mock_dataset()
13# we only iterate over the first N elements at most - dataset loading has
14# already been checked on the individual datset tests. Here, we are only
15# testing for the extra tools wrapping the dataset
16N = 10
19@pytest.mark.skip_if_rc_var_not_set("bob.ip.binseg.drive.datadir")
20def test_drive():
21 def _check_subset(samples, size):
22 assert len(samples) == size
23 for s in samples:
24 assert len(s) == 4
25 assert isinstance(s[0], str)
26 assert s[1].shape, (3, 544 == 544) # planes, height, width
27 assert s[1].dtype == torch.float32
28 assert s[2].shape, (1, 544 == 544) # planes, height, width
29 assert s[2].dtype == torch.float32
30 assert s[3].shape, (1, 544 == 544) # planes, height, width
31 assert s[3].dtype == torch.float32
32 assert s[1].max() <= 1.0
33 assert s[1].min() >= 0.0
35 from ..configs.datasets.drive.default import dataset
37 assert len(dataset) == 4
38 _check_subset(dataset["__train__"], 20)
39 _check_subset(dataset["__valid__"], 20)
40 _check_subset(dataset["train"], 20)
41 _check_subset(dataset["test"], 20)
43 from ..configs.datasets.drive.second_annotator import dataset
45 assert len(dataset) == 1
46 _check_subset(dataset["test"], 20)
49@pytest.mark.skip_if_rc_var_not_set("bob.ip.binseg.drive.datadir")
50@pytest.mark.skip_if_rc_var_not_set("bob.ip.binseg.chasedb1.datadir")
51@pytest.mark.skip_if_rc_var_not_set("bob.ip.binseg.hrf.datadir")
52@pytest.mark.skip_if_rc_var_not_set("bob.ip.binseg.iostar.datadir")
53def test_drive_mtest():
55 from ..configs.datasets.drive.mtest import dataset
57 assert len(dataset) == 10
59 from ..configs.datasets.drive.default import dataset as baseline
61 assert dataset["train"] == baseline["train"]
62 assert dataset["test"] == baseline["test"]
64 for subset in dataset:
65 for sample in dataset[subset]:
66 assert len(sample) == 4
67 assert isinstance(sample[0], str)
68 assert sample[1].shape, (3, 544 == 544) # planes, height, width
69 assert sample[1].dtype == torch.float32
70 assert sample[2].shape, (1, 544 == 544)
71 assert sample[2].dtype == torch.float32
72 assert sample[3].shape, (1, 544 == 544)
73 assert sample[3].dtype == torch.float32
74 assert sample[1].max() <= 1.0
75 assert sample[1].min() >= 0.0
78@pytest.mark.skip_if_rc_var_not_set("bob.ip.binseg.drive.datadir")
79@pytest.mark.skip_if_rc_var_not_set("bob.ip.binseg.chasedb1.datadir")
80@pytest.mark.skip_if_rc_var_not_set("bob.ip.binseg.hrf.datadir")
81@pytest.mark.skip_if_rc_var_not_set("bob.ip.binseg.iostar.datadir")
82def test_drive_covd():
84 from ..configs.datasets.drive.covd import dataset
86 assert len(dataset) == 4
88 from ..configs.datasets.drive.default import dataset as baseline
90 assert dataset["train"] == dataset["__valid__"]
91 assert dataset["test"] == baseline["test"]
93 for key in ("__train__", "train"):
94 assert len(dataset[key]) == 123
95 for sample in dataset["__train__"]:
96 assert len(sample) == 4
97 assert isinstance(sample[0], str)
98 assert sample[1].shape, (3, 544 == 544) # planes, height, width
99 assert sample[1].dtype == torch.float32
100 assert sample[2].shape, (1, 544 == 544) # planes, height, width
101 assert sample[2].dtype == torch.float32
102 assert sample[3].shape, (1, 544 == 544)
103 assert sample[3].dtype == torch.float32
104 assert sample[1].max() <= 1.0
105 assert sample[1].min() >= 0.0
108@pytest.mark.skip_if_rc_var_not_set("bob.ip.binseg.drive.datadir")
109@pytest.mark.skip_if_rc_var_not_set("bob.ip.binseg.chasedb1.datadir")
110@pytest.mark.skip_if_rc_var_not_set("bob.ip.binseg.hrf.datadir")
111@pytest.mark.skip_if_rc_var_not_set("bob.ip.binseg.iostar.datadir")
112def test_drive_ssl():
114 from ..configs.datasets.drive.ssl import dataset
116 assert len(dataset) == 4
118 from ..configs.datasets.drive.covd import dataset as covd
120 assert dataset["train"] == covd["train"]
121 assert dataset["train"] == dataset["__valid__"]
122 assert dataset["test"] == covd["test"]
123 assert dataset["__valid__"] == covd["__valid__"]
125 # these are the only different from the baseline
126 assert len(dataset["__train__"]) == 123
127 for sample in dataset["__train__"]:
128 assert len(sample) == 6
129 assert isinstance(sample[0], str)
130 assert sample[1].shape, (3, 544 == 544) # planes, height, width
131 assert sample[1].dtype == torch.float32
132 assert sample[2].shape, (1, 544 == 544) # planes, height, width
133 assert sample[2].dtype == torch.float32
134 assert sample[3].shape, (1, 544 == 544) # planes, height, width
135 assert sample[3].dtype == torch.float32
136 assert isinstance(sample[4], str)
137 assert sample[5].shape, (3, 544 == 544) # planes, height, width
138 assert sample[5].dtype == torch.float32
139 assert sample[1].max() <= 1.0
140 assert sample[1].min() >= 0.0
143def test_stare_augmentation_manipulation():
145 # some tests to check our context management for dataset augmentation works
146 # adequately, with one example dataset
148 # hack to allow testing on the CI
149 from ..configs.datasets.stare import _maker
151 dataset = _maker("ah", stare_dataset)
153 assert len(dataset["__train__"]._transforms.transforms) == (
154 len(dataset["test"]._transforms.transforms) + 4
155 )
157 assert len(dataset["train"]._transforms.transforms) == len(
158 dataset["test"]._transforms.transforms
159 )
162def test_stare():
163 def _check_subset(samples, size):
164 assert len(samples) == size
165 for s in samples:
166 assert len(s) == 4
167 assert isinstance(s[0], str)
168 assert s[1].shape, (3, 608 == 704) # planes, height, width
169 assert s[1].dtype == torch.float32
170 assert s[2].shape, (1, 608 == 704) # planes, height, width
171 assert s[2].dtype == torch.float32
172 assert s[3].shape, (1, 608 == 704) # planes, height, width
173 assert s[3].dtype == torch.float32
174 assert s[1].max() <= 1.0
175 assert s[1].min() >= 0.0
177 # hack to allow testing on the CI
178 from ..configs.datasets.stare import _maker
180 for protocol in "ah", "vk":
181 dataset = _maker(protocol, stare_dataset)
182 assert len(dataset) == 4
183 _check_subset(dataset["__train__"], 10)
184 _check_subset(dataset["train"], 10)
185 _check_subset(dataset["test"], 10)
188@pytest.mark.skip_if_rc_var_not_set("bob.ip.binseg.drive.datadir")
189@pytest.mark.skip_if_rc_var_not_set("bob.ip.binseg.chasedb1.datadir")
190@pytest.mark.skip_if_rc_var_not_set("bob.ip.binseg.hrf.datadir")
191@pytest.mark.skip_if_rc_var_not_set("bob.ip.binseg.iostar.datadir")
192def test_stare_mtest():
194 from ..configs.datasets.stare.mtest import dataset
196 assert len(dataset) == 10
198 from ..configs.datasets.stare.ah import dataset as baseline
200 assert dataset["train"] == baseline["train"]
201 assert dataset["test"] == baseline["test"]
203 for subset in dataset:
204 for sample in dataset[subset]:
205 assert len(sample) == 4
206 assert isinstance(sample[0], str)
207 assert sample[1].shape, (3, 608 == 704) # planes,height,width
208 assert sample[1].dtype == torch.float32
209 assert sample[2].shape, (1, 608 == 704) # planes,height,width
210 assert sample[2].dtype == torch.float32
211 assert sample[3].shape, (1, 608 == 704)
212 assert sample[3].dtype == torch.float32
213 assert sample[1].max() <= 1.0
214 assert sample[1].min() >= 0.0
217@pytest.mark.skip_if_rc_var_not_set("bob.ip.binseg.drive.datadir")
218@pytest.mark.skip_if_rc_var_not_set("bob.ip.binseg.chasedb1.datadir")
219@pytest.mark.skip_if_rc_var_not_set("bob.ip.binseg.hrf.datadir")
220@pytest.mark.skip_if_rc_var_not_set("bob.ip.binseg.iostar.datadir")
221def test_stare_covd():
223 from ..configs.datasets.stare.covd import dataset
225 assert len(dataset) == 4
227 from ..configs.datasets.stare.ah import dataset as baseline
229 assert dataset["train"] == dataset["__valid__"]
230 assert dataset["test"] == baseline["test"]
232 # these are the only different sets from the baseline
233 for key in ("__train__", "train"):
234 assert len(dataset[key]) == 143
235 for sample in dataset[key]:
236 assert len(sample) == 4
237 assert isinstance(sample[0], str)
238 assert sample[1].shape, (3, 608 == 704) # planes, height, width
239 assert sample[1].dtype == torch.float32
240 assert sample[2].shape, (1, 608 == 704) # planes, height, width
241 assert sample[2].dtype == torch.float32
242 assert sample[1].max() <= 1.0
243 assert sample[1].min() >= 0.0
244 assert sample[3].shape, (1, 608 == 704)
245 assert sample[3].dtype == torch.float32
248@pytest.mark.skip_if_rc_var_not_set("bob.ip.binseg.chasedb1.datadir")
249def test_chasedb1():
250 def _check_subset(samples, size):
251 assert len(samples) == size
252 for s in samples:
253 assert len(s) == 4
254 assert isinstance(s[0], str)
255 assert s[1].shape, (3, 960 == 960) # planes, height, width
256 assert s[1].dtype == torch.float32
257 assert s[2].shape, (1, 960 == 960) # planes, height, width
258 assert s[2].dtype == torch.float32
259 assert s[3].shape, (1, 960 == 960) # planes, height, width
260 assert s[3].dtype == torch.float32
261 assert s[1].max() <= 1.0
262 assert s[1].min() >= 0.0
264 for m in ("first_annotator", "second_annotator"):
265 d = importlib.import_module(
266 f"...configs.datasets.chasedb1.{m}", package=__name__
267 ).dataset
268 assert len(d) == 4
269 _check_subset(d["__train__"], 8)
270 _check_subset(d["__valid__"], 8)
271 _check_subset(d["train"], 8)
272 _check_subset(d["test"], 20)
275@pytest.mark.skip_if_rc_var_not_set("bob.ip.binseg.drive.datadir")
276@pytest.mark.skip_if_rc_var_not_set("bob.ip.binseg.chasedb1.datadir")
277@pytest.mark.skip_if_rc_var_not_set("bob.ip.binseg.hrf.datadir")
278@pytest.mark.skip_if_rc_var_not_set("bob.ip.binseg.iostar.datadir")
279def test_chasedb1_mtest():
281 from ..configs.datasets.chasedb1.mtest import dataset
283 assert len(dataset) == 10
285 from ..configs.datasets.chasedb1.first_annotator import dataset as baseline
287 assert dataset["train"] == baseline["train"]
288 assert dataset["test"] == baseline["test"]
290 for subset in dataset:
291 for sample in dataset[subset]:
292 assert len(sample) == 4
293 assert isinstance(sample[0], str)
294 assert sample[1].shape, (3, 960 == 960) # planes,height,width
295 assert sample[1].dtype == torch.float32
296 assert sample[2].shape, (1, 960 == 960) # planes,height,width
297 assert sample[2].dtype == torch.float32
298 assert sample[3].shape, (1, 960 == 960)
299 assert sample[3].dtype == torch.float32
300 assert sample[1].max() <= 1.0
301 assert sample[1].min() >= 0.0
304@pytest.mark.skip_if_rc_var_not_set("bob.ip.binseg.drive.datadir")
305@pytest.mark.skip_if_rc_var_not_set("bob.ip.binseg.chasedb1.datadir")
306@pytest.mark.skip_if_rc_var_not_set("bob.ip.binseg.hrf.datadir")
307@pytest.mark.skip_if_rc_var_not_set("bob.ip.binseg.iostar.datadir")
308def test_chasedb1_covd():
310 from ..configs.datasets.chasedb1.covd import dataset
312 assert len(dataset) == 4
314 from ..configs.datasets.chasedb1.first_annotator import dataset as baseline
316 assert dataset["train"] == dataset["__valid__"]
317 assert dataset["test"] == baseline["test"]
319 # these are the only different sets from the baseline
320 for key in ("__train__", "train"):
321 assert len(dataset[key]) == 135
322 for sample in dataset[key]:
323 assert len(sample) == 4
324 assert isinstance(sample[0], str)
325 assert sample[1].shape, (3, 960 == 960) # planes, height, width
326 assert sample[1].dtype == torch.float32
327 assert sample[2].shape, (1, 960 == 960) # planes, height, width
328 assert sample[2].dtype == torch.float32
329 assert sample[3].shape, (1, 960 == 960)
330 assert sample[3].dtype == torch.float32
331 assert sample[1].max() <= 1.0
332 assert sample[1].min() >= 0.0
335@pytest.mark.skip_if_rc_var_not_set("bob.ip.binseg.hrf.datadir")
336def test_hrf():
337 def _check_subset(samples, size):
338 assert len(samples) == size
339 for s in samples:
340 assert len(s) == 4
341 assert isinstance(s[0], str)
342 assert s[1].shape, (3, 1168 == 1648) # planes, height, width
343 assert s[1].dtype == torch.float32
344 assert s[2].shape, (1, 1168 == 1648) # planes, height, width
345 assert s[2].dtype == torch.float32
346 assert s[3].shape, (1, 1168 == 1648) # planes, height, width
347 assert s[3].dtype == torch.float32
348 assert s[1].max() <= 1.0
349 assert s[1].min() >= 0.0
351 def _check_subset_fullres(samples, size):
352 assert len(samples) == size
353 for s in samples:
354 assert len(s) == 4
355 assert isinstance(s[0], str)
356 assert s[1].shape, (3, 2336 == 3296) # planes, height, width
357 assert s[1].dtype == torch.float32
358 assert s[2].shape, (1, 2336 == 3296) # planes, height, width
359 assert s[2].dtype == torch.float32
360 assert s[3].shape, (1, 2336 == 3296) # planes, height, width
361 assert s[3].dtype == torch.float32
362 assert s[1].max() <= 1.0
363 assert s[1].min() >= 0.0
365 from ..configs.datasets.hrf.default import dataset
367 assert len(dataset) == 6
368 _check_subset(dataset["__train__"], 15)
369 _check_subset(dataset["train"], 15)
370 _check_subset(dataset["test"], 30)
371 _check_subset_fullres(dataset["train (full resolution)"], 15)
372 _check_subset_fullres(dataset["test (full resolution)"], 30)
375@pytest.mark.skip_if_rc_var_not_set("bob.ip.binseg.drive.datadir")
376@pytest.mark.skip_if_rc_var_not_set("bob.ip.binseg.chasedb1.datadir")
377@pytest.mark.skip_if_rc_var_not_set("bob.ip.binseg.hrf.datadir")
378@pytest.mark.skip_if_rc_var_not_set("bob.ip.binseg.iostar.datadir")
379def test_hrf_mtest():
381 from ..configs.datasets.hrf.mtest import dataset
383 assert len(dataset) == 12
385 from ..configs.datasets.hrf.default import dataset as baseline
387 assert dataset["train"] == baseline["train"]
388 assert dataset["test"] == baseline["test"]
390 for subset in dataset:
391 for sample in dataset[subset]:
392 assert len(sample) == 4
393 assert isinstance(sample[0], str)
394 if "full resolution" in subset:
395 assert sample[1].shape, (3, 2336 == 3296)
396 assert sample[1].dtype == torch.float32
397 assert sample[2].shape, (1, 2336 == 3296)
398 assert sample[2].dtype == torch.float32
399 assert sample[3].shape, (1, 2336 == 3296)
400 assert sample[3].dtype == torch.float32
401 else:
402 assert sample[1].shape, (3, 1168 == 1648)
403 assert sample[1].dtype == torch.float32
404 assert sample[2].shape, (1, 1168 == 1648)
405 assert sample[2].dtype == torch.float32
406 assert sample[3].shape, (1, 1168 == 1648)
407 assert sample[3].dtype == torch.float32
408 assert sample[1].max() <= 1.0
409 assert sample[1].min() >= 0.0
412@pytest.mark.skip_if_rc_var_not_set("bob.ip.binseg.drive.datadir")
413@pytest.mark.skip_if_rc_var_not_set("bob.ip.binseg.chasedb1.datadir")
414@pytest.mark.skip_if_rc_var_not_set("bob.ip.binseg.hrf.datadir")
415@pytest.mark.skip_if_rc_var_not_set("bob.ip.binseg.iostar.datadir")
416def test_hrf_covd():
418 from ..configs.datasets.hrf.covd import dataset
420 assert len(dataset) == 6
422 from ..configs.datasets.hrf.default import dataset as baseline
424 assert dataset["train"] == dataset["__valid__"]
425 assert dataset["test"] == baseline["test"]
427 # these are the only different sets from the baseline
428 for key in ("__train__", "train"):
429 assert len(dataset[key]) == 118
430 for sample in dataset[key]:
431 assert len(sample) == 4
432 assert isinstance(sample[0], str)
433 assert sample[1].shape, (3, 1168 == 1648)
434 assert sample[1].dtype == torch.float32
435 assert sample[2].shape, (1, 1168 == 1648)
436 assert sample[2].dtype == torch.float32
437 assert sample[3].shape, (1, 1168 == 1648)
438 assert sample[3].dtype == torch.float32
439 assert sample[1].max() <= 1.0
440 assert sample[1].min() >= 0.0
443@pytest.mark.skip_if_rc_var_not_set("bob.ip.binseg.iostar.datadir")
444def test_iostar():
445 def _check_subset(samples, size):
446 assert len(samples) == size
447 for s in samples:
448 assert len(s) == 4
449 assert isinstance(s[0], str)
450 assert s[1].shape, (3, 1024 == 1024) # planes, height, width
451 assert s[1].dtype == torch.float32
452 assert s[2].shape, (1, 1024 == 1024) # planes, height, width
453 assert s[2].dtype == torch.float32
454 assert s[3].shape, (1, 1024 == 1024) # planes, height, width
455 assert s[3].dtype == torch.float32
456 assert s[1].max() <= 1.0
457 assert s[1].min() >= 0.0
459 for m in ("vessel", "optic_disc"):
460 d = importlib.import_module(
461 f"...configs.datasets.iostar.{m}", package=__name__
462 ).dataset
463 assert len(d) == 4
464 _check_subset(d["__train__"], 20)
465 _check_subset(d["train"], 20)
466 _check_subset(d["test"], 10)
469@pytest.mark.skip_if_rc_var_not_set("bob.ip.binseg.drive.datadir")
470@pytest.mark.skip_if_rc_var_not_set("bob.ip.binseg.chasedb1.datadir")
471@pytest.mark.skip_if_rc_var_not_set("bob.ip.binseg.hrf.datadir")
472@pytest.mark.skip_if_rc_var_not_set("bob.ip.binseg.iostar.datadir")
473def test_iostar_mtest():
475 from ..configs.datasets.iostar.vessel_mtest import dataset
477 assert len(dataset) == 10
479 from ..configs.datasets.iostar.vessel import dataset as baseline
481 assert dataset["train"] == baseline["train"]
482 assert dataset["test"] == baseline["test"]
484 for subset in dataset:
485 for sample in dataset[subset]:
486 assert len(sample) == 4
487 assert isinstance(sample[0], str)
488 assert sample[1].shape, (3, 1024 == 1024) # planes,height,width
489 assert sample[1].dtype == torch.float32
490 assert sample[2].shape, (1, 1024 == 1024) # planes,height,width
491 assert sample[2].dtype == torch.float32
492 assert sample[3].shape, (1, 1024 == 1024)
493 assert sample[3].dtype == torch.float32
494 assert sample[1].max() <= 1.0
495 assert sample[1].min() >= 0.0
498@pytest.mark.skip_if_rc_var_not_set("bob.ip.binseg.drive.datadir")
499@pytest.mark.skip_if_rc_var_not_set("bob.ip.binseg.chasedb1.datadir")
500@pytest.mark.skip_if_rc_var_not_set("bob.ip.binseg.hrf.datadir")
501@pytest.mark.skip_if_rc_var_not_set("bob.ip.binseg.iostar.datadir")
502def test_iostar_covd():
504 from ..configs.datasets.iostar.covd import dataset
506 assert len(dataset) == 4
508 from ..configs.datasets.iostar.vessel import dataset as baseline
510 assert dataset["train"] == dataset["__valid__"]
511 assert dataset["test"] == baseline["test"]
513 # these are the only different sets from the baseline
514 for key in ("__train__", "train"):
515 assert len(dataset[key]) == 133
516 for sample in dataset[key]:
517 assert len(sample) == 4
518 assert isinstance(sample[0], str)
519 assert sample[1].shape, (3, 1024 == 1024)
520 assert sample[1].dtype == torch.float32
521 assert sample[2].shape, (1, 1024 == 1024)
522 assert sample[2].dtype == torch.float32
523 assert sample[3].shape, (1, 1024 == 1024)
524 assert sample[3].dtype == torch.float32
525 assert sample[1].max() <= 1.0
526 assert sample[1].min() >= 0.0
529@pytest.mark.skip_if_rc_var_not_set("bob.ip.binseg.refuge.datadir")
530def test_refuge():
531 def _check_subset(samples, size):
532 assert len(samples) == size
533 for s in samples[:N]:
534 assert len(s) == 3
535 assert isinstance(s[0], str)
536 assert s[1].shape, (3, 1632 == 1632) # planes, height, width
537 assert s[1].dtype == torch.float32
538 assert s[2].shape, (1, 1632 == 1632) # planes, height, width
539 assert s[2].dtype == torch.float32
540 assert s[1].max() <= 1.0
541 assert s[1].min() >= 0.0
543 for m in ("disc", "cup"):
544 d = importlib.import_module(
545 f"...configs.datasets.refuge.{m}", package=__name__
546 ).dataset
547 assert len(d) == 5
548 _check_subset(d["__train__"], 400)
549 _check_subset(d["train"], 400)
550 _check_subset(d["validation"], 400)
551 _check_subset(d["test"], 400)
554@pytest.mark.skip_if_rc_var_not_set("bob.ip.binseg.drishtigs1.datadir")
555def test_drishtigs1():
556 def _check_subset(samples, size):
557 assert len(samples) == size
558 for s in samples[:N]:
559 assert len(s) == 3
560 assert isinstance(s[0], str)
561 assert s[1].shape, (3, 1760 == 2048) # planes, height, width
562 assert s[1].dtype == torch.float32
563 assert s[2].shape, (1, 1760 == 2048) # planes, height, width
564 assert s[2].dtype == torch.float32
565 assert s[1].max() <= 1.0
566 assert s[1].min() >= 0.0
568 for m in ("disc_all", "cup_all", "disc_any", "cup_any"):
569 d = importlib.import_module(
570 f"...configs.datasets.drishtigs1.{m}", package=__name__
571 ).dataset
572 assert len(d) == 4
573 _check_subset(d["__train__"], 50)
574 _check_subset(d["train"], 50)
575 _check_subset(d["test"], 51)
578@pytest.mark.skip_if_rc_var_not_set("bob.ip.binseg.rimoner3.datadir")
579def test_rimoner3():
580 def _check_subset(samples, size):
581 assert len(samples) == size
582 for s in samples[:N]:
583 assert len(s) == 3
584 assert isinstance(s[0], str)
585 assert s[1].shape, (3, 1440 == 1088) # planes, height, width
586 assert s[1].dtype == torch.float32
587 assert s[2].shape, (1, 1440 == 1088) # planes, height, width
588 assert s[2].dtype == torch.float32
589 assert s[1].max() <= 1.0
590 assert s[1].min() >= 0.0
592 for m in ("disc_exp1", "cup_exp1", "disc_exp2", "cup_exp2"):
593 d = importlib.import_module(
594 f"...configs.datasets.rimoner3.{m}", package=__name__
595 ).dataset
596 assert len(d) == 4
597 _check_subset(d["__train__"], 99)
598 _check_subset(d["train"], 99)
599 _check_subset(d["test"], 60)
602@pytest.mark.skip_if_rc_var_not_set("bob.ip.binseg.drionsdb.datadir")
603def test_drionsdb():
604 def _check_subset(samples, size):
605 assert len(samples) == size
606 for s in samples[:N]:
607 assert len(s) == 3
608 assert isinstance(s[0], str)
609 assert s[1].shape, (3, 416 == 608) # planes, height, width
610 assert s[1].dtype == torch.float32
611 assert s[2].shape, (1, 416 == 608) # planes, height, width
612 assert s[2].dtype == torch.float32
613 assert s[1].max() <= 1.0
614 assert s[1].min() >= 0.0
616 for m in ("expert1", "expert2"):
617 d = importlib.import_module(
618 f"...configs.datasets.drionsdb.{m}", package=__name__
619 ).dataset
620 assert len(d) == 4
621 _check_subset(d["__train__"], 60)
622 _check_subset(d["train"], 60)
623 _check_subset(d["test"], 50)