Coverage for src/deepdraw/data/transforms.py: 66%
167 statements
« prev ^ index » next coverage.py v7.3.1, created at 2023-11-30 15:00 +0100
« prev ^ index » next coverage.py v7.3.1, created at 2023-11-30 15:00 +0100
1# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
2#
3# SPDX-License-Identifier: GPL-3.0-or-later
5"""Image transformations for our pipelines.
7Differences between methods here and those from
8:py:mod:`torchvision.transforms` is that these support multiple simultaneous
9image inputs, which are required to feed segmentation networks (e.g. image and
10labels or masks). We also take care of data augmentations, in which random
11flipping and rotation needs to be applied across all input images, but color
12jittering, for example, only on the input image.
13"""
15import random
17import numpy
18import PIL.Image
19import PIL.ImageOps
20import torchvision.transforms
21import torchvision.transforms.functional
24class TupleMixin:
25 """Adds support to work with tuples of objects to torchvision
26 transforms."""
28 def __call__(self, *args):
29 return [super(TupleMixin, self).__call__(k) for k in args]
32class CenterCrop(TupleMixin, torchvision.transforms.CenterCrop):
33 pass
36class Pad(TupleMixin, torchvision.transforms.Pad):
37 pass
40class Resize(TupleMixin, torchvision.transforms.Resize):
41 pass
44class ToTensor(TupleMixin, torchvision.transforms.ToTensor):
45 pass
48class Compose(torchvision.transforms.Compose):
49 def __call__(self, *args):
50 for t in self.transforms:
51 args = t(*args)
52 return args
55# NEVER USED IN THE PACKAGE
56# Should it be kept?
57class SingleCrop:
58 """Crops one image at the given coordinates.
60 Attributes
61 ----------
62 i : int
63 upper pixel coordinate.
64 j : int
65 left pixel coordinate.
66 h : int
67 height of the cropped image.
68 w : int
69 width of the cropped image.
70 """
72 def __init__(self, i, j, h, w):
73 self.i = i
74 self.j = j
75 self.h = h
76 self.w = w
78 def __call__(self, img):
79 return img.crop((self.j, self.i, self.j + self.w, self.i + self.h))
82class Crop(TupleMixin, SingleCrop):
83 """Crops multiple images at the given coordinates.
85 Attributes
86 ----------
87 i : int
88 upper pixel coordinate.
89 j : int
90 left pixel coordinate.
91 h : int
92 height of the cropped image.
93 w : int
94 width of the cropped image.
95 """
97 pass
100class SingleAutoLevel16to8:
101 """Converts a 16-bit image to 8-bit representation using "auto-level".
103 This transform assumes that the input image is gray-scaled.
105 To auto-level, we calculate the maximum and the minimum of the image, and
106 consider such a range should be mapped to the [0,255] range of the
107 destination image.
108 """
110 def __call__(self, img):
111 imin, imax = img.getextrema()
112 irange = imax - imin
113 return PIL.Image.fromarray(
114 numpy.round(
115 255.0 * (numpy.array(img).astype(float) - imin) / irange
116 ).astype("uint8"),
117 ).convert("L")
120class AutoLevel16to8(TupleMixin, SingleAutoLevel16to8):
121 """Converts multiple 16-bit images to 8-bit representations using "auto-
122 level".
124 This transform assumes that the input images are gray-scaled.
126 To auto-level, we calculate the maximum and the minimum of the image, and
127 consider such a range should be mapped to the [0,255] range of the
128 destination image.
129 """
131 pass
134class SingleToRGB:
135 """Converts from any input format to RGB, using an ADAPTIVE conversion.
137 This transform takes the input image and converts it to RGB using
138 py:method:`PIL.Image.Image.convert`, with `mode='RGB'` and using all
139 other defaults. This may be aggressive if applied to 16-bit images
140 without further considerations.
141 """
143 def __call__(self, img):
144 return img.convert(mode="RGB")
147class ToRGB(TupleMixin, SingleToRGB):
148 """Converts from any input format to RGB, using an ADAPTIVE conversion.
150 This transform takes the input image and converts it to RGB using
151 py:method:`PIL.Image.Image.convert`, with `mode='RGB'` and using all
152 other defaults. This may be aggressive if applied to 16-bit images
153 without further considerations.
154 """
156 pass
159class RandomHorizontalFlip(torchvision.transforms.RandomHorizontalFlip):
160 """Randomly flips all input images horizontally."""
162 def __call__(self, *args):
163 if random.random() < self.p:
164 return [
165 torchvision.transforms.functional.hflip(img) for img in args
166 ]
167 else:
168 return args
171class RandomVerticalFlip(torchvision.transforms.RandomVerticalFlip):
172 """Randomly flips all input images vertically."""
174 def __call__(self, *args):
175 if random.random() < self.p:
176 return [
177 torchvision.transforms.functional.vflip(img) for img in args
178 ]
179 else:
180 return args
183class RandomRotation(torchvision.transforms.RandomRotation):
184 """Randomly rotates all input images by the same amount.
186 Unlike the current torchvision implementation, we also accept a probability
187 for applying the rotation.
190 Parameters
191 ----------
193 p : :py:class:`float`, Optional
194 probability at which the operation is applied
196 **kwargs : dict
197 passed to parent. Notice that, if not set, we use the following
198 defaults here for the underlying transform from torchvision:
200 * ``degrees``: 15
201 * ``interpolation``: ``torchvision.transforms.functional.InterpolationMode.BILINEAR``
202 """
204 def __init__(self, p=0.5, **kwargs):
205 kwargs.setdefault("degrees", 15)
206 kwargs.setdefault(
207 "interpolation",
208 torchvision.transforms.functional.InterpolationMode.BILINEAR,
209 )
210 super().__init__(**kwargs)
211 self.p = p
213 def __call__(self, *args):
214 # applies **the same** rotation to all inputs (data and ground-truth)
215 if random.random() < self.p:
216 angle = self.get_params(self.degrees)
217 return [
218 torchvision.transforms.functional.rotate(
219 img, angle, self.interpolation, self.expand, self.center
220 )
221 for img in args
222 ]
223 else:
224 return args
226 def __repr__(self):
227 retval = super().__repr__()
228 return retval.replace("(", f"(p={self.p},", 1)
231class ColorJitter(torchvision.transforms.ColorJitter):
232 """Randomly applies a color jitter transformation on the **first** image.
234 Notice this transform extension, unlike others in this module, only affects
235 the first image passed as input argument. Unlike the current torchvision
236 implementation, we also accept a probability for applying the jitter.
239 Parameters
240 ----------
242 p : :py:class:`float`, Optional
243 probability at which the operation is applied
245 **kwargs : dict
246 passed to parent. Notice that, if not set, we use the following
247 defaults here for the underlying transform from torchvision:
249 * ``brightness``: 0.3
250 * ``contrast``: 0.3
251 * ``saturation``: 0.02
252 * ``hue``: 0.02
253 """
255 def __init__(self, p=0.5, **kwargs):
256 kwargs.setdefault("brightness", 0.3)
257 kwargs.setdefault("contrast", 0.3)
258 kwargs.setdefault("saturation", 0.02)
259 kwargs.setdefault("hue", 0.02)
260 super().__init__(**kwargs)
261 self.p = p
263 def __call__(self, *args):
264 if random.random() < self.p:
265 # applies color jitter only to the input image not ground-truth
266 return [super().__call__(args[0]), *args[1:]]
267 else:
268 return args
270 def __repr__(self):
271 retval = super().__repr__()
272 return retval.replace("(", f"(p={self.p},", 1)
275def _expand2square(pil_img, background_color):
276 """Function that pad the minimum between the height and the width to fit a
277 square.
279 Parameters
280 ----------
282 pil_img : PIL.Image.Image
283 A PIL image that represents the image for analysis.
285 background_color: py:class:`tuple`, Optional
286 A tuple to represent the color of the background of the image in order
287 to pad with the same color. If the image is an RGB image
288 background_color should be a tuple of size 3 , if it's a grayscale
289 image the variable can be represented with an integer.
291 Returns
292 -------
294 image : PIL.Image.Image
295 A new image with height equal to width.
296 """
298 width, height = pil_img.size
299 if width == height:
300 return pil_img
301 elif width > height:
302 result = PIL.Image.new(pil_img.mode, (width, width), background_color)
303 result.paste(pil_img, (0, (width - height) // 2))
304 return result
305 else:
306 result = PIL.Image.new(pil_img.mode, (height, height), background_color)
307 result.paste(pil_img, ((height - width) // 2, 0))
308 return result
311class ShrinkIntoSquare:
312 """Crops black borders and then resize to a square with minimal padding.
314 This transform can crop all the images by removing the black pixels in the
315 width and height until it finds a non-black pixel. Then, expands the image
316 back until it makes a square with minimal size.
319 Parameters
320 ----------
322 reference : :py:class:`int`, Optional
323 Which reference part of the sample to use for cropping black borders.
324 If not set, use the first object on the sample (typically, the image).
326 threshold : :py:class:`int`, Optional
327 Threshold to use for when considering what is a "black" border
328 """
330 def __init__(self, reference=0, threshold=0):
331 self.reference = reference
332 self.threshold = threshold
334 def __call__(self, *args):
335 ref = numpy.asarray(args[self.reference].convert("L"))
336 width = numpy.sum(ref, axis=0) > self.threshold
337 height = numpy.sum(ref, axis=1) > self.threshold
339 border = (
340 width.argmax(),
341 height.argmax(),
342 width[::-1].argmax(),
343 height[::-1].argmax(),
344 )
346 new_args = [PIL.ImageOps.crop(k, border) for k in args]
348 def _black_background(i):
349 return (0, 0, 0) if i.mode == "RGB" else 0
351 return [_expand2square(k, _black_background(k)) for k in new_args]
354class GaussianBlur(torchvision.transforms.GaussianBlur):
355 """Randomly applies a gaussian blur transformation on the **first** image.
357 Notice this transform extension, unlike others in this module, only affects
358 the first image passed as input argument. Unlike the current torchvision
359 implementation, we also accept a probability for applying the blur.
362 Parameters
363 ----------
365 p : :py:class:`float`, Optional
366 probability at which the operation is applied
368 **kwargs : dict
369 passed to parent. Notice that, if not set, we use the following
370 defaults here for the underlying transform from torchvision:
372 * ``kernel_size``: (5, 5)
373 * ``sigma``: (0.1, 5)
374 """
376 def __init__(self, p=0.5, **kwargs):
377 kwargs.setdefault("kernel_size", (5, 5))
378 kwargs.setdefault("sigma", (0.1, 5))
380 super().__init__(**kwargs)
381 self.p = p
383 def __call__(self, *args):
384 if random.random() < self.p:
385 # applies gaussian blur only to the input image not ground-truth
386 return [super().__call__(args[0]), *args[1:]]
387 else:
388 return args
391class GroundTruthCrop:
392 """Crop image in a square keeping only the area with the ground truth.
394 This transform can crop all images given a ground-truth mask as reference.
395 Notice that the crop will result in a square image at the end, which means
396 that it will keep the bigger dimension and adjust the smaller one to fit
397 into a square. There's an option to add extra area around the gt bounding
398 box. If resulting dimensions are larger than the boundaries of the image,
399 minimal padding will be done to keep the image in a square shape.
401 Parameters
402 ----------
404 reference : :py:class:`int`, Optional
405 Which reference part of the sample to use for getting coordinates.
406 If not set, use the second object on the sample (typically, the mask).
408 extra_area : :py:class:`float`, Optional
409 Multiplier that will add the extra area around the ground-truth
410 bounding box. Example: 0.1 will result in a crop with dimensions of
411 the largest side increased by 10%. If not set, the default will be 0
412 (only the ground-truth box).
413 """
415 def __init__(self, reference=1, extra_area=0.0):
416 self.reference = reference
417 self.extra_area = extra_area
419 def __call__(self, *args):
420 ref = args[self.reference]
422 max_w, max_h = ref.size
424 where = numpy.where(ref)
425 y0 = numpy.min(where[0])
426 y1 = numpy.max(where[0])
427 x0 = numpy.min(where[1])
428 x1 = numpy.max(where[1])
430 w = x1 - x0
431 h = y1 - y0
433 extra_x = self.extra_area * w / 2
434 extra_y = self.extra_area * h / 2
436 new_w = (1 + self.extra_area) * w
437 new_h = (1 + self.extra_area) * h
439 diff = abs(new_w - new_h) / 2
441 if new_w == new_h:
442 x0_new = x0.copy() - extra_x
443 x1_new = x1.copy() + extra_x
444 y0_new = y0.copy() - extra_y
445 y1_new = y1.copy() + extra_y
447 elif new_w > new_h:
448 x0_new = x0.copy() - extra_x
449 x1_new = x1.copy() + extra_x
450 y0_new = y0.copy() - extra_y - diff
451 y1_new = y1.copy() + extra_y + diff
453 else:
454 x0_new = x0.copy() - extra_x - diff
455 x1_new = x1.copy() + extra_x + diff
456 y0_new = y0.copy() - extra_y
457 y1_new = y1.copy() + extra_y
459 border = (x0_new, y0_new, max_w - x1_new, max_h - y1_new)
461 def _expand_img(
462 pil_img, background_color, x0_pad=0, x1_pad=0, y0_pad=0, y1_pad=0
463 ):
464 width = pil_img.size[0] + x0_pad + x1_pad
465 height = pil_img.size[1] + y0_pad + y1_pad
467 result = PIL.Image.new(
468 pil_img.mode, (width, height), background_color
469 )
470 result.paste(pil_img, (x0_pad, y0_pad))
471 return result
473 def _black_background(i):
474 return (0, 0, 0) if i.mode == "RGB" else 0
476 d_x0 = numpy.rint(max([0 - x0_new, 0])).astype(int)
477 d_y0 = numpy.rint(max([0 - y0_new, 0])).astype(int)
478 d_x1 = numpy.rint(max([x1_new - max_w, 0])).astype(int)
479 d_y1 = numpy.rint(max([y1_new - max_h, 0])).astype(int)
481 new_args = [
482 _expand_img(
483 k,
484 _black_background(k),
485 x0_pad=d_x0,
486 x1_pad=d_x1,
487 y0_pad=d_y0,
488 y1_pad=d_y1,
489 )
490 for k in args
491 ]
493 new_args = [PIL.ImageOps.crop(k, border) for k in new_args]
495 return new_args