Coverage for src/deepdraw/data/utils.py: 77%

73 statements  

« prev     ^ index     » next       coverage.py v7.3.1, created at 2023-11-30 15:00 +0100

1# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> 

2# 

3# SPDX-License-Identifier: GPL-3.0-or-later 

4 

5"""Common utilities.""" 

6 

7import PIL.Image 

8import PIL.ImageChops 

9import PIL.ImageDraw 

10import PIL.ImageOps 

11import torch 

12import torch.utils.data 

13 

14from .transforms import Compose, ToTensor 

15 

16 

17def invert_mode1_image(img): 

18 """Inverts a binary PIL image (mode == ``"1"``)""" 

19 

20 return PIL.ImageOps.invert(img.convert("RGB")).convert( 

21 mode="1", dither=None 

22 ) 

23 

24 

25def subtract_mode1_images(img1, img2): 

26 """Returns a new image that represents ``img1 - img2``""" 

27 

28 return PIL.ImageChops.subtract(img1, img2) 

29 

30 

31def overlayed_image( 

32 img, 

33 label, 

34 mask=None, 

35 label_color=(0, 255, 0), 

36 mask_color=(0, 0, 255), 

37 alpha=0.4, 

38): 

39 """Creates an image showing existing labels and mask. 

40 

41 This function creates a new representation of the input image ``img`` 

42 overlaying a green mask for labelled objects, and a red mask for parts of 

43 the image that should be ignored (negative mask). By looking at this 

44 representation, it shall be possible to verify if the dataset/loader is 

45 yielding images correctly. 

46 

47 

48 Parameters 

49 ---------- 

50 

51 img : PIL.Image.Image 

52 An RGB PIL image that represents the original image for analysis 

53 

54 label : PIL.Image.Image 

55 A PIL image in any mode that represents the labelled elements in the 

56 image. In case of images in mode "L" or "1", white pixels represent 

57 the labelled object. Black-er pixels represent background. 

58 

59 mask : py:class:`PIL.Image.Image`, Optional 

60 A PIL image in mode "1" that represents the mask for the image. White 

61 pixels indicate where content should be used, black pixels, content to 

62 be ignored. 

63 

64 label_color : py:class:`tuple`, Optional 

65 A tuple with three integer entries indicating the RGB color to be used 

66 for labels. Only used if ``label.mode`` is "1" or "L". 

67 

68 mask_color : py:class:`tuple`, Optional 

69 A tuple with three integer entries indicating the RGB color to be used 

70 for the mask-negative (black parts in the original mask). 

71 

72 alpha : py:class:`float`, Optional 

73 A float that indicates how much of blending should be performed between 

74 the label, mask and the original image. 

75 

76 

77 Returns 

78 ------- 

79 

80 image : PIL.Image.Image 

81 A new image overlaying the original image, object labels (in green) and 

82 what is to be considered parts to be **masked-out** (i.e. a 

83 representation of a negative of the mask). 

84 """ 

85 

86 # creates a representation of labels, in RGB format, with the right color 

87 if label.mode in ("1", "L"): 

88 label_colored = PIL.ImageOps.colorize( 

89 label.convert("L"), (0, 0, 0), label_color 

90 ) 

91 else: 

92 # user has already passed an RGB version of the labels, just compose 

93 label_colored = label 

94 

95 # blend image and label together - first blend to get vessels drawn with a 

96 # slight "label_color" tone on top, then composite with original image, to 

97 # avoid loosing brightness. 

98 retval = PIL.Image.blend(img, label_colored, alpha) 

99 if label.mode == "1": 

100 composite_mask = invert_mode1_image(label) 

101 else: 

102 composite_mask = PIL.ImageOps.invert(label.convert("L")) 

103 retval = PIL.Image.composite(img, retval, composite_mask) 

104 

105 # creates a representation of the mask negative with the right color 

106 if mask is not None: 

107 antimask_colored = PIL.ImageOps.colorize( 

108 mask.convert("L"), mask_color, (0, 0, 0) 

109 ) 

110 tmp = PIL.Image.blend(retval, antimask_colored, alpha) 

111 retval = PIL.Image.composite(retval, tmp, mask) 

112 

113 return retval 

114 

115 

116class SampleListDataset(torch.utils.data.Dataset): 

117 """PyTorch dataset wrapper around Sample lists. 

118 

119 A transform object can be passed that will be applied to the image, ground 

120 truth and mask (if present). 

121 

122 It supports indexing such that dataset[i] can be used to get the i-th 

123 sample. 

124 

125 

126 Parameters 

127 ---------- 

128 

129 samples : list 

130 A list of :py:class:`deepdraw.data.sample.Sample` objects 

131 

132 transforms : :py:class:`list`, Optional 

133 a list of transformations to be applied to **both** image and 

134 ground-truth data. Notice a last transform 

135 (:py:class:`deepdraw.data.transforms.ToTensor`) is always applied 

136 - you do not need to add that. 

137 """ 

138 

139 def __init__(self, samples, transforms=[]): 

140 self._samples = samples 

141 self.transforms = transforms 

142 

143 @property 

144 def transforms(self): 

145 return self._transforms.transforms[:-1] 

146 

147 @transforms.setter 

148 def transforms(self, value): 

149 self._transforms = Compose(value + [ToTensor()]) 

150 

151 def copy(self, transforms=None): 

152 """Returns a deep copy of itself, optionally resetting transforms. 

153 

154 Parameters 

155 ---------- 

156 

157 transforms : :py:class:`list`, Optional 

158 An optional list of transforms to set in the copy. If not 

159 specified, use ``self.transforms``. 

160 """ 

161 

162 return SampleListDataset(self._samples, transforms or self.transforms) 

163 

164 def keys(self): 

165 """Generator producing all keys for all samples.""" 

166 for k in self._samples: 

167 yield k.key 

168 

169 def all_keys_match(self, other): 

170 """Compares all keys to ``other``, return ``True`` if all match.""" 

171 return len(self) == len(other) and all( 

172 [(ks == ko) for ks, ko in zip(self.keys(), other.keys())] 

173 ) 

174 

175 def __len__(self): 

176 """ 

177 

178 Returns 

179 ------- 

180 

181 size : int 

182 size of the dataset 

183 

184 """ 

185 return len(self._samples) 

186 

187 def __getitem__(self, key): 

188 """ 

189 

190 Parameters 

191 ---------- 

192 

193 key : int, slice 

194 

195 Returns 

196 ------- 

197 

198 sample : list 

199 The sample data: ``[key, image[, gt[, mask]]]`` 

200 

201 """ 

202 

203 if isinstance(key, slice): 

204 return [self[k] for k in range(*key.indices(len(self)))] 

205 else: # we try it as an int 

206 item = self._samples[key] 

207 data = item.data # triggers data loading 

208 

209 retval = [data["data"]] 

210 if "label" in data: 

211 retval.append(data["label"]) 

212 if "mask" in data: 

213 retval.append(data["mask"]) 

214 

215 if self._transforms: 

216 retval = self._transforms(*retval) 

217 

218 return [item.key] + retval 

219 

220 

221# NEVER USED IN THE PACKAGE 

222# Should it be kept? 

223class SSLDataset(torch.utils.data.Dataset): 

224 """PyTorch dataset wrapper around labelled and unlabelled sample lists. 

225 

226 Yields elements of the form: 

227 

228 .. code-block:: text 

229 

230 [key, image, ground-truth, [mask,] unlabelled-key, unlabelled-image] 

231 

232 The size of the dataset is the same as the labelled dataset. 

233 

234 Indexing works by selecting the right element on the labelled dataset, and 

235 randomly picking another one from the unlabelled dataset 

236 

237 Parameters 

238 ---------- 

239 

240 labelled : :py:class:`torch.utils.data.Dataset` 

241 Labelled dataset (**must** have "mask" and "label" entries for every 

242 sample) 

243 

244 unlabelled : :py:class:`torch.utils.data.Dataset` 

245 Unlabelled dataset (**may** have "mask" and "label" entries for every 

246 sample, but are ignored) 

247 """ 

248 

249 def __init__(self, labelled, unlabelled): 

250 self._labelled = labelled 

251 self._unlabelled = unlabelled 

252 

253 def keys(self): 

254 """Generator producing all keys for all samples.""" 

255 for k in self._labelled + self._unlabelled: 

256 yield k.key 

257 

258 def all_keys_match(self, other): 

259 """Compares all keys to ``other``, return ``True`` if all match.""" 

260 return len(self) == len(other) and all( 

261 [(ks == ko) for ks, ko in zip(self.keys(), other.keys())] 

262 ) 

263 

264 def __len__(self): 

265 """ 

266 

267 Returns 

268 ------- 

269 

270 size : int 

271 size of the dataset 

272 

273 """ 

274 

275 return len(self._labelled) 

276 

277 def __getitem__(self, index): 

278 """ 

279 

280 Parameters 

281 ---------- 

282 index : int 

283 The index for the element to pick 

284 

285 Returns 

286 ------- 

287 

288 sample : list 

289 The sample data: ``[key, image, gt, [mask, ]unlab-key, unlab-image]`` 

290 

291 """ 

292 

293 retval = self._labelled[index] 

294 # gets one an unlabelled sample randomly to follow the labelled sample 

295 unlab = self._unlabelled[torch.randint(len(self._unlabelled), ())] 

296 # only interested in key and data 

297 return retval + unlab[:2]