1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3
4"""Image transformations for our pipelines
5
6Differences between methods here and those from
7:py:mod:`torchvision.transforms` is that these support multiple simultaneous
8image inputs, which are required to feed segmentation networks (e.g. image and
9labels or masks). We also take care of data augmentations, in which random
10flipping and rotation needs to be applied across all input images, but color
11jittering, for example, only on the input image.
12"""
13
14import random
15
16import numpy
17import PIL.Image
18import PIL.ImageOps
19import torchvision.transforms
20import torchvision.transforms.functional
21
22
23class TupleMixin:
24 """Adds support to work with tuples of objects to torchvision transforms"""
25
26 def __call__(self, *args):
27 return [super(TupleMixin, self).__call__(k) for k in args]
28
29
30class CenterCrop(TupleMixin, torchvision.transforms.CenterCrop):
31 pass
32
33
34class Pad(TupleMixin, torchvision.transforms.Pad):
35 pass
36
37
38class Resize(TupleMixin, torchvision.transforms.Resize):
39 pass
40
41
42class ToTensor(TupleMixin, torchvision.transforms.ToTensor):
43 pass
44
45
46class Compose(torchvision.transforms.Compose):
47 def __call__(self, *args):
48 for t in self.transforms:
49 args = t(*args)
50 return args
51
52
53class SingleCrop:
54 """
55 Crops one image at the given coordinates.
56
57 Attributes
58 ----------
59 i : int
60 upper pixel coordinate.
61 j : int
62 left pixel coordinate.
63 h : int
64 height of the cropped image.
65 w : int
66 width of the cropped image.
67 """
68
69 def __init__(self, i, j, h, w):
70 self.i = i
71 self.j = j
72 self.h = h
73 self.w = w
74
75 def __call__(self, img):
76 return img.crop((self.j, self.i, self.j + self.w, self.i + self.h))
77
78
79class Crop(TupleMixin, SingleCrop):
80 """
81 Crops multiple images at the given coordinates.
82
83 Attributes
84 ----------
85 i : int
86 upper pixel coordinate.
87 j : int
88 left pixel coordinate.
89 h : int
90 height of the cropped image.
91 w : int
92 width of the cropped image.
93 """
94
95 pass
96
97
98class SingleAutoLevel16to8:
99 """Converts a 16-bit image to 8-bit representation using "auto-level"
100
101 This transform assumes that the input image is gray-scaled.
102
103 To auto-level, we calculate the maximum and the minimum of the image, and
104 consider such a range should be mapped to the [0,255] range of the
105 destination image.
106
107 """
108
109 def __call__(self, img):
110 imin, imax = img.getextrema()
111 irange = imax - imin
112 return PIL.Image.fromarray(
113 numpy.round(
114 255.0 * (numpy.array(img).astype(float) - imin) / irange
115 ).astype("uint8"),
116 ).convert("L")
117
118
119class AutoLevel16to8(TupleMixin, SingleAutoLevel16to8):
120 """Converts multiple 16-bit images to 8-bit representations using "auto-level"
121
122 This transform assumes that the input images are gray-scaled.
123
124 To auto-level, we calculate the maximum and the minimum of the image, and
125 consider such a range should be mapped to the [0,255] range of the
126 destination image.
127 """
128
129 pass
130
131
132class SingleToRGB:
133 """Converts from any input format to RGB, using an ADAPTIVE conversion.
134
135 This transform takes the input image and converts it to RGB using
136 py:method:`PIL.Image.Image.convert`, with `mode='RGB'` and using all other
137 defaults. This may be aggressive if applied to 16-bit images without
138 further considerations.
139 """
140
141 def __call__(self, img):
142 return img.convert(mode="RGB")
143
144
145class ToRGB(TupleMixin, SingleToRGB):
146 """Converts from any input format to RGB, using an ADAPTIVE conversion.
147
148 This transform takes the input image and converts it to RGB using
149 py:method:`PIL.Image.Image.convert`, with `mode='RGB'` and using all other
150 defaults. This may be aggressive if applied to 16-bit images without
151 further considerations.
152 """
153
154 pass
155
156
157class RandomHorizontalFlip(torchvision.transforms.RandomHorizontalFlip):
158 """Randomly flips all input images horizontally"""
159
160 def __call__(self, *args):
161 if random.random() < self.p:
162 return [
163 torchvision.transforms.functional.hflip(img) for img in args
164 ]
165 else:
166 return args
167
168
169class RandomVerticalFlip(torchvision.transforms.RandomVerticalFlip):
170 """Randomly flips all input images vertically"""
171
172 def __call__(self, *args):
173 if random.random() < self.p:
174 return [
175 torchvision.transforms.functional.vflip(img) for img in args
176 ]
177 else:
178 return args
179
180
181class RandomRotation(torchvision.transforms.RandomRotation):
182 """Randomly rotates all input images by the same amount
183
184 Unlike the current torchvision implementation, we also accept a probability
185 for applying the rotation.
186
187
188 Parameters
189 ----------
190
191 p : :py:class:`float`, Optional
192 probability at which the operation is applied
193
194 **kwargs : dict
195 passed to parent. Notice that, if not set, we use the following
196 defaults here for the underlying transform from torchvision:
197
198 * ``degrees``: 15
199 * ``interpolation``: ``torchvision.transforms.functional.InterpolationMode.BILINEAR``
200
201 """
202
203 def __init__(self, p=0.5, **kwargs):
204 kwargs.setdefault("degrees", 15)
205 kwargs.setdefault(
206 "interpolation",
207 torchvision.transforms.functional.InterpolationMode.BILINEAR,
208 )
209 super(RandomRotation, self).__init__(**kwargs)
210 self.p = p
211
212 def __call__(self, *args):
213 # applies **the same** rotation to all inputs (data and ground-truth)
214 if random.random() < self.p:
215 angle = self.get_params(self.degrees)
216 return [
217 torchvision.transforms.functional.rotate(
218 img, angle, self.interpolation, self.expand, self.center
219 )
220 for img in args
221 ]
222 else:
223 return args
224
225 def __repr__(self):
226 retval = super(RandomRotation, self).__repr__()
227 return retval.replace("(", f"(p={self.p},", 1)
228
229
230class ColorJitter(torchvision.transforms.ColorJitter):
231 """Randomly applies a color jitter transformation on the **first** image
232
233 Notice this transform extension, unlike others in this module, only affects
234 the first image passed as input argument. Unlike the current torchvision
235 implementation, we also accept a probability for applying the jitter.
236
237
238 Parameters
239 ----------
240
241 p : :py:class:`float`, Optional
242 probability at which the operation is applied
243
244 **kwargs : dict
245 passed to parent. Notice that, if not set, we use the following
246 defaults here for the underlying transform from torchvision:
247
248 * ``brightness``: 0.3
249 * ``contrast``: 0.3
250 * ``saturation``: 0.02
251 * ``hue``: 0.02
252
253 """
254
255 def __init__(self, p=0.5, **kwargs):
256 kwargs.setdefault("brightness", 0.3)
257 kwargs.setdefault("contrast", 0.3)
258 kwargs.setdefault("saturation", 0.02)
259 kwargs.setdefault("hue", 0.02)
260 super(ColorJitter, self).__init__(**kwargs)
261 self.p = p
262
263 def __call__(self, *args):
264 if random.random() < self.p:
265 # applies color jitter only to the input image not ground-truth
266 return [super(ColorJitter, self).__call__(args[0]), *args[1:]]
267 else:
268 return args
269
270 def __repr__(self):
271 retval = super(ColorJitter, self).__repr__()
272 return retval.replace("(", f"(p={self.p},", 1)
273
274
275def _expand2square(pil_img, background_color):
276 """
277 Function that pad the minimum between the height and the width to fit a square
278
279 Parameters
280 ----------
281
282 pil_img : PIL.Image.Image
283 A PIL image that represents the image for analysis.
284
285 background_color: py:class:`tuple`, Optional
286 A tuple to represent the color of the background of the image in order to pad with the same color.
287 If the image is an RGB image background_color should be a tuple of size 3 , if it's a grayscale image the variable can be represented with an integer.
288
289 Returns
290 -------
291
292 image : PIL.Image.Image
293 A new image with height equal to width.
294
295
296 """
297 width, height = pil_img.size
298 if width == height:
299 return pil_img
300 elif width > height:
301 result = PIL.Image.new(pil_img.mode, (width, width), background_color)
302 result.paste(pil_img, (0, (width - height) // 2))
303 return result
304 else:
305 result = PIL.Image.new(pil_img.mode, (height, height), background_color)
306 result.paste(pil_img, ((height - width) // 2, 0))
307 return result
308
309
310class ResizeCrop:
311 """
312 Crop all the images by removing the black pixels in the width and height until it finds a non-black pixel.
313
314 """
315
316 def __call__(self, *args):
317
318 img = args[0]
319 label = args[1]
320 mask = args[2]
321 mask_data = numpy.asarray(mask)
322 wid = numpy.sum(mask_data, axis=0)
323 heig = numpy.sum(mask_data, axis=1)
324
325 crop_left, crop_right = (wid != 0).argmax(axis=0), (
326 wid[::-1] != 0
327 ).argmax(axis=0)
328 crop_up, crop_down = (heig != 0).argmax(axis=0), (
329 heig[::-1] != 0
330 ).argmax(axis=0)
331
332 border = (crop_left, crop_up, crop_right, crop_down)
333
334 new_mask = PIL.ImageOps.crop(mask, border)
335 new_img = PIL.ImageOps.crop(img, border)
336 new_label = PIL.ImageOps.crop(label, border)
337
338 new_img = _expand2square(new_img, (0, 0, 0))
339 new_label = _expand2square(new_label, 0)
340 new_mask = _expand2square(new_mask, 0)
341
342 args = (new_img, new_label, new_mask)
343
344 return args