1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3
4"""Tests for transforms"""
5
6import os
7import pkg_resources
8
9import numpy
10import PIL.Image
11
12from ..data.transforms import (
13 RemoveBlackBorders,
14 ElasticDeformation,
15 SingleAutoLevel16to8,
16)
17from ..data.loader import load_pil
18
19
20def _data_file(f):
21 return pkg_resources.resource_filename(__name__, os.path.join("data", f))
22
23
24def test_remove_black_borders():
25 # Get a raw sample with black border
26 data_file = _data_file("raw_with_black_border.png")
27 raw_with_black_border = PIL.Image.open(data_file)
28
29 # Remove the black border
30 rbb = RemoveBlackBorders()
31 raw_rbb_removed = rbb(raw_with_black_border)
32
33 # Get the same sample without black border
34 data_file_2 = _data_file("raw_without_black_border.png")
35 raw_without_black_border = PIL.Image.open(data_file_2)
36
37 # Compare both
38 raw_rbb_removed = numpy.asarray(raw_rbb_removed)
39 raw_without_black_border = numpy.asarray(raw_without_black_border)
40
41 numpy.testing.assert_array_equal(raw_without_black_border, raw_rbb_removed)
42
43
44def test_elastic_deformation():
45 # Get a raw sample without deformation
46 data_file = _data_file("raw_without_elastic_deformation.png")
47 raw_without_deformation = PIL.Image.open(data_file)
48
49 # Elastic deforms the raw
50 ed = ElasticDeformation(random_state=numpy.random.RandomState(seed=100))
51 raw_deformed = ed(raw_without_deformation)
52
53 # Get the same sample already deformed (with seed=100)
54 data_file_2 = _data_file("raw_with_elastic_deformation.png")
55 raw_2 = PIL.Image.open(data_file_2)
56
57 # Compare both
58 raw_deformed = numpy.asarray(raw_deformed)
59 raw_2 = numpy.asarray(raw_2)
60
61 numpy.testing.assert_array_equal(raw_deformed, raw_2)
62
63
64def test_load_pil_16bit():
65
66 # If the ratio is higher 0.5, image is probably clipped
67 Level16to8 = SingleAutoLevel16to8()
68
69 data_file = _data_file("16bits.png")
70 image = numpy.array(Level16to8(load_pil(data_file)))
71
72 count_pixels = numpy.count_nonzero(image)
73 count_max_value = numpy.count_nonzero(image == image.max())
74
75 assert count_max_value / count_pixels < 0.5
76
77 # It should not do anything to an image already in 8 bits
78 data_file = _data_file("raw_without_black_border.png")
79 img_loaded = load_pil(data_file)
80
81 original_8bits = numpy.array(img_loaded)
82 leveled_8bits = numpy.array(Level16to8(img_loaded))
83
84 numpy.testing.assert_array_equal(original_8bits, leveled_8bits)