1#!/usr/bin/env python
2# coding=utf-8
3
4
5"""Tests for Drishti-GS1"""
6
7import os
8
9import numpy
10import pytest
11
12from ...binseg.data.drishtigs1 import dataset
13from .utils import count_bw
14
15
16def test_protocol_consistency():
17
18 for protocol in (
19 "optic-disc-all",
20 "optic-cup-all",
21 "optic-disc-any",
22 "optic-cup-any",
23 ):
24
25 subset = dataset.subsets(protocol)
26 assert len(subset) == 2
27
28 assert "train" in subset
29 assert len(subset["train"]) == 50
30 for s in subset["train"]:
31 assert s.key.startswith(
32 os.path.join(
33 "Drishti-GS1_files", "Training", "Images", "drishtiGS_"
34 )
35 )
36
37 assert "test" in subset
38 assert len(subset["test"]) == 51
39 for s in subset["test"]:
40 assert s.key.startswith(
41 os.path.join(
42 "Drishti-GS1_files", "Test", "Images", "drishtiGS_"
43 )
44 )
45
46
47@pytest.mark.skip_if_rc_var_not_set("bob.ip.binseg.drishtigs1.datadir")
48@pytest.mark.slow
49def test_loading():
50 def _check_sample(s, bw_threshold_label):
51
52 data = s.data
53 assert isinstance(data, dict)
54 assert len(data) == 3
55
56 assert "data" in data
57 assert data["data"].size[0] > 2040, (
58 f"Width ({data['data'].size[0]}) for {s.key} is smaller "
59 f"than 2040 pixels"
60 )
61 assert data["data"].size[1] > 1740, (
62 f"Width ({data['data'].size[1]}) for {s.key} is smaller "
63 f"than 1740 pixels"
64 )
65 assert data["data"].mode == "RGB"
66
67 assert "label" in data
68 # assert data["label"].size == image_size
69 assert data["data"].size == data["label"].size
70 assert data["label"].mode == "1"
71 b, w = count_bw(data["label"])
72 assert (b + w) == numpy.prod(data["data"].size), (
73 f"Counts of black + white ({b}+{w}) do not add up to total "
74 f"image size ({numpy.prod(data['data'].size)}) at '{s.key}':label"
75 )
76 assert (w / b) < bw_threshold_label, (
77 f"The proportion between black and white pixels "
78 f"({w}/{b}={w/b:.3f}) is larger than the allowed threshold "
79 f"of {bw_threshold_label} at '{s.key}':label - this could "
80 f"indicate a loading problem!"
81 )
82
83 assert "mask" in data
84 assert data["data"].size == data["mask"].size
85 assert data["mask"].mode == "1"
86
87 # to visualize images, uncomment the folowing code
88 # it should display an image with a faded background representing the
89 # original data, blended with green labels.
90 # from ..data.utils import overlayed_image
91 # display = overlayed_image(data["data"], data["label"])
92 # display.show()
93 # import ipdb; ipdb.set_trace()
94
95 return w / b
96
97 limit = None
98 subset = dataset.subsets("optic-cup-all")
99 proportions = [_check_sample(s, 0.027) for s in subset["train"][:limit]]
100 # print(f"max label proportions = {max(proportions)}")
101 proportions = [_check_sample(s, 0.035) for s in subset["test"][:limit]]
102 # print(f"max label proportions = {max(proportions)}")
103
104 subset = dataset.subsets("optic-disc-all")
105 proportions = [_check_sample(s, 0.045) for s in subset["train"][:limit]]
106 # print(f"max label proportions = {max(proportions)}")
107 proportions = [_check_sample(s, 0.055) for s in subset["test"][:limit]]
108 # print(f"max label proportions = {max(proportions)}")
109
110 subset = dataset.subsets("optic-cup-any")
111 proportions = [_check_sample(s, 0.034) for s in subset["train"][:limit]]
112 # print(f"max label proportions = {max(proportions)}")
113 proportions = [_check_sample(s, 0.047) for s in subset["test"][:limit]]
114 # print(f"max label proportions = {max(proportions)}")
115
116 subset = dataset.subsets("optic-disc-any")
117 proportions = [_check_sample(s, 0.052) for s in subset["train"][:limit]]
118 # print(f"max label proportions = {max(proportions)}")
119 proportions = [_check_sample(s, 0.060) for s in subset["test"][:limit]]
120 # print(f"max label proportions = {max(proportions)}")
121 del proportions # only to satisfy flake8
122
123
124@pytest.mark.skip_if_rc_var_not_set("bob.ip.binseg.drishtigs1.datadir")
125@pytest.mark.slow
126def test_check():
127 assert dataset.check() == 0