1#!/usr/bin/env python
2# coding=utf-8
3
4
5"""Tests for RIM-ONE r3"""
6
7import numpy
8import pytest
9
10from ...binseg.data.rimoner3 import dataset
11from .utils import count_bw
12
13
14def test_protocol_consistency():
15
16 for protocol in (
17 "optic-disc-exp1",
18 "optic-cup-exp1",
19 "optic-disc-exp2",
20 "optic-cup-exp2",
21 "optic-disc-avg",
22 "optic-cup-avg",
23 ):
24
25 subset = dataset.subsets(protocol)
26 assert len(subset) == 2
27
28 assert "train" in subset
29 assert len(subset["train"]) == 99
30 for s in subset["train"]:
31 assert "Stereo Images" in s.key
32
33 assert "test" in subset
34 assert len(subset["test"]) == 60
35 for s in subset["test"]:
36 assert "Stereo Images" in s.key
37
38
39@pytest.mark.skip_if_rc_var_not_set("bob.ip.binseg.rimoner3.datadir")
40@pytest.mark.slow
41def test_loading():
42
43 image_size = (1072, 1424)
44
45 def _check_sample(s, bw_threshold_label):
46
47 data = s.data
48 assert isinstance(data, dict)
49 assert len(data) == 3
50
51 assert "data" in data
52 assert data["data"].size == image_size
53 assert data["data"].mode == "RGB"
54
55 assert "label" in data
56 assert data["label"].size == image_size
57 assert data["label"].mode == "1"
58 b, w = count_bw(data["label"])
59 assert (b + w) == numpy.prod(image_size), (
60 f"Counts of black + white ({b}+{w}) do not add up to total "
61 f"image size ({numpy.prod(image_size)}) at '{s.key}':label"
62 )
63 assert (w / b) < bw_threshold_label, (
64 f"The proportion between black and white pixels "
65 f"({w}/{b}={w/b:.2f}) is larger than the allowed threshold "
66 f"of {bw_threshold_label} at '{s.key}':label - this could "
67 f"indicate a loading problem!"
68 )
69
70 assert "mask" in data
71 assert data["mask"].size == image_size
72 assert data["mask"].mode == "1"
73
74 # to visualize images, uncomment the folowing code
75 # it should display an image with a faded background representing the
76 # original data, blended with green labels.
77 # from ..data.utils import overlayed_image
78 # display = overlayed_image(data["data"], data["label"])
79 # display.show()
80 # import ipdb; ipdb.set_trace()
81
82 return w / b
83
84 subset = dataset.subsets("optic-cup-exp1")
85 limit = None
86 proportions = [_check_sample(s, 0.048) for s in subset["train"][:limit]]
87 # print(f"max label proportions = {max(proportions)}")
88 proportions = [_check_sample(s, 0.042) for s in subset["test"][:limit]]
89 # print(f"max label proportions = {max(proportions)}")
90
91 subset = dataset.subsets("optic-disc-exp1")
92 proportions = [_check_sample(s, 0.088) for s in subset["train"][:limit]]
93 # print(f"max label proportions = {max(proportions)}")
94 proportions = [_check_sample(s, 0.061) for s in subset["test"][:limit]]
95 # print(f"max label proportions = {max(proportions)}")
96
97 subset = dataset.subsets("optic-cup-exp2")
98 proportions = [_check_sample(s, 0.039) for s in subset["train"][:limit]]
99 # print(f"max label proportions = {max(proportions)}")
100 proportions = [_check_sample(s, 0.038) for s in subset["test"][:limit]]
101 # print(f"max label proportions = {max(proportions)}")
102
103 subset = dataset.subsets("optic-disc-exp2")
104 proportions = [_check_sample(s, 0.090) for s in subset["train"][:limit]]
105 # print(f"max label proportions = {max(proportions)}")
106 proportions = [_check_sample(s, 0.065) for s in subset["test"][:limit]]
107 # print(f"max label proportions = {max(proportions)}")
108
109 subset = dataset.subsets("optic-cup-avg")
110 proportions = [_check_sample(s, 0.042) for s in subset["train"][:limit]]
111 # print(f"max label proportions = {max(proportions)}")
112 proportions = [_check_sample(s, 0.040) for s in subset["test"][:limit]]
113 # print(f"max label proportions = {max(proportions)}")
114
115 subset = dataset.subsets("optic-disc-avg")
116 proportions = [_check_sample(s, 0.089) for s in subset["train"][:limit]]
117 # print(f"max label proportions = {max(proportions)}")
118 proportions = [_check_sample(s, 0.063) for s in subset["test"][:limit]]
119 # print(f"max label proportions = {max(proportions)}")
120 del proportions # only to satisfy flake8
121
122
123@pytest.mark.skip_if_rc_var_not_set("bob.ip.binseg.rimoner3.datadir")
124@pytest.mark.slow
125def test_check():
126 assert dataset.check() == 0