1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3
4"""Image transformations for our pipelines
5
6Differences between methods here and those from
7:py:mod:`torchvision.transforms` is that these support multiple simultaneous
8image inputs, which are required to feed segmentation networks (e.g. image and
9labels or masks). We also take care of data augmentations, in which random
10flipping and rotation needs to be applied across all input images, but color
11jittering, for example, only on the input image.
12"""
13
14import random
15
16import numpy
17import PIL.Image
18import PIL.ImageOps
19import torch
20import torchvision.transforms
21import torchvision.transforms.functional
22
23
24class TupleMixin:
25 """Adds support to work with tuples of objects to torchvision transforms"""
26
27 def __call__(self, *args):
28 return [super(TupleMixin, self).__call__(k) for k in args]
29
30
31class CenterCrop(TupleMixin, torchvision.transforms.CenterCrop):
32 pass
33
34
35class Pad(TupleMixin, torchvision.transforms.Pad):
36 pass
37
38
39class Resize(TupleMixin, torchvision.transforms.Resize):
40 pass
41
42
43class ToTensor(TupleMixin, torchvision.transforms.ToTensor):
44 pass
45
46
47class Compose(torchvision.transforms.Compose):
48 def __call__(self, *args):
49 for t in self.transforms:
50 args = t(*args)
51 return args
52
53
54class SingleCrop:
55 """
56 Crops one image at the given coordinates.
57
58 Attributes
59 ----------
60 i : int
61 upper pixel coordinate.
62 j : int
63 left pixel coordinate.
64 h : int
65 height of the cropped image.
66 w : int
67 width of the cropped image.
68 """
69
70 def __init__(self, i, j, h, w):
71 self.i = i
72 self.j = j
73 self.h = h
74 self.w = w
75
76 def __call__(self, img):
77 return img.crop((self.j, self.i, self.j + self.w, self.i + self.h))
78
79
80class Crop(TupleMixin, SingleCrop):
81 """
82 Crops multiple images at the given coordinates.
83
84 Attributes
85 ----------
86 i : int
87 upper pixel coordinate.
88 j : int
89 left pixel coordinate.
90 h : int
91 height of the cropped image.
92 w : int
93 width of the cropped image.
94 """
95
96 pass
97
98
99class SingleAutoLevel16to8:
100 """Converts a 16-bit image to 8-bit representation using "auto-level"
101
102 This transform assumes that the input image is gray-scaled.
103
104 To auto-level, we calculate the maximum and the minimum of the image, and
105 consider such a range should be mapped to the [0,255] range of the
106 destination image.
107
108 """
109
110 def __call__(self, img):
111 imin, imax = img.getextrema()
112 irange = imax - imin
113 return PIL.Image.fromarray(
114 numpy.round(
115 255.0 * (numpy.array(img).astype(float) - imin) / irange
116 ).astype("uint8"),
117 ).convert("L")
118
119
120class AutoLevel16to8(TupleMixin, SingleAutoLevel16to8):
121 """Converts multiple 16-bit images to 8-bit representations using "auto-level"
122
123 This transform assumes that the input images are gray-scaled.
124
125 To auto-level, we calculate the maximum and the minimum of the image, and
126 consider such a range should be mapped to the [0,255] range of the
127 destination image.
128 """
129
130 pass
131
132
133class SingleToRGB:
134 """Converts from any input format to RGB, using an ADAPTIVE conversion.
135
136 This transform takes the input image and converts it to RGB using
137 py:method:`PIL.Image.Image.convert`, with `mode='RGB'` and using all other
138 defaults. This may be aggressive if applied to 16-bit images without
139 further considerations.
140 """
141
142 def __call__(self, img):
143 return img.convert(mode="RGB")
144
145
146class ToRGB(TupleMixin, SingleToRGB):
147 """Converts from any input format to RGB, using an ADAPTIVE conversion.
148
149 This transform takes the input image and converts it to RGB using
150 py:method:`PIL.Image.Image.convert`, with `mode='RGB'` and using all other
151 defaults. This may be aggressive if applied to 16-bit images without
152 further considerations.
153 """
154
155 pass
156
157
158class RandomHorizontalFlip(torchvision.transforms.RandomHorizontalFlip):
159 """Randomly flips all input images horizontally"""
160
161 def __call__(self, *args):
162 if random.random() < self.p:
163 return [
164 torchvision.transforms.functional.hflip(img) for img in args
165 ]
166 else:
167 return args
168
169
170class RandomVerticalFlip(torchvision.transforms.RandomVerticalFlip):
171 """Randomly flips all input images vertically"""
172
173 def __call__(self, *args):
174 if random.random() < self.p:
175 return [
176 torchvision.transforms.functional.vflip(img) for img in args
177 ]
178 else:
179 return args
180
181
182class RandomRotation(torchvision.transforms.RandomRotation):
183 """Randomly rotates all input images by the same amount
184
185 Unlike the current torchvision implementation, we also accept a probability
186 for applying the rotation.
187
188
189 Parameters
190 ----------
191
192 p : :py:class:`float`, Optional
193 probability at which the operation is applied
194
195 **kwargs : dict
196 passed to parent. Notice that, if not set, we use the following
197 defaults here for the underlying transform from torchvision:
198
199 * ``degrees``: 15
200 * ``interpolation``: ``torchvision.transforms.functional.InterpolationMode.BILINEAR``
201
202 """
203
204 def __init__(self, p=0.5, **kwargs):
205 kwargs.setdefault("degrees", 15)
206 kwargs.setdefault(
207 "interpolation",
208 torchvision.transforms.functional.InterpolationMode.BILINEAR,
209 )
210 super(RandomRotation, self).__init__(**kwargs)
211 self.p = p
212
213 def __call__(self, *args):
214 # applies **the same** rotation to all inputs (data and ground-truth)
215 if random.random() < self.p:
216 angle = self.get_params(self.degrees)
217 return [
218 torchvision.transforms.functional.rotate(
219 img, angle, self.interpolation, self.expand, self.center
220 )
221 for img in args
222 ]
223 else:
224 return args
225
226 def __repr__(self):
227 retval = super(RandomRotation, self).__repr__()
228 return retval.replace("(", f"(p={self.p},", 1)
229
230
231class ColorJitter(torchvision.transforms.ColorJitter):
232 """Randomly applies a color jitter transformation on the **first** image
233
234 Notice this transform extension, unlike others in this module, only affects
235 the first image passed as input argument. Unlike the current torchvision
236 implementation, we also accept a probability for applying the jitter.
237
238
239 Parameters
240 ----------
241
242 p : :py:class:`float`, Optional
243 probability at which the operation is applied
244
245 **kwargs : dict
246 passed to parent. Notice that, if not set, we use the following
247 defaults here for the underlying transform from torchvision:
248
249 * ``brightness``: 0.3
250 * ``contrast``: 0.3
251 * ``saturation``: 0.02
252 * ``hue``: 0.02
253
254 """
255
256 def __init__(self, p=0.5, **kwargs):
257 kwargs.setdefault("brightness", 0.3)
258 kwargs.setdefault("contrast", 0.3)
259 kwargs.setdefault("saturation", 0.02)
260 kwargs.setdefault("hue", 0.02)
261 super(ColorJitter, self).__init__(**kwargs)
262 self.p = p
263
264 def __call__(self, *args):
265 if random.random() < self.p:
266 # applies color jitter only to the input image not ground-truth
267 return [super(ColorJitter, self).__call__(args[0]), *args[1:]]
268 else:
269 return args
270
271 def __repr__(self):
272 retval = super(ColorJitter, self).__repr__()
273 return retval.replace("(", f"(p={self.p},", 1)
274
275
276def _expand2square(pil_img, background_color):
277 """
278 Function that pad the minimum between the height and the width to fit a square
279
280 Parameters
281 ----------
282
283 pil_img : PIL.Image.Image
284 A PIL image that represents the image for analysis.
285
286 background_color: py:class:`tuple`, Optional
287 A tuple to represent the color of the background of the image in order
288 to pad with the same color. If the image is an RGB image
289 background_color should be a tuple of size 3 , if it's a grayscale
290 image the variable can be represented with an integer.
291
292 Returns
293 -------
294
295 image : PIL.Image.Image
296 A new image with height equal to width.
297
298 """
299
300 width, height = pil_img.size
301 if width == height:
302 return pil_img
303 elif width > height:
304 result = PIL.Image.new(pil_img.mode, (width, width), background_color)
305 result.paste(pil_img, (0, (width - height) // 2))
306 return result
307 else:
308 result = PIL.Image.new(pil_img.mode, (height, height), background_color)
309 result.paste(pil_img, ((height - width) // 2, 0))
310 return result
311
312
313class ShrinkIntoSquare:
314 """Crops black borders and then resize to a square with minimal padding
315
316 This transform can crop all the images by removing the black pixels in the
317 width and height until it finds a non-black pixel. Then, expands the image
318 back until it makes a square with minimal size.
319
320
321 Parameters
322 ----------
323
324 reference : :py:class:`int`, Optional
325 Which reference part of the sample to use for cropping black borders.
326 If not set, use the first object on the sample (typically, the image).
327
328 threshold : :py:class:`int`, Optional
329 Threshold to use for when considering what is a "black" border
330
331 """
332
333 def __init__(self, reference=0, threshold=0):
334 self.reference = reference
335 self.threshold = threshold
336
337 def __call__(self, *args):
338
339 ref = numpy.asarray(args[self.reference].convert("L"))
340 width = numpy.sum(ref, axis=0) > self.threshold
341 height = numpy.sum(ref, axis=1) > self.threshold
342
343 border = (
344 width.argmax(),
345 height.argmax(),
346 width[::-1].argmax(),
347 height[::-1].argmax(),
348 )
349
350 new_args = [PIL.ImageOps.crop(k, border) for k in args]
351
352 def _black_background(i):
353 return (0, 0, 0) if i.mode == "RGB" else 0
354
355 return [_expand2square(k, _black_background(k)) for k in new_args]
356
357
358class GaussianBlur(torchvision.transforms.GaussianBlur):
359 """Randomly applies a gaussian blur transformation on the **first** image
360
361 Notice this transform extension, unlike others in this module, only affects
362 the first image passed as input argument. Unlike the current torchvision
363 implementation, we also accept a probability for applying the blur.
364
365
366 Parameters
367 ----------
368
369 p : :py:class:`float`, Optional
370 probability at which the operation is applied
371
372 **kwargs : dict
373 passed to parent. Notice that, if not set, we use the following
374 defaults here for the underlying transform from torchvision:
375
376 * ``kernel_size``: (5, 5)
377 * ``sigma``: (0.1, 5)
378 """
379
380 def __init__(self, p=0.5, **kwargs):
381 kwargs.setdefault("kernel_size", (5, 5))
382 kwargs.setdefault("sigma", (0.1, 5))
383
384 super(GaussianBlur, self).__init__(**kwargs)
385 self.p = p
386
387 def __call__(self, *args):
388 if random.random() < self.p:
389 # applies gaussian blur only to the input image not ground-truth
390 return [super(GaussianBlur, self).__call__(args[0]), *args[1:]]
391 else:
392 return args
393
394
395class GetBoundingBox:
396 """Returns image tensor and its corresponding target dict given a mask.
397
398 Parameters
399 ----------
400 image : :py:class:`int`, Optional
401 Which reference part of the sample is the image.
402
403 reference : :py:class:`int`, Optional
404 Which reference part of the sample to use for getting bbox.
405 If not set, use the second object on the sample (typically, the mask).
406 """
407
408 def __init__(self, image=0, reference=1):
409 self.image = image
410 self.reference = reference
411
412 def __call__(self, args):
413
414 ref = args[self.reference][0, :, :]
415
416 obj_ids = ref.unique()
417 obj_ids = obj_ids[1:]
418
419 masks = ref == obj_ids[:, None, None]
420
421 num_objs = len(obj_ids)
422 boxes = []
423 for i in range(num_objs):
424 pos = torch.where(masks[i])
425 xmin = pos[1].min().item()
426 xmax = pos[1].max().item()
427 ymin = pos[0].min().item()
428 ymax = pos[0].max().item()
429 boxes.append([xmin, ymin, xmax, ymax])
430
431 boxes = torch.as_tensor(boxes, dtype=torch.int64)
432 labels = torch.ones((num_objs,), dtype=torch.int64)
433
434 target = {}
435 target["boxes"] = boxes
436 target["labels"] = labels
437
438 return [args[self.image], target]