Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
4"""Image transformations for our pipelines
6Differences between methods here and those from
7:py:mod:`torchvision.transforms` is that these support multiple simultaneous
8image inputs, which are required to feed segmentation networks (e.g. image and
9labels or masks). We also take care of data augmentations, in which random
10flipping and rotation needs to be applied across all input images, but color
11jittering, for example, only on the input image.
12"""
14import random
16import numpy
17import PIL.Image
18import PIL.ImageOps
19import torchvision.transforms
20import torchvision.transforms.functional
23class TupleMixin:
24 """Adds support to work with tuples of objects to torchvision transforms"""
26 def __call__(self, *args):
27 return [super(TupleMixin, self).__call__(k) for k in args]
30class CenterCrop(TupleMixin, torchvision.transforms.CenterCrop):
31 pass
34class Pad(TupleMixin, torchvision.transforms.Pad):
35 pass
38class Resize(TupleMixin, torchvision.transforms.Resize):
39 pass
42class ToTensor(TupleMixin, torchvision.transforms.ToTensor):
43 pass
46class Compose(torchvision.transforms.Compose):
47 def __call__(self, *args):
48 for t in self.transforms:
49 args = t(*args)
50 return args
53class SingleCrop:
54 """
55 Crops one image at the given coordinates.
57 Attributes
58 ----------
59 i : int
60 upper pixel coordinate.
61 j : int
62 left pixel coordinate.
63 h : int
64 height of the cropped image.
65 w : int
66 width of the cropped image.
67 """
69 def __init__(self, i, j, h, w):
70 self.i = i
71 self.j = j
72 self.h = h
73 self.w = w
75 def __call__(self, img):
76 return img.crop((self.j, self.i, self.j + self.w, self.i + self.h))
79class Crop(TupleMixin, SingleCrop):
80 """
81 Crops multiple images at the given coordinates.
83 Attributes
84 ----------
85 i : int
86 upper pixel coordinate.
87 j : int
88 left pixel coordinate.
89 h : int
90 height of the cropped image.
91 w : int
92 width of the cropped image.
93 """
95 pass
98class SingleAutoLevel16to8:
99 """Converts a 16-bit image to 8-bit representation using "auto-level"
101 This transform assumes that the input image is gray-scaled.
103 To auto-level, we calculate the maximum and the minimum of the image, and
104 consider such a range should be mapped to the [0,255] range of the
105 destination image.
107 """
109 def __call__(self, img):
110 imin, imax = img.getextrema()
111 irange = imax - imin
112 return PIL.Image.fromarray(
113 numpy.round(
114 255.0 * (numpy.array(img).astype(float) - imin) / irange
115 ).astype("uint8"),
116 ).convert("L")
119class AutoLevel16to8(TupleMixin, SingleAutoLevel16to8):
120 """Converts multiple 16-bit images to 8-bit representations using "auto-level"
122 This transform assumes that the input images are gray-scaled.
124 To auto-level, we calculate the maximum and the minimum of the image, and
125 consider such a range should be mapped to the [0,255] range of the
126 destination image.
127 """
129 pass
132class SingleToRGB:
133 """Converts from any input format to RGB, using an ADAPTIVE conversion.
135 This transform takes the input image and converts it to RGB using
136 py:method:`PIL.Image.Image.convert`, with `mode='RGB'` and using all other
137 defaults. This may be aggressive if applied to 16-bit images without
138 further considerations.
139 """
141 def __call__(self, img):
142 return img.convert(mode="RGB")
145class ToRGB(TupleMixin, SingleToRGB):
146 """Converts from any input format to RGB, using an ADAPTIVE conversion.
148 This transform takes the input image and converts it to RGB using
149 py:method:`PIL.Image.Image.convert`, with `mode='RGB'` and using all other
150 defaults. This may be aggressive if applied to 16-bit images without
151 further considerations.
152 """
154 pass
157class RandomHorizontalFlip(torchvision.transforms.RandomHorizontalFlip):
158 """Randomly flips all input images horizontally"""
160 def __call__(self, *args):
161 if random.random() < self.p:
162 return [
163 torchvision.transforms.functional.hflip(img) for img in args
164 ]
165 else:
166 return args
169class RandomVerticalFlip(torchvision.transforms.RandomVerticalFlip):
170 """Randomly flips all input images vertically"""
172 def __call__(self, *args):
173 if random.random() < self.p:
174 return [
175 torchvision.transforms.functional.vflip(img) for img in args
176 ]
177 else:
178 return args
181class RandomRotation(torchvision.transforms.RandomRotation):
182 """Randomly rotates all input images by the same amount
184 Unlike the current torchvision implementation, we also accept a probability
185 for applying the rotation.
188 Parameters
189 ----------
191 p : :py:class:`float`, Optional
192 probability at which the operation is applied
194 **kwargs : dict
195 passed to parent. Notice that, if not set, we use the following
196 defaults here for the underlying transform from torchvision:
198 * ``degrees``: 15
199 * ``interpolation``: ``torchvision.transforms.functional.InterpolationMode.BILINEAR``
201 """
203 def __init__(self, p=0.5, **kwargs):
204 kwargs.setdefault("degrees", 15)
205 kwargs.setdefault(
206 "interpolation",
207 torchvision.transforms.functional.InterpolationMode.BILINEAR,
208 )
209 super(RandomRotation, self).__init__(**kwargs)
210 self.p = p
212 def __call__(self, *args):
213 # applies **the same** rotation to all inputs (data and ground-truth)
214 if random.random() < self.p:
215 angle = self.get_params(self.degrees)
216 return [
217 torchvision.transforms.functional.rotate(
218 img, angle, self.interpolation, self.expand, self.center
219 )
220 for img in args
221 ]
222 else:
223 return args
225 def __repr__(self):
226 retval = super(RandomRotation, self).__repr__()
227 return retval.replace("(", f"(p={self.p},", 1)
230class ColorJitter(torchvision.transforms.ColorJitter):
231 """Randomly applies a color jitter transformation on the **first** image
233 Notice this transform extension, unlike others in this module, only affects
234 the first image passed as input argument. Unlike the current torchvision
235 implementation, we also accept a probability for applying the jitter.
238 Parameters
239 ----------
241 p : :py:class:`float`, Optional
242 probability at which the operation is applied
244 **kwargs : dict
245 passed to parent. Notice that, if not set, we use the following
246 defaults here for the underlying transform from torchvision:
248 * ``brightness``: 0.3
249 * ``contrast``: 0.3
250 * ``saturation``: 0.02
251 * ``hue``: 0.02
253 """
255 def __init__(self, p=0.5, **kwargs):
256 kwargs.setdefault("brightness", 0.3)
257 kwargs.setdefault("contrast", 0.3)
258 kwargs.setdefault("saturation", 0.02)
259 kwargs.setdefault("hue", 0.02)
260 super(ColorJitter, self).__init__(**kwargs)
261 self.p = p
263 def __call__(self, *args):
264 if random.random() < self.p:
265 # applies color jitter only to the input image not ground-truth
266 return [super(ColorJitter, self).__call__(args[0]), *args[1:]]
267 else:
268 return args
270 def __repr__(self):
271 retval = super(ColorJitter, self).__repr__()
272 return retval.replace("(", f"(p={self.p},", 1)
275def _expand2square(pil_img, background_color):
276 """
277 Function that pad the minimum between the height and the width to fit a square
279 Parameters
280 ----------
282 pil_img : PIL.Image.Image
283 A PIL image that represents the image for analysis.
285 background_color: py:class:`tuple`, Optional
286 A tuple to represent the color of the background of the image in order to pad with the same color.
287 If the image is an RGB image background_color should be a tuple of size 3 , if it's a grayscale image the variable can be represented with an integer.
289 Returns
290 -------
292 image : PIL.Image.Image
293 A new image with height equal to width.
296 """
297 width, height = pil_img.size
298 if width == height:
299 return pil_img
300 elif width > height:
301 result = PIL.Image.new(pil_img.mode, (width, width), background_color)
302 result.paste(pil_img, (0, (width - height) // 2))
303 return result
304 else:
305 result = PIL.Image.new(pil_img.mode, (height, height), background_color)
306 result.paste(pil_img, ((height - width) // 2, 0))
307 return result
310class ResizeCrop:
311 """
312 Crop all the images by removing the black pixels in the width and height until it finds a non-black pixel.
314 """
316 def __call__(self, *args):
318 img = args[0]
319 label = args[1]
320 mask = args[2]
321 mask_data = numpy.asarray(mask)
322 wid = numpy.sum(mask_data, axis=0)
323 heig = numpy.sum(mask_data, axis=1)
325 crop_left, crop_right = (wid != 0).argmax(axis=0), (
326 wid[::-1] != 0
327 ).argmax(axis=0)
328 crop_up, crop_down = (heig != 0).argmax(axis=0), (
329 heig[::-1] != 0
330 ).argmax(axis=0)
332 border = (crop_left, crop_up, crop_right, crop_down)
334 new_mask = PIL.ImageOps.crop(mask, border)
335 new_img = PIL.ImageOps.crop(img, border)
336 new_label = PIL.ImageOps.crop(label, border)
338 new_img = _expand2square(new_img, (0, 0, 0))
339 new_label = _expand2square(new_label, 0)
340 new_mask = _expand2square(new_mask, 0)
342 args = (new_img, new_label, new_mask)
344 return args