Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1#!/usr/bin/env python
2# coding=utf-8
5"""Common utilities"""
7import PIL.Image
8import PIL.ImageChops
9import PIL.ImageOps
10import torch
11import torch.utils.data
13from .transforms import Compose, ToTensor
16def invert_mode1_image(img):
17 """Inverts a binary PIL image (mode == ``"1"``)"""
19 return PIL.ImageOps.invert(img.convert("RGB")).convert(
20 mode="1", dither=None
21 )
24def subtract_mode1_images(img1, img2):
25 """Returns a new image that represents ``img1 - img2``"""
27 return PIL.ImageChops.subtract(img1, img2)
30def overlayed_image(
31 img,
32 label,
33 mask=None,
34 label_color=(0, 255, 0),
35 mask_color=(0, 0, 255),
36 alpha=0.4,
37):
38 """Creates an image showing existing labels and masko
40 This function creates a new representation of the input image ``img``
41 overlaying a green mask for labelled objects, and a red mask for parts of
42 the image that should be ignored (negative mask). By looking at this
43 representation, it shall be possible to verify if the dataset/loader is
44 yielding images correctly.
47 Parameters
48 ----------
50 img : PIL.Image.Image
51 An RGB PIL image that represents the original image for analysis
53 label : PIL.Image.Image
54 A PIL image in any mode that represents the labelled elements in the
55 image. In case of images in mode "L" or "1", white pixels represent
56 the labelled object. Black-er pixels represent background.
58 mask : py:class:`PIL.Image.Image`, Optional
59 A PIL image in mode "1" that represents the mask for the image. White
60 pixels indicate where content should be used, black pixels, content to
61 be ignored.
63 label_color : py:class:`tuple`, Optional
64 A tuple with three integer entries indicating the RGB color to be used
65 for labels. Only used if ``label.mode`` is "1" or "L".
67 mask_color : py:class:`tuple`, Optional
68 A tuple with three integer entries indicating the RGB color to be used
69 for the mask-negative (black parts in the original mask).
71 alpha : py:class:`float`, Optional
72 A float that indicates how much of blending should be performed between
73 the label, mask and the original image.
76 Returns
77 -------
79 image : PIL.Image.Image
80 A new image overlaying the original image, object labels (in green) and
81 what is to be considered parts to be **masked-out** (i.e. a
82 representation of a negative of the mask).
84 """
86 # creates a representation of labels, in RGB format, with the right color
87 if label.mode in ("1", "L"):
88 label_colored = PIL.ImageOps.colorize(
89 label.convert("L"), (0, 0, 0), label_color
90 )
91 else:
92 # user has already passed an RGB version of the labels, just compose
93 label_colored = label
95 # blend image and label together - first blend to get vessels drawn with a
96 # slight "label_color" tone on top, then composite with original image, to
97 # avoid loosing brightness.
98 retval = PIL.Image.blend(img, label_colored, alpha)
99 if label.mode == "1":
100 composite_mask = invert_mode1_image(label)
101 else:
102 composite_mask = PIL.ImageOps.invert(label.convert("L"))
103 retval = PIL.Image.composite(img, retval, composite_mask)
105 # creates a representation of the mask negative with the right color
106 if mask is not None:
107 antimask_colored = PIL.ImageOps.colorize(
108 mask.convert("L"), mask_color, (0, 0, 0)
109 )
110 tmp = PIL.Image.blend(retval, antimask_colored, alpha)
111 retval = PIL.Image.composite(retval, tmp, mask)
113 return retval
116class SampleListDataset(torch.utils.data.Dataset):
117 """PyTorch dataset wrapper around Sample lists
119 A transform object can be passed that will be applied to the image, ground
120 truth and mask (if present).
122 It supports indexing such that dataset[i] can be used to get the i-th
123 sample.
126 Parameters
127 ----------
129 samples : list
130 A list of :py:class:`bob.ip.binseg.data.sample.Sample` objects
132 transforms : :py:class:`list`, Optional
133 a list of transformations to be applied to **both** image and
134 ground-truth data. Notice a last transform
135 (:py:class:`bob.ip.binseg.data.transforms.ToTensor`) is always applied
136 - you do not need to add that.
138 """
140 def __init__(self, samples, transforms=[]):
142 self._samples = samples
143 self.transforms = transforms
145 @property
146 def transforms(self):
147 return self._transforms.transforms[:-1]
149 @transforms.setter
150 def transforms(self, value):
151 self._transforms = Compose(value + [ToTensor()])
153 def copy(self, transforms=None):
154 """Returns a deep copy of itself, optionally resetting transforms
156 Parameters
157 ----------
159 transforms : :py:class:`list`, Optional
160 An optional list of transforms to set in the copy. If not
161 specified, use ``self.transforms``.
162 """
164 return SampleListDataset(self._samples, transforms or self.transforms)
166 def keys(self):
167 """Generator producing all keys for all samples"""
168 for k in self._samples:
169 yield k.key
171 def all_keys_match(self, other):
172 """Compares all keys to ``other``, return ``True`` if all match"""
173 return len(self) == len(other) and all(
174 [(ks == ko) for ks, ko in zip(self.keys(), other.keys())]
175 )
177 def __len__(self):
178 """
180 Returns
181 -------
183 size : int
184 size of the dataset
186 """
187 return len(self._samples)
189 def __getitem__(self, key):
190 """
192 Parameters
193 ----------
195 key : int, slice
197 Returns
198 -------
200 sample : list
201 The sample data: ``[key, image[, gt[, mask]]]``
203 """
205 if isinstance(key, slice):
206 return [self[k] for k in range(*key.indices(len(self)))]
207 else: # we try it as an int
208 item = self._samples[key]
209 data = item.data # triggers data loading
211 retval = [data["data"]]
212 if "label" in data:
213 retval.append(data["label"])
214 if "mask" in data:
215 retval.append(data["mask"])
217 if self._transforms:
218 retval = self._transforms(*retval)
220 return [item.key] + retval
223class SSLDataset(torch.utils.data.Dataset):
224 """PyTorch dataset wrapper around labelled and unlabelled sample lists
226 Yields elements of the form:
228 .. code-block:: text
230 [key, image, ground-truth, [mask,] unlabelled-key, unlabelled-image]
232 The size of the dataset is the same as the labelled dataset.
234 Indexing works by selecting the right element on the labelled dataset, and
235 randomly picking another one from the unlabelled dataset
237 Parameters
238 ----------
240 labelled : :py:class:`torch.utils.data.Dataset`
241 Labelled dataset (**must** have "mask" and "label" entries for every
242 sample)
244 unlabelled : :py:class:`torch.utils.data.Dataset`
245 Unlabelled dataset (**may** have "mask" and "label" entries for every
246 sample, but are ignored)
248 """
250 def __init__(self, labelled, unlabelled):
251 self._labelled = labelled
252 self._unlabelled = unlabelled
254 def keys(self):
255 """Generator producing all keys for all samples"""
256 for k in self._labelled + self._unlabelled:
257 yield k.key
259 def all_keys_match(self, other):
260 """Compares all keys to ``other``, return ``True`` if all match"""
261 return len(self) == len(other) and all(
262 [(ks == ko) for ks, ko in zip(self.keys(), other.keys())]
263 )
265 def __len__(self):
266 """
268 Returns
269 -------
271 size : int
272 size of the dataset
274 """
276 return len(self._labelled)
278 def __getitem__(self, index):
279 """
281 Parameters
282 ----------
283 index : int
284 The index for the element to pick
286 Returns
287 -------
289 sample : list
290 The sample data: ``[key, image, gt, [mask, ]unlab-key, unlab-image]``
292 """
294 retval = self._labelled[index]
295 # gets one an unlabelled sample randomly to follow the labelled sample
296 unlab = self._unlabelled[torch.randint(len(self._unlabelled), ())]
297 # only interested in key and data
298 return retval + unlab[:2]