1#!/usr/bin/env python
2# coding=utf-8
3
4
5"""Tests for DRIONS-DB"""
6
7import os
8
9import numpy
10import pytest
11
12from ...binseg.data.drionsdb import dataset
13from .utils import count_bw
14
15
16def test_protocol_consistency():
17
18 for protocol in ("expert1", "expert2"):
19
20 subset = dataset.subsets(protocol)
21 assert len(subset) == 2
22
23 assert "train" in subset
24 assert len(subset["train"]) == 60
25 for s in subset["train"]:
26 assert s.key.startswith(os.path.join("images", "image_0"))
27
28 assert "test" in subset
29 assert len(subset["test"]) == 50
30 for s in subset["test"]:
31 assert s.key.startswith(os.path.join("images", "image_"))
32
33
34@pytest.mark.skip_if_rc_var_not_set("bob.ip.binseg.drionsdb.datadir")
35@pytest.mark.slow
36def test_loading():
37
38 image_size = (600, 400)
39
40 def _check_sample(s, bw_threshold_label):
41
42 data = s.data
43 assert isinstance(data, dict)
44 assert len(data) == 3
45
46 assert "data" in data
47 assert data["data"].size == image_size
48 assert data["data"].mode == "RGB"
49
50 assert "label" in data
51 assert data["label"].size == image_size
52 assert data["label"].mode == "1"
53
54 assert "mask" in data
55 assert data["mask"].size == image_size
56 assert data["mask"].mode == "1"
57
58 b, w = count_bw(data["label"])
59 assert (b + w) == numpy.prod(image_size), (
60 f"Counts of black + white ({b}+{w}) do not add up to total "
61 f"image size ({numpy.prod(image_size)}) at '{s.key}':label"
62 )
63 assert (w / b) < bw_threshold_label, (
64 f"The proportion between black and white pixels "
65 f"({w}/{b}={w/b:.3f}) is larger than the allowed threshold "
66 f"of {bw_threshold_label} at '{s.key}':label - this could "
67 f"indicate a loading problem!"
68 )
69
70 # to visualize images, uncomment the folowing code
71 # it should display an image with a faded background representing the
72 # original data, blended with green labels.
73 # from ..data.utils import overlayed_image
74 # display = overlayed_image(data["data"], data["label"])
75 # display.show()
76 # import ipdb; ipdb.set_trace()
77
78 return w / b
79
80 limit = None # use this to limit testing to first images only
81 subset = dataset.subsets("expert1")
82 proportions = [_check_sample(s, 0.046) for s in subset["train"][:limit]]
83 # print(f"max label proportions = {max(proportions)}")
84 proportions = [_check_sample(s, 0.043) for s in subset["test"][:limit]]
85 # print(f"max label proportions = {max(proportions)}")
86
87 subset = dataset.subsets("expert2")
88 proportions = [_check_sample(s, 0.044) for s in subset["train"][:limit]]
89 # print(f"max label proportions = {max(proportions)}")
90 proportions = [_check_sample(s, 0.045) for s in subset["test"][:limit]]
91 # print(f"max label proportions = {max(proportions)}")
92 del proportions # only to satisfy flake8
93
94
95@pytest.mark.skip_if_rc_var_not_set("bob.ip.binseg.drionsdb.datadir")
96@pytest.mark.slow
97def test_check():
98 assert dataset.check() == 0