Coverage for src/deepdraw/data/utils.py: 77%
73 statements
« prev ^ index » next coverage.py v7.3.1, created at 2023-11-30 15:00 +0100
« prev ^ index » next coverage.py v7.3.1, created at 2023-11-30 15:00 +0100
1# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
2#
3# SPDX-License-Identifier: GPL-3.0-or-later
5"""Common utilities."""
7import PIL.Image
8import PIL.ImageChops
9import PIL.ImageDraw
10import PIL.ImageOps
11import torch
12import torch.utils.data
14from .transforms import Compose, ToTensor
17def invert_mode1_image(img):
18 """Inverts a binary PIL image (mode == ``"1"``)"""
20 return PIL.ImageOps.invert(img.convert("RGB")).convert(
21 mode="1", dither=None
22 )
25def subtract_mode1_images(img1, img2):
26 """Returns a new image that represents ``img1 - img2``"""
28 return PIL.ImageChops.subtract(img1, img2)
31def overlayed_image(
32 img,
33 label,
34 mask=None,
35 label_color=(0, 255, 0),
36 mask_color=(0, 0, 255),
37 alpha=0.4,
38):
39 """Creates an image showing existing labels and mask.
41 This function creates a new representation of the input image ``img``
42 overlaying a green mask for labelled objects, and a red mask for parts of
43 the image that should be ignored (negative mask). By looking at this
44 representation, it shall be possible to verify if the dataset/loader is
45 yielding images correctly.
48 Parameters
49 ----------
51 img : PIL.Image.Image
52 An RGB PIL image that represents the original image for analysis
54 label : PIL.Image.Image
55 A PIL image in any mode that represents the labelled elements in the
56 image. In case of images in mode "L" or "1", white pixels represent
57 the labelled object. Black-er pixels represent background.
59 mask : py:class:`PIL.Image.Image`, Optional
60 A PIL image in mode "1" that represents the mask for the image. White
61 pixels indicate where content should be used, black pixels, content to
62 be ignored.
64 label_color : py:class:`tuple`, Optional
65 A tuple with three integer entries indicating the RGB color to be used
66 for labels. Only used if ``label.mode`` is "1" or "L".
68 mask_color : py:class:`tuple`, Optional
69 A tuple with three integer entries indicating the RGB color to be used
70 for the mask-negative (black parts in the original mask).
72 alpha : py:class:`float`, Optional
73 A float that indicates how much of blending should be performed between
74 the label, mask and the original image.
77 Returns
78 -------
80 image : PIL.Image.Image
81 A new image overlaying the original image, object labels (in green) and
82 what is to be considered parts to be **masked-out** (i.e. a
83 representation of a negative of the mask).
84 """
86 # creates a representation of labels, in RGB format, with the right color
87 if label.mode in ("1", "L"):
88 label_colored = PIL.ImageOps.colorize(
89 label.convert("L"), (0, 0, 0), label_color
90 )
91 else:
92 # user has already passed an RGB version of the labels, just compose
93 label_colored = label
95 # blend image and label together - first blend to get vessels drawn with a
96 # slight "label_color" tone on top, then composite with original image, to
97 # avoid loosing brightness.
98 retval = PIL.Image.blend(img, label_colored, alpha)
99 if label.mode == "1":
100 composite_mask = invert_mode1_image(label)
101 else:
102 composite_mask = PIL.ImageOps.invert(label.convert("L"))
103 retval = PIL.Image.composite(img, retval, composite_mask)
105 # creates a representation of the mask negative with the right color
106 if mask is not None:
107 antimask_colored = PIL.ImageOps.colorize(
108 mask.convert("L"), mask_color, (0, 0, 0)
109 )
110 tmp = PIL.Image.blend(retval, antimask_colored, alpha)
111 retval = PIL.Image.composite(retval, tmp, mask)
113 return retval
116class SampleListDataset(torch.utils.data.Dataset):
117 """PyTorch dataset wrapper around Sample lists.
119 A transform object can be passed that will be applied to the image, ground
120 truth and mask (if present).
122 It supports indexing such that dataset[i] can be used to get the i-th
123 sample.
126 Parameters
127 ----------
129 samples : list
130 A list of :py:class:`deepdraw.data.sample.Sample` objects
132 transforms : :py:class:`list`, Optional
133 a list of transformations to be applied to **both** image and
134 ground-truth data. Notice a last transform
135 (:py:class:`deepdraw.data.transforms.ToTensor`) is always applied
136 - you do not need to add that.
137 """
139 def __init__(self, samples, transforms=[]):
140 self._samples = samples
141 self.transforms = transforms
143 @property
144 def transforms(self):
145 return self._transforms.transforms[:-1]
147 @transforms.setter
148 def transforms(self, value):
149 self._transforms = Compose(value + [ToTensor()])
151 def copy(self, transforms=None):
152 """Returns a deep copy of itself, optionally resetting transforms.
154 Parameters
155 ----------
157 transforms : :py:class:`list`, Optional
158 An optional list of transforms to set in the copy. If not
159 specified, use ``self.transforms``.
160 """
162 return SampleListDataset(self._samples, transforms or self.transforms)
164 def keys(self):
165 """Generator producing all keys for all samples."""
166 for k in self._samples:
167 yield k.key
169 def all_keys_match(self, other):
170 """Compares all keys to ``other``, return ``True`` if all match."""
171 return len(self) == len(other) and all(
172 [(ks == ko) for ks, ko in zip(self.keys(), other.keys())]
173 )
175 def __len__(self):
176 """
178 Returns
179 -------
181 size : int
182 size of the dataset
184 """
185 return len(self._samples)
187 def __getitem__(self, key):
188 """
190 Parameters
191 ----------
193 key : int, slice
195 Returns
196 -------
198 sample : list
199 The sample data: ``[key, image[, gt[, mask]]]``
201 """
203 if isinstance(key, slice):
204 return [self[k] for k in range(*key.indices(len(self)))]
205 else: # we try it as an int
206 item = self._samples[key]
207 data = item.data # triggers data loading
209 retval = [data["data"]]
210 if "label" in data:
211 retval.append(data["label"])
212 if "mask" in data:
213 retval.append(data["mask"])
215 if self._transforms:
216 retval = self._transforms(*retval)
218 return [item.key] + retval
221# NEVER USED IN THE PACKAGE
222# Should it be kept?
223class SSLDataset(torch.utils.data.Dataset):
224 """PyTorch dataset wrapper around labelled and unlabelled sample lists.
226 Yields elements of the form:
228 .. code-block:: text
230 [key, image, ground-truth, [mask,] unlabelled-key, unlabelled-image]
232 The size of the dataset is the same as the labelled dataset.
234 Indexing works by selecting the right element on the labelled dataset, and
235 randomly picking another one from the unlabelled dataset
237 Parameters
238 ----------
240 labelled : :py:class:`torch.utils.data.Dataset`
241 Labelled dataset (**must** have "mask" and "label" entries for every
242 sample)
244 unlabelled : :py:class:`torch.utils.data.Dataset`
245 Unlabelled dataset (**may** have "mask" and "label" entries for every
246 sample, but are ignored)
247 """
249 def __init__(self, labelled, unlabelled):
250 self._labelled = labelled
251 self._unlabelled = unlabelled
253 def keys(self):
254 """Generator producing all keys for all samples."""
255 for k in self._labelled + self._unlabelled:
256 yield k.key
258 def all_keys_match(self, other):
259 """Compares all keys to ``other``, return ``True`` if all match."""
260 return len(self) == len(other) and all(
261 [(ks == ko) for ks, ko in zip(self.keys(), other.keys())]
262 )
264 def __len__(self):
265 """
267 Returns
268 -------
270 size : int
271 size of the dataset
273 """
275 return len(self._labelled)
277 def __getitem__(self, index):
278 """
280 Parameters
281 ----------
282 index : int
283 The index for the element to pick
285 Returns
286 -------
288 sample : list
289 The sample data: ``[key, image, gt, [mask, ]unlab-key, unlab-image]``
291 """
293 retval = self._labelled[index]
294 # gets one an unlabelled sample randomly to follow the labelled sample
295 unlab = self._unlabelled[torch.randint(len(self._unlabelled), ())]
296 # only interested in key and data
297 return retval + unlab[:2]