1#!/usr/bin/env python
2# coding=utf-8
3
4
5"""Common utilities"""
6
7import PIL.Image
8import PIL.ImageChops
9import PIL.ImageDraw
10import PIL.ImageOps
11import torch
12import torch.utils.data
13
14from .transforms import Compose, GetBoundingBox, ToTensor
15
16
17def invert_mode1_image(img):
18 """Inverts a binary PIL image (mode == ``"1"``)"""
19
20 return PIL.ImageOps.invert(img.convert("RGB")).convert(
21 mode="1", dither=None
22 )
23
24
25def subtract_mode1_images(img1, img2):
26 """Returns a new image that represents ``img1 - img2``"""
27
28 return PIL.ImageChops.subtract(img1, img2)
29
30
31def overlayed_image(
32 img,
33 label,
34 mask=None,
35 label_color=(0, 255, 0),
36 mask_color=(0, 0, 255),
37 alpha=0.4,
38):
39 """Creates an image showing existing labels and masko
40
41 This function creates a new representation of the input image ``img``
42 overlaying a green mask for labelled objects, and a red mask for parts of
43 the image that should be ignored (negative mask). By looking at this
44 representation, it shall be possible to verify if the dataset/loader is
45 yielding images correctly.
46
47
48 Parameters
49 ----------
50
51 img : PIL.Image.Image
52 An RGB PIL image that represents the original image for analysis
53
54 label : PIL.Image.Image
55 A PIL image in any mode that represents the labelled elements in the
56 image. In case of images in mode "L" or "1", white pixels represent
57 the labelled object. Black-er pixels represent background.
58
59 mask : py:class:`PIL.Image.Image`, Optional
60 A PIL image in mode "1" that represents the mask for the image. White
61 pixels indicate where content should be used, black pixels, content to
62 be ignored.
63
64 label_color : py:class:`tuple`, Optional
65 A tuple with three integer entries indicating the RGB color to be used
66 for labels. Only used if ``label.mode`` is "1" or "L".
67
68 mask_color : py:class:`tuple`, Optional
69 A tuple with three integer entries indicating the RGB color to be used
70 for the mask-negative (black parts in the original mask).
71
72 alpha : py:class:`float`, Optional
73 A float that indicates how much of blending should be performed between
74 the label, mask and the original image.
75
76
77 Returns
78 -------
79
80 image : PIL.Image.Image
81 A new image overlaying the original image, object labels (in green) and
82 what is to be considered parts to be **masked-out** (i.e. a
83 representation of a negative of the mask).
84
85 """
86
87 # creates a representation of labels, in RGB format, with the right color
88 if label.mode in ("1", "L"):
89 label_colored = PIL.ImageOps.colorize(
90 label.convert("L"), (0, 0, 0), label_color
91 )
92 else:
93 # user has already passed an RGB version of the labels, just compose
94 label_colored = label
95
96 # blend image and label together - first blend to get vessels drawn with a
97 # slight "label_color" tone on top, then composite with original image, to
98 # avoid loosing brightness.
99 retval = PIL.Image.blend(img, label_colored, alpha)
100 if label.mode == "1":
101 composite_mask = invert_mode1_image(label)
102 else:
103 composite_mask = PIL.ImageOps.invert(label.convert("L"))
104 retval = PIL.Image.composite(img, retval, composite_mask)
105
106 # creates a representation of the mask negative with the right color
107 if mask is not None:
108 antimask_colored = PIL.ImageOps.colorize(
109 mask.convert("L"), mask_color, (0, 0, 0)
110 )
111 tmp = PIL.Image.blend(retval, antimask_colored, alpha)
112 retval = PIL.Image.composite(retval, tmp, mask)
113
114 return retval
115
116
117def overlayed_bbox_image(
118 img,
119 box,
120 box_color=(0, 255, 0),
121 width=1,
122):
123 """Creates an image showing existing bounding boxes
124
125 This function creates a new representation of the input image ``img``
126 overlaying a green bounding box for labelled objects. By looking at this
127 representation, it shall be possible to verify if the dataset/loader is
128 yielding images correctly.
129
130
131 Parameters
132 ----------
133
134 img : PIL.Image.Image
135 An RGB PIL image that represents the original image for analysis
136
137 box : list
138 A list of bounding box coordinates.
139
140 box_color : py:class:`tuple`, Optional
141 A tuple with three integer entries indicating the RGB color to be used
142 for bounding box.
143
144 width : py:class:`int`, Optional
145 An integer indicating the size of the rectangle line, in pixels.
146
147 Returns
148 -------
149
150 image : PIL.Image.Image
151 A new image overlaying the original image, object labels (in green).
152
153 """
154
155 # creates a representation of labels, in RGB format, with the right color
156 # create rectangle image
157 img1 = PIL.ImageDraw.Draw(img)
158 x1, y1, x2, y2 = box
159 shape = [(x1, y1), (x2, y2)]
160 img1.rectangle(shape, outline=box_color, width=width)
161
162 return img
163
164
165class SampleListDataset(torch.utils.data.Dataset):
166 """PyTorch dataset wrapper around Sample lists
167
168 A transform object can be passed that will be applied to the image, ground
169 truth and mask (if present).
170
171 It supports indexing such that dataset[i] can be used to get the i-th
172 sample.
173
174
175 Parameters
176 ----------
177
178 samples : list
179 A list of :py:class:`bob.ip.common.data.sample.Sample` objects
180
181 transforms : :py:class:`list`, Optional
182 a list of transformations to be applied to **both** image and
183 ground-truth data. Notice a last transform
184 (:py:class:`bob.ip.common.data.transforms.ToTensor`) is always applied
185 - you do not need to add that.
186
187 """
188
189 def __init__(self, samples, transforms=[]):
190
191 self._samples = samples
192 self.transforms = transforms
193
194 @property
195 def transforms(self):
196 return self._transforms.transforms[:-1]
197
198 @transforms.setter
199 def transforms(self, value):
200 self._transforms = Compose(value + [ToTensor()])
201
202 def copy(self, transforms=None):
203 """Returns a deep copy of itself, optionally resetting transforms
204
205 Parameters
206 ----------
207
208 transforms : :py:class:`list`, Optional
209 An optional list of transforms to set in the copy. If not
210 specified, use ``self.transforms``.
211 """
212
213 return SampleListDataset(self._samples, transforms or self.transforms)
214
215 def keys(self):
216 """Generator producing all keys for all samples"""
217 for k in self._samples:
218 yield k.key
219
220 def all_keys_match(self, other):
221 """Compares all keys to ``other``, return ``True`` if all match"""
222 return len(self) == len(other) and all(
223 [(ks == ko) for ks, ko in zip(self.keys(), other.keys())]
224 )
225
226 def __len__(self):
227 """
228
229 Returns
230 -------
231
232 size : int
233 size of the dataset
234
235 """
236 return len(self._samples)
237
238 def __getitem__(self, key):
239 """
240
241 Parameters
242 ----------
243
244 key : int, slice
245
246 Returns
247 -------
248
249 sample : list
250 The sample data: ``[key, image[, gt[, mask]]]``
251
252 """
253
254 if isinstance(key, slice):
255 return [self[k] for k in range(*key.indices(len(self)))]
256 else: # we try it as an int
257 item = self._samples[key]
258 data = item.data # triggers data loading
259
260 retval = [data["data"]]
261 if "label" in data:
262 retval.append(data["label"])
263 if "mask" in data:
264 retval.append(data["mask"])
265
266 if self._transforms:
267 retval = self._transforms(*retval)
268
269 return [item.key] + retval
270
271
272class SampleListDetectionDataset(torch.utils.data.Dataset):
273 """PyTorch dataset wrapper around Sample lists
274
275 A transform object can be passed that will be applied to the image, ground
276 truth and mask (if present).
277
278 It supports indexing such that dataset[i] can be used to get the i-th
279 sample.
280
281
282 Parameters
283 ----------
284
285 samples : list
286 A list of :py:class:`bob.ip.common.data.sample.Sample` objects
287
288 transforms : :py:class:`list`, Optional
289 a list of transformations to be applied to **both** image and
290 ground-truth data. Notice a last transform
291 (:py:class:`bob.ip.common.data.transforms.ToTensor`) is always applied
292 - you do not need to add that.
293
294 """
295
296 def __init__(self, samples, transforms=[]):
297
298 self._samples = samples
299 self.transforms = transforms
300
301 @property
302 def transforms(self):
303 return self._transforms.transforms[:-1]
304
305 @transforms.setter
306 def transforms(self, value):
307 self._transforms = Compose(value + [ToTensor()])
308
309 def copy(self, transforms=None):
310 """Returns a deep copy of itself, optionally resetting transforms
311
312 Parameters
313 ----------
314
315 transforms : :py:class:`list`, Optional
316 An optional list of transforms to set in the copy. If not
317 specified, use ``self.transforms``.
318 """
319
320 return SampleListDetectionDataset(
321 self._samples, transforms or self.transforms
322 )
323
324 def keys(self):
325 """Generator producing all keys for all samples"""
326 for k in self._samples:
327 yield k.key
328
329 def all_keys_match(self, other):
330 """Compares all keys to ``other``, return ``True`` if all match"""
331 return len(self) == len(other) and all(
332 [(ks == ko) for ks, ko in zip(self.keys(), other.keys())]
333 )
334
335 def __len__(self):
336 """
337
338 Returns
339 -------
340
341 size : int
342 size of the dataset
343
344 """
345 return len(self._samples)
346
347 def __getitem__(self, key):
348 """
349
350 Parameters
351 ----------
352
353 key : int, slice
354
355 Returns
356 -------
357
358 sample : list
359 The sample data: ``[key, image[, gt[, mask]]]``
360
361 """
362
363 if isinstance(key, slice):
364 return [self[k] for k in range(*key.indices(len(self)))]
365 else: # we try it as an int
366 item = self._samples[key]
367 data = item.data # triggers data loading
368
369 retval = [data["data"]]
370 if "label" in data:
371 retval.append(data["label"])
372 if "mask" in data:
373 retval.append(data["mask"])
374
375 if self._transforms:
376 retval = self._transforms(*retval)
377
378 bounding_box_transform = GetBoundingBox()
379 retval = bounding_box_transform(retval)
380
381 return [item.key] + retval
382
383
384class SSLDataset(torch.utils.data.Dataset):
385 """PyTorch dataset wrapper around labelled and unlabelled sample lists
386
387 Yields elements of the form:
388
389 .. code-block:: text
390
391 [key, image, ground-truth, [mask,] unlabelled-key, unlabelled-image]
392
393 The size of the dataset is the same as the labelled dataset.
394
395 Indexing works by selecting the right element on the labelled dataset, and
396 randomly picking another one from the unlabelled dataset
397
398 Parameters
399 ----------
400
401 labelled : :py:class:`torch.utils.data.Dataset`
402 Labelled dataset (**must** have "mask" and "label" entries for every
403 sample)
404
405 unlabelled : :py:class:`torch.utils.data.Dataset`
406 Unlabelled dataset (**may** have "mask" and "label" entries for every
407 sample, but are ignored)
408
409 """
410
411 def __init__(self, labelled, unlabelled):
412 self._labelled = labelled
413 self._unlabelled = unlabelled
414
415 def keys(self):
416 """Generator producing all keys for all samples"""
417 for k in self._labelled + self._unlabelled:
418 yield k.key
419
420 def all_keys_match(self, other):
421 """Compares all keys to ``other``, return ``True`` if all match"""
422 return len(self) == len(other) and all(
423 [(ks == ko) for ks, ko in zip(self.keys(), other.keys())]
424 )
425
426 def __len__(self):
427 """
428
429 Returns
430 -------
431
432 size : int
433 size of the dataset
434
435 """
436
437 return len(self._labelled)
438
439 def __getitem__(self, index):
440 """
441
442 Parameters
443 ----------
444 index : int
445 The index for the element to pick
446
447 Returns
448 -------
449
450 sample : list
451 The sample data: ``[key, image, gt, [mask, ]unlab-key, unlab-image]``
452
453 """
454
455 retval = self._labelled[index]
456 # gets one an unlabelled sample randomly to follow the labelled sample
457 unlab = self._unlabelled[torch.randint(len(self._unlabelled), ())]
458 # only interested in key and data
459 return retval + unlab[:2]