Coverage for /scratch/builds/bob/bob.ip.binseg/miniconda/conda-bld/bob.ip.binseg_1673966692152/_test_env_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_p/lib/python3.10/site-packages/bob/ip/common/data/utils.py: 80%

113 statements  

« prev     ^ index     » next       coverage.py v7.0.5, created at 2023-01-17 15:03 +0000

1#!/usr/bin/env python 

2# coding=utf-8 

3 

4 

5"""Common utilities""" 

6 

7import PIL.Image 

8import PIL.ImageChops 

9import PIL.ImageDraw 

10import PIL.ImageOps 

11import torch 

12import torch.utils.data 

13 

14from .transforms import Compose, GetBoundingBox, ToTensor 

15 

16 

17def invert_mode1_image(img): 

18 """Inverts a binary PIL image (mode == ``"1"``)""" 

19 

20 return PIL.ImageOps.invert(img.convert("RGB")).convert( 

21 mode="1", dither=None 

22 ) 

23 

24 

25def subtract_mode1_images(img1, img2): 

26 """Returns a new image that represents ``img1 - img2``""" 

27 

28 return PIL.ImageChops.subtract(img1, img2) 

29 

30 

31def overlayed_image( 

32 img, 

33 label, 

34 mask=None, 

35 label_color=(0, 255, 0), 

36 mask_color=(0, 0, 255), 

37 alpha=0.4, 

38): 

39 """Creates an image showing existing labels and masko 

40 

41 This function creates a new representation of the input image ``img`` 

42 overlaying a green mask for labelled objects, and a red mask for parts of 

43 the image that should be ignored (negative mask). By looking at this 

44 representation, it shall be possible to verify if the dataset/loader is 

45 yielding images correctly. 

46 

47 

48 Parameters 

49 ---------- 

50 

51 img : PIL.Image.Image 

52 An RGB PIL image that represents the original image for analysis 

53 

54 label : PIL.Image.Image 

55 A PIL image in any mode that represents the labelled elements in the 

56 image. In case of images in mode "L" or "1", white pixels represent 

57 the labelled object. Black-er pixels represent background. 

58 

59 mask : py:class:`PIL.Image.Image`, Optional 

60 A PIL image in mode "1" that represents the mask for the image. White 

61 pixels indicate where content should be used, black pixels, content to 

62 be ignored. 

63 

64 label_color : py:class:`tuple`, Optional 

65 A tuple with three integer entries indicating the RGB color to be used 

66 for labels. Only used if ``label.mode`` is "1" or "L". 

67 

68 mask_color : py:class:`tuple`, Optional 

69 A tuple with three integer entries indicating the RGB color to be used 

70 for the mask-negative (black parts in the original mask). 

71 

72 alpha : py:class:`float`, Optional 

73 A float that indicates how much of blending should be performed between 

74 the label, mask and the original image. 

75 

76 

77 Returns 

78 ------- 

79 

80 image : PIL.Image.Image 

81 A new image overlaying the original image, object labels (in green) and 

82 what is to be considered parts to be **masked-out** (i.e. a 

83 representation of a negative of the mask). 

84 

85 """ 

86 

87 # creates a representation of labels, in RGB format, with the right color 

88 if label.mode in ("1", "L"): 

89 label_colored = PIL.ImageOps.colorize( 

90 label.convert("L"), (0, 0, 0), label_color 

91 ) 

92 else: 

93 # user has already passed an RGB version of the labels, just compose 

94 label_colored = label 

95 

96 # blend image and label together - first blend to get vessels drawn with a 

97 # slight "label_color" tone on top, then composite with original image, to 

98 # avoid loosing brightness. 

99 retval = PIL.Image.blend(img, label_colored, alpha) 

100 if label.mode == "1": 

101 composite_mask = invert_mode1_image(label) 

102 else: 

103 composite_mask = PIL.ImageOps.invert(label.convert("L")) 

104 retval = PIL.Image.composite(img, retval, composite_mask) 

105 

106 # creates a representation of the mask negative with the right color 

107 if mask is not None: 

108 antimask_colored = PIL.ImageOps.colorize( 

109 mask.convert("L"), mask_color, (0, 0, 0) 

110 ) 

111 tmp = PIL.Image.blend(retval, antimask_colored, alpha) 

112 retval = PIL.Image.composite(retval, tmp, mask) 

113 

114 return retval 

115 

116 

117def overlayed_bbox_image( 

118 img, 

119 box, 

120 box_color=(0, 255, 0), 

121 width=1, 

122): 

123 """Creates an image showing existing bounding boxes 

124 

125 This function creates a new representation of the input image ``img`` 

126 overlaying a green bounding box for labelled objects. By looking at this 

127 representation, it shall be possible to verify if the dataset/loader is 

128 yielding images correctly. 

129 

130 

131 Parameters 

132 ---------- 

133 

134 img : PIL.Image.Image 

135 An RGB PIL image that represents the original image for analysis 

136 

137 box : list 

138 A list of bounding box coordinates. 

139 

140 box_color : py:class:`tuple`, Optional 

141 A tuple with three integer entries indicating the RGB color to be used 

142 for bounding box. 

143 

144 width : py:class:`int`, Optional 

145 An integer indicating the size of the rectangle line, in pixels. 

146 

147 Returns 

148 ------- 

149 

150 image : PIL.Image.Image 

151 A new image overlaying the original image, object labels (in green). 

152 

153 """ 

154 

155 # creates a representation of labels, in RGB format, with the right color 

156 # create rectangle image 

157 img1 = PIL.ImageDraw.Draw(img) 

158 x1, y1, x2, y2 = box 

159 shape = [(x1, y1), (x2, y2)] 

160 img1.rectangle(shape, outline=box_color, width=width) 

161 

162 return img 

163 

164 

165class SampleListDataset(torch.utils.data.Dataset): 

166 """PyTorch dataset wrapper around Sample lists 

167 

168 A transform object can be passed that will be applied to the image, ground 

169 truth and mask (if present). 

170 

171 It supports indexing such that dataset[i] can be used to get the i-th 

172 sample. 

173 

174 

175 Parameters 

176 ---------- 

177 

178 samples : list 

179 A list of :py:class:`bob.ip.common.data.sample.Sample` objects 

180 

181 transforms : :py:class:`list`, Optional 

182 a list of transformations to be applied to **both** image and 

183 ground-truth data. Notice a last transform 

184 (:py:class:`bob.ip.common.data.transforms.ToTensor`) is always applied 

185 - you do not need to add that. 

186 

187 """ 

188 

189 def __init__(self, samples, transforms=[]): 

190 

191 self._samples = samples 

192 self.transforms = transforms 

193 

194 @property 

195 def transforms(self): 

196 return self._transforms.transforms[:-1] 

197 

198 @transforms.setter 

199 def transforms(self, value): 

200 self._transforms = Compose(value + [ToTensor()]) 

201 

202 def copy(self, transforms=None): 

203 """Returns a deep copy of itself, optionally resetting transforms 

204 

205 Parameters 

206 ---------- 

207 

208 transforms : :py:class:`list`, Optional 

209 An optional list of transforms to set in the copy. If not 

210 specified, use ``self.transforms``. 

211 """ 

212 

213 return SampleListDataset(self._samples, transforms or self.transforms) 

214 

215 def keys(self): 

216 """Generator producing all keys for all samples""" 

217 for k in self._samples: 

218 yield k.key 

219 

220 def all_keys_match(self, other): 

221 """Compares all keys to ``other``, return ``True`` if all match""" 

222 return len(self) == len(other) and all( 

223 [(ks == ko) for ks, ko in zip(self.keys(), other.keys())] 

224 ) 

225 

226 def __len__(self): 

227 """ 

228 

229 Returns 

230 ------- 

231 

232 size : int 

233 size of the dataset 

234 

235 """ 

236 return len(self._samples) 

237 

238 def __getitem__(self, key): 

239 """ 

240 

241 Parameters 

242 ---------- 

243 

244 key : int, slice 

245 

246 Returns 

247 ------- 

248 

249 sample : list 

250 The sample data: ``[key, image[, gt[, mask]]]`` 

251 

252 """ 

253 

254 if isinstance(key, slice): 

255 return [self[k] for k in range(*key.indices(len(self)))] 

256 else: # we try it as an int 

257 item = self._samples[key] 

258 data = item.data # triggers data loading 

259 

260 retval = [data["data"]] 

261 if "label" in data: 

262 retval.append(data["label"]) 

263 if "mask" in data: 

264 retval.append(data["mask"]) 

265 

266 if self._transforms: 

267 retval = self._transforms(*retval) 

268 

269 return [item.key] + retval 

270 

271 

272class SampleListDetectionDataset(torch.utils.data.Dataset): 

273 """PyTorch dataset wrapper around Sample lists 

274 

275 A transform object can be passed that will be applied to the image, ground 

276 truth and mask (if present). 

277 

278 It supports indexing such that dataset[i] can be used to get the i-th 

279 sample. 

280 

281 

282 Parameters 

283 ---------- 

284 

285 samples : list 

286 A list of :py:class:`bob.ip.common.data.sample.Sample` objects 

287 

288 transforms : :py:class:`list`, Optional 

289 a list of transformations to be applied to **both** image and 

290 ground-truth data. Notice a last transform 

291 (:py:class:`bob.ip.common.data.transforms.ToTensor`) is always applied 

292 - you do not need to add that. 

293 

294 """ 

295 

296 def __init__(self, samples, transforms=[]): 

297 

298 self._samples = samples 

299 self.transforms = transforms 

300 

301 @property 

302 def transforms(self): 

303 return self._transforms.transforms[:-1] 

304 

305 @transforms.setter 

306 def transforms(self, value): 

307 self._transforms = Compose(value + [ToTensor()]) 

308 

309 def copy(self, transforms=None): 

310 """Returns a deep copy of itself, optionally resetting transforms 

311 

312 Parameters 

313 ---------- 

314 

315 transforms : :py:class:`list`, Optional 

316 An optional list of transforms to set in the copy. If not 

317 specified, use ``self.transforms``. 

318 """ 

319 

320 return SampleListDetectionDataset( 

321 self._samples, transforms or self.transforms 

322 ) 

323 

324 def keys(self): 

325 """Generator producing all keys for all samples""" 

326 for k in self._samples: 

327 yield k.key 

328 

329 def all_keys_match(self, other): 

330 """Compares all keys to ``other``, return ``True`` if all match""" 

331 return len(self) == len(other) and all( 

332 [(ks == ko) for ks, ko in zip(self.keys(), other.keys())] 

333 ) 

334 

335 def __len__(self): 

336 """ 

337 

338 Returns 

339 ------- 

340 

341 size : int 

342 size of the dataset 

343 

344 """ 

345 return len(self._samples) 

346 

347 def __getitem__(self, key): 

348 """ 

349 

350 Parameters 

351 ---------- 

352 

353 key : int, slice 

354 

355 Returns 

356 ------- 

357 

358 sample : list 

359 The sample data: ``[key, image[, gt[, mask]]]`` 

360 

361 """ 

362 

363 if isinstance(key, slice): 

364 return [self[k] for k in range(*key.indices(len(self)))] 

365 else: # we try it as an int 

366 item = self._samples[key] 

367 data = item.data # triggers data loading 

368 

369 retval = [data["data"]] 

370 if "label" in data: 

371 retval.append(data["label"]) 

372 if "mask" in data: 

373 retval.append(data["mask"]) 

374 

375 if self._transforms: 

376 retval = self._transforms(*retval) 

377 

378 bounding_box_transform = GetBoundingBox() 

379 retval = bounding_box_transform(retval) 

380 

381 return [item.key] + retval 

382 

383 

384class SSLDataset(torch.utils.data.Dataset): 

385 """PyTorch dataset wrapper around labelled and unlabelled sample lists 

386 

387 Yields elements of the form: 

388 

389 .. code-block:: text 

390 

391 [key, image, ground-truth, [mask,] unlabelled-key, unlabelled-image] 

392 

393 The size of the dataset is the same as the labelled dataset. 

394 

395 Indexing works by selecting the right element on the labelled dataset, and 

396 randomly picking another one from the unlabelled dataset 

397 

398 Parameters 

399 ---------- 

400 

401 labelled : :py:class:`torch.utils.data.Dataset` 

402 Labelled dataset (**must** have "mask" and "label" entries for every 

403 sample) 

404 

405 unlabelled : :py:class:`torch.utils.data.Dataset` 

406 Unlabelled dataset (**may** have "mask" and "label" entries for every 

407 sample, but are ignored) 

408 

409 """ 

410 

411 def __init__(self, labelled, unlabelled): 

412 self._labelled = labelled 

413 self._unlabelled = unlabelled 

414 

415 def keys(self): 

416 """Generator producing all keys for all samples""" 

417 for k in self._labelled + self._unlabelled: 

418 yield k.key 

419 

420 def all_keys_match(self, other): 

421 """Compares all keys to ``other``, return ``True`` if all match""" 

422 return len(self) == len(other) and all( 

423 [(ks == ko) for ks, ko in zip(self.keys(), other.keys())] 

424 ) 

425 

426 def __len__(self): 

427 """ 

428 

429 Returns 

430 ------- 

431 

432 size : int 

433 size of the dataset 

434 

435 """ 

436 

437 return len(self._labelled) 

438 

439 def __getitem__(self, index): 

440 """ 

441 

442 Parameters 

443 ---------- 

444 index : int 

445 The index for the element to pick 

446 

447 Returns 

448 ------- 

449 

450 sample : list 

451 The sample data: ``[key, image, gt, [mask, ]unlab-key, unlab-image]`` 

452 

453 """ 

454 

455 retval = self._labelled[index] 

456 # gets one an unlabelled sample randomly to follow the labelled sample 

457 unlab = self._unlabelled[torch.randint(len(self._unlabelled), ())] 

458 # only interested in key and data 

459 return retval + unlab[:2]