1#!/usr/bin/env python
2# coding=utf-8
3
4
5"""Tests for HRF"""
6
7import os
8
9import numpy
10import pytest
11
12from ...binseg.data.hrf import dataset
13from .utils import count_bw
14
15
16def test_protocol_consistency():
17
18 subset = dataset.subsets("default")
19 assert len(subset) == 2
20
21 assert "train" in subset
22 assert len(subset["train"]) == 15
23 for s in subset["train"]:
24 assert s.key.startswith(os.path.join("images", "0"))
25
26 assert "test" in subset
27 assert len(subset["test"]) == 30
28 for s in subset["test"]:
29 assert s.key.startswith("images")
30
31
32@pytest.mark.skip_if_rc_var_not_set("bob.ip.binseg.hrf.datadir")
33def test_loading():
34
35 image_size = (3504, 2336)
36
37 def _check_sample(s, bw_threshold_label, bw_threshold_mask):
38
39 data = s.data
40 assert isinstance(data, dict)
41 assert len(data) == 3
42
43 assert "data" in data
44 assert data["data"].size == image_size
45 assert data["data"].mode == "RGB"
46
47 assert "label" in data
48 assert data["label"].size == image_size
49 assert data["label"].mode == "1"
50 b, w = count_bw(data["label"])
51 assert (b + w) == numpy.prod(image_size), (
52 f"Counts of black + white ({b}+{w}) do not add up to total "
53 f"image size ({numpy.prod(image_size)}) at '{s.key}':label"
54 )
55 assert (w / b) < bw_threshold_label, (
56 f"The proportion between black and white pixels "
57 f"({w}/{b}={w/b:.2f}) is larger than the allowed threshold "
58 f"of {bw_threshold_label} at '{s.key}':label - this could "
59 f"indicate a loading problem!"
60 )
61
62 assert "mask" in data
63 assert data["mask"].size == image_size
64 assert data["mask"].mode == "1"
65 bm, wm = count_bw(data["mask"])
66 assert (bm + wm) == numpy.prod(image_size), (
67 f"Counts of black + white ({bm}+{wm}) do not add up to total "
68 f"image size ({numpy.prod(image_size)}) at '{s.key}':mask"
69 )
70 assert (wm / bm) > bw_threshold_mask, (
71 f"The proportion between black and white pixels in masks "
72 f"({wm}/{bm}={wm/bm:.2f}) is smaller than the allowed "
73 f"threshold of {bw_threshold_mask} at '{s.key}':label - "
74 f"this could indicate a loading problem!"
75 )
76
77 # to visualize images, uncomment the folowing code
78 # it should display an image with a faded background representing the
79 # original data, blended with green labels and blue area indicating the
80 # parts to be masked out.
81 # from ..data.utils import overlayed_image
82 # display = overlayed_image(data["data"], data["label"], data["mask"])
83 # display.show()
84 # import ipdb; ipdb.set_trace()
85
86 return w / b, wm / bm
87
88 limit = None # use this to limit testing to first images only
89 subset = dataset.subsets("default")
90 proportions = [
91 _check_sample(s, 0.12, 5.42) for s in subset["train"][:limit]
92 ]
93 # print(f"max label proportions = {max(k[0] for k in proportions)}")
94 # print(f"min mask proportions = {min(k[1] for k in proportions)}")
95 proportions = [_check_sample(s, 0.12, 5.41) for s in subset["test"][:limit]]
96 # print(f"max label proportions = {max(k[0] for k in proportions)}")
97 # print(f"min mask proportions = {min(k[1] for k in proportions)}")
98 del proportions # only to satisfy flake8
99
100
101@pytest.mark.skip_if_rc_var_not_set("bob.ip.binseg.hrf.datadir")
102def test_check():
103 assert dataset.check() == 0