1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3
4"""Image transformations for our pipelines
5
6Differences between methods here and those from
7:py:mod:`torchvision.transforms` is that these support multiple simultaneous
8image inputs, which are required to feed segmentation networks (e.g. image and
9labels or masks). We also take care of data augmentations, in which random
10flipping and rotation needs to be applied across all input images, but color
11jittering, for example, only on the input image.
12"""
13
14import random
15
16import numpy
17import PIL.Image
18import torchvision.transforms
19import torchvision.transforms.functional
20from scipy.ndimage.filters import gaussian_filter
21from scipy.ndimage.interpolation import map_coordinates
22
23class SingleAutoLevel16to8:
24 """Converts a 16-bit image to 8-bit representation using "auto-level"
25
26 This transform assumes that the input image is gray-scaled.
27
28 To auto-level, we calculate the maximum and the minimum of the image, and
29 consider such a range should be mapped to the [0,255] range of the
30 destination image.
31
32 """
33
34 def __call__(self, img):
35 imin, imax = img.getextrema()
36 irange = imax - imin
37 return PIL.Image.fromarray(
38 numpy.round(
39 255.0 * (numpy.array(img).astype(float) - imin) / irange
40 ).astype("uint8"),
41 ).convert("L")
42
43
44class RemoveBlackBorders:
45 """Remove black borders of CXR"""
46 def __init__(self, threshold=0):
47 self.threshold = threshold
48
49 def __call__(self, img):
50 img = numpy.asarray(img)
51 mask = numpy.asarray(img) > self.threshold
52 return PIL.Image.fromarray(
53 img[numpy.ix_(mask.any(1), mask.any(0))]
54 )
55
56class ElasticDeformation:
57 """Elastic deformation of 2D image slightly adapted from [SIMARD-2003]_.
58 .. [SIMARD-2003] Simard, Steinkraus and Platt, "Best Practices for
59 Convolutional Neural Networks applied to Visual Document Analysis", in
60 Proc. of the International Conference on Document Analysis and
61 Recognition, 2003.
62 Source: https://gist.github.com/oeway/2e3b989e0343f0884388ed7ed82eb3b0
63 """
64 def __init__(self, alpha=1000, sigma=30, spline_order=1, mode='nearest', random_state=numpy.random, p=1):
65 self.alpha = alpha
66 self.sigma = sigma
67 self.spline_order = spline_order
68 self.mode = mode
69 self.random_state = random_state
70 self.p = p
71
72 def __call__(self, img):
73
74 if random.random() < self.p:
75
76 img = numpy.asarray(img)
77
78 assert img.ndim == 2
79
80 shape = img.shape
81
82 dx = gaussian_filter((self.random_state.rand(*shape) * 2 - 1),
83 self.sigma, mode="constant", cval=0) * self.alpha
84 dy = gaussian_filter((self.random_state.rand(*shape) * 2 - 1),
85 self.sigma, mode="constant", cval=0) * self.alpha
86
87 x, y = numpy.meshgrid(numpy.arange(shape[0]), numpy.arange(shape[1]), indexing='ij')
88 indices = [numpy.reshape(x + dx, (-1, 1)), numpy.reshape(y + dy, (-1, 1))]
89 result = numpy.empty_like(img)
90 result[:, :] = map_coordinates(
91 img[:, :], indices, order=self.spline_order, mode=self.mode).reshape(shape)
92 return PIL.Image.fromarray(result)
93 else:
94 return img