Coverage for src/deepdraw/data/transforms.py: 66%

167 statements  

« prev     ^ index     » next       coverage.py v7.3.1, created at 2023-11-30 15:00 +0100

1# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> 

2# 

3# SPDX-License-Identifier: GPL-3.0-or-later 

4 

5"""Image transformations for our pipelines. 

6 

7Differences between methods here and those from 

8:py:mod:`torchvision.transforms` is that these support multiple simultaneous 

9image inputs, which are required to feed segmentation networks (e.g. image and 

10labels or masks). We also take care of data augmentations, in which random 

11flipping and rotation needs to be applied across all input images, but color 

12jittering, for example, only on the input image. 

13""" 

14 

15import random 

16 

17import numpy 

18import PIL.Image 

19import PIL.ImageOps 

20import torchvision.transforms 

21import torchvision.transforms.functional 

22 

23 

24class TupleMixin: 

25 """Adds support to work with tuples of objects to torchvision 

26 transforms.""" 

27 

28 def __call__(self, *args): 

29 return [super(TupleMixin, self).__call__(k) for k in args] 

30 

31 

32class CenterCrop(TupleMixin, torchvision.transforms.CenterCrop): 

33 pass 

34 

35 

36class Pad(TupleMixin, torchvision.transforms.Pad): 

37 pass 

38 

39 

40class Resize(TupleMixin, torchvision.transforms.Resize): 

41 pass 

42 

43 

44class ToTensor(TupleMixin, torchvision.transforms.ToTensor): 

45 pass 

46 

47 

48class Compose(torchvision.transforms.Compose): 

49 def __call__(self, *args): 

50 for t in self.transforms: 

51 args = t(*args) 

52 return args 

53 

54 

55# NEVER USED IN THE PACKAGE 

56# Should it be kept? 

57class SingleCrop: 

58 """Crops one image at the given coordinates. 

59 

60 Attributes 

61 ---------- 

62 i : int 

63 upper pixel coordinate. 

64 j : int 

65 left pixel coordinate. 

66 h : int 

67 height of the cropped image. 

68 w : int 

69 width of the cropped image. 

70 """ 

71 

72 def __init__(self, i, j, h, w): 

73 self.i = i 

74 self.j = j 

75 self.h = h 

76 self.w = w 

77 

78 def __call__(self, img): 

79 return img.crop((self.j, self.i, self.j + self.w, self.i + self.h)) 

80 

81 

82class Crop(TupleMixin, SingleCrop): 

83 """Crops multiple images at the given coordinates. 

84 

85 Attributes 

86 ---------- 

87 i : int 

88 upper pixel coordinate. 

89 j : int 

90 left pixel coordinate. 

91 h : int 

92 height of the cropped image. 

93 w : int 

94 width of the cropped image. 

95 """ 

96 

97 pass 

98 

99 

100class SingleAutoLevel16to8: 

101 """Converts a 16-bit image to 8-bit representation using "auto-level". 

102 

103 This transform assumes that the input image is gray-scaled. 

104 

105 To auto-level, we calculate the maximum and the minimum of the image, and 

106 consider such a range should be mapped to the [0,255] range of the 

107 destination image. 

108 """ 

109 

110 def __call__(self, img): 

111 imin, imax = img.getextrema() 

112 irange = imax - imin 

113 return PIL.Image.fromarray( 

114 numpy.round( 

115 255.0 * (numpy.array(img).astype(float) - imin) / irange 

116 ).astype("uint8"), 

117 ).convert("L") 

118 

119 

120class AutoLevel16to8(TupleMixin, SingleAutoLevel16to8): 

121 """Converts multiple 16-bit images to 8-bit representations using "auto- 

122 level". 

123 

124 This transform assumes that the input images are gray-scaled. 

125 

126 To auto-level, we calculate the maximum and the minimum of the image, and 

127 consider such a range should be mapped to the [0,255] range of the 

128 destination image. 

129 """ 

130 

131 pass 

132 

133 

134class SingleToRGB: 

135 """Converts from any input format to RGB, using an ADAPTIVE conversion. 

136 

137 This transform takes the input image and converts it to RGB using 

138 py:method:`PIL.Image.Image.convert`, with `mode='RGB'` and using all 

139 other defaults. This may be aggressive if applied to 16-bit images 

140 without further considerations. 

141 """ 

142 

143 def __call__(self, img): 

144 return img.convert(mode="RGB") 

145 

146 

147class ToRGB(TupleMixin, SingleToRGB): 

148 """Converts from any input format to RGB, using an ADAPTIVE conversion. 

149 

150 This transform takes the input image and converts it to RGB using 

151 py:method:`PIL.Image.Image.convert`, with `mode='RGB'` and using all 

152 other defaults. This may be aggressive if applied to 16-bit images 

153 without further considerations. 

154 """ 

155 

156 pass 

157 

158 

159class RandomHorizontalFlip(torchvision.transforms.RandomHorizontalFlip): 

160 """Randomly flips all input images horizontally.""" 

161 

162 def __call__(self, *args): 

163 if random.random() < self.p: 

164 return [ 

165 torchvision.transforms.functional.hflip(img) for img in args 

166 ] 

167 else: 

168 return args 

169 

170 

171class RandomVerticalFlip(torchvision.transforms.RandomVerticalFlip): 

172 """Randomly flips all input images vertically.""" 

173 

174 def __call__(self, *args): 

175 if random.random() < self.p: 

176 return [ 

177 torchvision.transforms.functional.vflip(img) for img in args 

178 ] 

179 else: 

180 return args 

181 

182 

183class RandomRotation(torchvision.transforms.RandomRotation): 

184 """Randomly rotates all input images by the same amount. 

185 

186 Unlike the current torchvision implementation, we also accept a probability 

187 for applying the rotation. 

188 

189 

190 Parameters 

191 ---------- 

192 

193 p : :py:class:`float`, Optional 

194 probability at which the operation is applied 

195 

196 **kwargs : dict 

197 passed to parent. Notice that, if not set, we use the following 

198 defaults here for the underlying transform from torchvision: 

199 

200 * ``degrees``: 15 

201 * ``interpolation``: ``torchvision.transforms.functional.InterpolationMode.BILINEAR`` 

202 """ 

203 

204 def __init__(self, p=0.5, **kwargs): 

205 kwargs.setdefault("degrees", 15) 

206 kwargs.setdefault( 

207 "interpolation", 

208 torchvision.transforms.functional.InterpolationMode.BILINEAR, 

209 ) 

210 super().__init__(**kwargs) 

211 self.p = p 

212 

213 def __call__(self, *args): 

214 # applies **the same** rotation to all inputs (data and ground-truth) 

215 if random.random() < self.p: 

216 angle = self.get_params(self.degrees) 

217 return [ 

218 torchvision.transforms.functional.rotate( 

219 img, angle, self.interpolation, self.expand, self.center 

220 ) 

221 for img in args 

222 ] 

223 else: 

224 return args 

225 

226 def __repr__(self): 

227 retval = super().__repr__() 

228 return retval.replace("(", f"(p={self.p},", 1) 

229 

230 

231class ColorJitter(torchvision.transforms.ColorJitter): 

232 """Randomly applies a color jitter transformation on the **first** image. 

233 

234 Notice this transform extension, unlike others in this module, only affects 

235 the first image passed as input argument. Unlike the current torchvision 

236 implementation, we also accept a probability for applying the jitter. 

237 

238 

239 Parameters 

240 ---------- 

241 

242 p : :py:class:`float`, Optional 

243 probability at which the operation is applied 

244 

245 **kwargs : dict 

246 passed to parent. Notice that, if not set, we use the following 

247 defaults here for the underlying transform from torchvision: 

248 

249 * ``brightness``: 0.3 

250 * ``contrast``: 0.3 

251 * ``saturation``: 0.02 

252 * ``hue``: 0.02 

253 """ 

254 

255 def __init__(self, p=0.5, **kwargs): 

256 kwargs.setdefault("brightness", 0.3) 

257 kwargs.setdefault("contrast", 0.3) 

258 kwargs.setdefault("saturation", 0.02) 

259 kwargs.setdefault("hue", 0.02) 

260 super().__init__(**kwargs) 

261 self.p = p 

262 

263 def __call__(self, *args): 

264 if random.random() < self.p: 

265 # applies color jitter only to the input image not ground-truth 

266 return [super().__call__(args[0]), *args[1:]] 

267 else: 

268 return args 

269 

270 def __repr__(self): 

271 retval = super().__repr__() 

272 return retval.replace("(", f"(p={self.p},", 1) 

273 

274 

275def _expand2square(pil_img, background_color): 

276 """Function that pad the minimum between the height and the width to fit a 

277 square. 

278 

279 Parameters 

280 ---------- 

281 

282 pil_img : PIL.Image.Image 

283 A PIL image that represents the image for analysis. 

284 

285 background_color: py:class:`tuple`, Optional 

286 A tuple to represent the color of the background of the image in order 

287 to pad with the same color. If the image is an RGB image 

288 background_color should be a tuple of size 3 , if it's a grayscale 

289 image the variable can be represented with an integer. 

290 

291 Returns 

292 ------- 

293 

294 image : PIL.Image.Image 

295 A new image with height equal to width. 

296 """ 

297 

298 width, height = pil_img.size 

299 if width == height: 

300 return pil_img 

301 elif width > height: 

302 result = PIL.Image.new(pil_img.mode, (width, width), background_color) 

303 result.paste(pil_img, (0, (width - height) // 2)) 

304 return result 

305 else: 

306 result = PIL.Image.new(pil_img.mode, (height, height), background_color) 

307 result.paste(pil_img, ((height - width) // 2, 0)) 

308 return result 

309 

310 

311class ShrinkIntoSquare: 

312 """Crops black borders and then resize to a square with minimal padding. 

313 

314 This transform can crop all the images by removing the black pixels in the 

315 width and height until it finds a non-black pixel. Then, expands the image 

316 back until it makes a square with minimal size. 

317 

318 

319 Parameters 

320 ---------- 

321 

322 reference : :py:class:`int`, Optional 

323 Which reference part of the sample to use for cropping black borders. 

324 If not set, use the first object on the sample (typically, the image). 

325 

326 threshold : :py:class:`int`, Optional 

327 Threshold to use for when considering what is a "black" border 

328 """ 

329 

330 def __init__(self, reference=0, threshold=0): 

331 self.reference = reference 

332 self.threshold = threshold 

333 

334 def __call__(self, *args): 

335 ref = numpy.asarray(args[self.reference].convert("L")) 

336 width = numpy.sum(ref, axis=0) > self.threshold 

337 height = numpy.sum(ref, axis=1) > self.threshold 

338 

339 border = ( 

340 width.argmax(), 

341 height.argmax(), 

342 width[::-1].argmax(), 

343 height[::-1].argmax(), 

344 ) 

345 

346 new_args = [PIL.ImageOps.crop(k, border) for k in args] 

347 

348 def _black_background(i): 

349 return (0, 0, 0) if i.mode == "RGB" else 0 

350 

351 return [_expand2square(k, _black_background(k)) for k in new_args] 

352 

353 

354class GaussianBlur(torchvision.transforms.GaussianBlur): 

355 """Randomly applies a gaussian blur transformation on the **first** image. 

356 

357 Notice this transform extension, unlike others in this module, only affects 

358 the first image passed as input argument. Unlike the current torchvision 

359 implementation, we also accept a probability for applying the blur. 

360 

361 

362 Parameters 

363 ---------- 

364 

365 p : :py:class:`float`, Optional 

366 probability at which the operation is applied 

367 

368 **kwargs : dict 

369 passed to parent. Notice that, if not set, we use the following 

370 defaults here for the underlying transform from torchvision: 

371 

372 * ``kernel_size``: (5, 5) 

373 * ``sigma``: (0.1, 5) 

374 """ 

375 

376 def __init__(self, p=0.5, **kwargs): 

377 kwargs.setdefault("kernel_size", (5, 5)) 

378 kwargs.setdefault("sigma", (0.1, 5)) 

379 

380 super().__init__(**kwargs) 

381 self.p = p 

382 

383 def __call__(self, *args): 

384 if random.random() < self.p: 

385 # applies gaussian blur only to the input image not ground-truth 

386 return [super().__call__(args[0]), *args[1:]] 

387 else: 

388 return args 

389 

390 

391class GroundTruthCrop: 

392 """Crop image in a square keeping only the area with the ground truth. 

393 

394 This transform can crop all images given a ground-truth mask as reference. 

395 Notice that the crop will result in a square image at the end, which means 

396 that it will keep the bigger dimension and adjust the smaller one to fit 

397 into a square. There's an option to add extra area around the gt bounding 

398 box. If resulting dimensions are larger than the boundaries of the image, 

399 minimal padding will be done to keep the image in a square shape. 

400 

401 Parameters 

402 ---------- 

403 

404 reference : :py:class:`int`, Optional 

405 Which reference part of the sample to use for getting coordinates. 

406 If not set, use the second object on the sample (typically, the mask). 

407 

408 extra_area : :py:class:`float`, Optional 

409 Multiplier that will add the extra area around the ground-truth 

410 bounding box. Example: 0.1 will result in a crop with dimensions of 

411 the largest side increased by 10%. If not set, the default will be 0 

412 (only the ground-truth box). 

413 """ 

414 

415 def __init__(self, reference=1, extra_area=0.0): 

416 self.reference = reference 

417 self.extra_area = extra_area 

418 

419 def __call__(self, *args): 

420 ref = args[self.reference] 

421 

422 max_w, max_h = ref.size 

423 

424 where = numpy.where(ref) 

425 y0 = numpy.min(where[0]) 

426 y1 = numpy.max(where[0]) 

427 x0 = numpy.min(where[1]) 

428 x1 = numpy.max(where[1]) 

429 

430 w = x1 - x0 

431 h = y1 - y0 

432 

433 extra_x = self.extra_area * w / 2 

434 extra_y = self.extra_area * h / 2 

435 

436 new_w = (1 + self.extra_area) * w 

437 new_h = (1 + self.extra_area) * h 

438 

439 diff = abs(new_w - new_h) / 2 

440 

441 if new_w == new_h: 

442 x0_new = x0.copy() - extra_x 

443 x1_new = x1.copy() + extra_x 

444 y0_new = y0.copy() - extra_y 

445 y1_new = y1.copy() + extra_y 

446 

447 elif new_w > new_h: 

448 x0_new = x0.copy() - extra_x 

449 x1_new = x1.copy() + extra_x 

450 y0_new = y0.copy() - extra_y - diff 

451 y1_new = y1.copy() + extra_y + diff 

452 

453 else: 

454 x0_new = x0.copy() - extra_x - diff 

455 x1_new = x1.copy() + extra_x + diff 

456 y0_new = y0.copy() - extra_y 

457 y1_new = y1.copy() + extra_y 

458 

459 border = (x0_new, y0_new, max_w - x1_new, max_h - y1_new) 

460 

461 def _expand_img( 

462 pil_img, background_color, x0_pad=0, x1_pad=0, y0_pad=0, y1_pad=0 

463 ): 

464 width = pil_img.size[0] + x0_pad + x1_pad 

465 height = pil_img.size[1] + y0_pad + y1_pad 

466 

467 result = PIL.Image.new( 

468 pil_img.mode, (width, height), background_color 

469 ) 

470 result.paste(pil_img, (x0_pad, y0_pad)) 

471 return result 

472 

473 def _black_background(i): 

474 return (0, 0, 0) if i.mode == "RGB" else 0 

475 

476 d_x0 = numpy.rint(max([0 - x0_new, 0])).astype(int) 

477 d_y0 = numpy.rint(max([0 - y0_new, 0])).astype(int) 

478 d_x1 = numpy.rint(max([x1_new - max_w, 0])).astype(int) 

479 d_y1 = numpy.rint(max([y1_new - max_h, 0])).astype(int) 

480 

481 new_args = [ 

482 _expand_img( 

483 k, 

484 _black_background(k), 

485 x0_pad=d_x0, 

486 x1_pad=d_x1, 

487 y0_pad=d_y0, 

488 y1_pad=d_y1, 

489 ) 

490 for k in args 

491 ] 

492 

493 new_args = [PIL.ImageOps.crop(k, border) for k in new_args] 

494 

495 return new_args