Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1#!/usr/bin/env python 

2# -*- coding: utf-8 -*- 

3 

4"""Image transformations for our pipelines 

5 

6Differences between methods here and those from 

7:py:mod:`torchvision.transforms` is that these support multiple simultaneous 

8image inputs, which are required to feed segmentation networks (e.g. image and 

9labels or masks). We also take care of data augmentations, in which random 

10flipping and rotation needs to be applied across all input images, but color 

11jittering, for example, only on the input image. 

12""" 

13 

14import random 

15 

16import numpy 

17import PIL.Image 

18import PIL.ImageOps 

19import torchvision.transforms 

20import torchvision.transforms.functional 

21 

22 

23class TupleMixin: 

24 """Adds support to work with tuples of objects to torchvision transforms""" 

25 

26 def __call__(self, *args): 

27 return [super(TupleMixin, self).__call__(k) for k in args] 

28 

29 

30class CenterCrop(TupleMixin, torchvision.transforms.CenterCrop): 

31 pass 

32 

33 

34class Pad(TupleMixin, torchvision.transforms.Pad): 

35 pass 

36 

37 

38class Resize(TupleMixin, torchvision.transforms.Resize): 

39 pass 

40 

41 

42class ToTensor(TupleMixin, torchvision.transforms.ToTensor): 

43 pass 

44 

45 

46class Compose(torchvision.transforms.Compose): 

47 def __call__(self, *args): 

48 for t in self.transforms: 

49 args = t(*args) 

50 return args 

51 

52 

53class SingleCrop: 

54 """ 

55 Crops one image at the given coordinates. 

56 

57 Attributes 

58 ---------- 

59 i : int 

60 upper pixel coordinate. 

61 j : int 

62 left pixel coordinate. 

63 h : int 

64 height of the cropped image. 

65 w : int 

66 width of the cropped image. 

67 """ 

68 

69 def __init__(self, i, j, h, w): 

70 self.i = i 

71 self.j = j 

72 self.h = h 

73 self.w = w 

74 

75 def __call__(self, img): 

76 return img.crop((self.j, self.i, self.j + self.w, self.i + self.h)) 

77 

78 

79class Crop(TupleMixin, SingleCrop): 

80 """ 

81 Crops multiple images at the given coordinates. 

82 

83 Attributes 

84 ---------- 

85 i : int 

86 upper pixel coordinate. 

87 j : int 

88 left pixel coordinate. 

89 h : int 

90 height of the cropped image. 

91 w : int 

92 width of the cropped image. 

93 """ 

94 

95 pass 

96 

97 

98class SingleAutoLevel16to8: 

99 """Converts a 16-bit image to 8-bit representation using "auto-level" 

100 

101 This transform assumes that the input image is gray-scaled. 

102 

103 To auto-level, we calculate the maximum and the minimum of the image, and 

104 consider such a range should be mapped to the [0,255] range of the 

105 destination image. 

106 

107 """ 

108 

109 def __call__(self, img): 

110 imin, imax = img.getextrema() 

111 irange = imax - imin 

112 return PIL.Image.fromarray( 

113 numpy.round( 

114 255.0 * (numpy.array(img).astype(float) - imin) / irange 

115 ).astype("uint8"), 

116 ).convert("L") 

117 

118 

119class AutoLevel16to8(TupleMixin, SingleAutoLevel16to8): 

120 """Converts multiple 16-bit images to 8-bit representations using "auto-level" 

121 

122 This transform assumes that the input images are gray-scaled. 

123 

124 To auto-level, we calculate the maximum and the minimum of the image, and 

125 consider such a range should be mapped to the [0,255] range of the 

126 destination image. 

127 """ 

128 

129 pass 

130 

131 

132class SingleToRGB: 

133 """Converts from any input format to RGB, using an ADAPTIVE conversion. 

134 

135 This transform takes the input image and converts it to RGB using 

136 py:method:`PIL.Image.Image.convert`, with `mode='RGB'` and using all other 

137 defaults. This may be aggressive if applied to 16-bit images without 

138 further considerations. 

139 """ 

140 

141 def __call__(self, img): 

142 return img.convert(mode="RGB") 

143 

144 

145class ToRGB(TupleMixin, SingleToRGB): 

146 """Converts from any input format to RGB, using an ADAPTIVE conversion. 

147 

148 This transform takes the input image and converts it to RGB using 

149 py:method:`PIL.Image.Image.convert`, with `mode='RGB'` and using all other 

150 defaults. This may be aggressive if applied to 16-bit images without 

151 further considerations. 

152 """ 

153 

154 pass 

155 

156 

157class RandomHorizontalFlip(torchvision.transforms.RandomHorizontalFlip): 

158 """Randomly flips all input images horizontally""" 

159 

160 def __call__(self, *args): 

161 if random.random() < self.p: 

162 return [ 

163 torchvision.transforms.functional.hflip(img) for img in args 

164 ] 

165 else: 

166 return args 

167 

168 

169class RandomVerticalFlip(torchvision.transforms.RandomVerticalFlip): 

170 """Randomly flips all input images vertically""" 

171 

172 def __call__(self, *args): 

173 if random.random() < self.p: 

174 return [ 

175 torchvision.transforms.functional.vflip(img) for img in args 

176 ] 

177 else: 

178 return args 

179 

180 

181class RandomRotation(torchvision.transforms.RandomRotation): 

182 """Randomly rotates all input images by the same amount 

183 

184 Unlike the current torchvision implementation, we also accept a probability 

185 for applying the rotation. 

186 

187 

188 Parameters 

189 ---------- 

190 

191 p : :py:class:`float`, Optional 

192 probability at which the operation is applied 

193 

194 **kwargs : dict 

195 passed to parent. Notice that, if not set, we use the following 

196 defaults here for the underlying transform from torchvision: 

197 

198 * ``degrees``: 15 

199 * ``interpolation``: ``torchvision.transforms.functional.InterpolationMode.BILINEAR`` 

200 

201 """ 

202 

203 def __init__(self, p=0.5, **kwargs): 

204 kwargs.setdefault("degrees", 15) 

205 kwargs.setdefault( 

206 "interpolation", 

207 torchvision.transforms.functional.InterpolationMode.BILINEAR, 

208 ) 

209 super(RandomRotation, self).__init__(**kwargs) 

210 self.p = p 

211 

212 def __call__(self, *args): 

213 # applies **the same** rotation to all inputs (data and ground-truth) 

214 if random.random() < self.p: 

215 angle = self.get_params(self.degrees) 

216 return [ 

217 torchvision.transforms.functional.rotate( 

218 img, angle, self.interpolation, self.expand, self.center 

219 ) 

220 for img in args 

221 ] 

222 else: 

223 return args 

224 

225 def __repr__(self): 

226 retval = super(RandomRotation, self).__repr__() 

227 return retval.replace("(", f"(p={self.p},", 1) 

228 

229 

230class ColorJitter(torchvision.transforms.ColorJitter): 

231 """Randomly applies a color jitter transformation on the **first** image 

232 

233 Notice this transform extension, unlike others in this module, only affects 

234 the first image passed as input argument. Unlike the current torchvision 

235 implementation, we also accept a probability for applying the jitter. 

236 

237 

238 Parameters 

239 ---------- 

240 

241 p : :py:class:`float`, Optional 

242 probability at which the operation is applied 

243 

244 **kwargs : dict 

245 passed to parent. Notice that, if not set, we use the following 

246 defaults here for the underlying transform from torchvision: 

247 

248 * ``brightness``: 0.3 

249 * ``contrast``: 0.3 

250 * ``saturation``: 0.02 

251 * ``hue``: 0.02 

252 

253 """ 

254 

255 def __init__(self, p=0.5, **kwargs): 

256 kwargs.setdefault("brightness", 0.3) 

257 kwargs.setdefault("contrast", 0.3) 

258 kwargs.setdefault("saturation", 0.02) 

259 kwargs.setdefault("hue", 0.02) 

260 super(ColorJitter, self).__init__(**kwargs) 

261 self.p = p 

262 

263 def __call__(self, *args): 

264 if random.random() < self.p: 

265 # applies color jitter only to the input image not ground-truth 

266 return [super(ColorJitter, self).__call__(args[0]), *args[1:]] 

267 else: 

268 return args 

269 

270 def __repr__(self): 

271 retval = super(ColorJitter, self).__repr__() 

272 return retval.replace("(", f"(p={self.p},", 1) 

273 

274 

275def _expand2square(pil_img, background_color): 

276 """ 

277 Function that pad the minimum between the height and the width to fit a square 

278 

279 Parameters 

280 ---------- 

281 

282 pil_img : PIL.Image.Image 

283 A PIL image that represents the image for analysis. 

284 

285 background_color: py:class:`tuple`, Optional 

286 A tuple to represent the color of the background of the image in order to pad with the same color. 

287 If the image is an RGB image background_color should be a tuple of size 3 , if it's a grayscale image the variable can be represented with an integer. 

288 

289 Returns 

290 ------- 

291 

292 image : PIL.Image.Image 

293 A new image with height equal to width. 

294 

295 

296 """ 

297 width, height = pil_img.size 

298 if width == height: 

299 return pil_img 

300 elif width > height: 

301 result = PIL.Image.new(pil_img.mode, (width, width), background_color) 

302 result.paste(pil_img, (0, (width - height) // 2)) 

303 return result 

304 else: 

305 result = PIL.Image.new(pil_img.mode, (height, height), background_color) 

306 result.paste(pil_img, ((height - width) // 2, 0)) 

307 return result 

308 

309 

310class ResizeCrop: 

311 """ 

312 Crop all the images by removing the black pixels in the width and height until it finds a non-black pixel. 

313 

314 """ 

315 

316 def __call__(self, *args): 

317 

318 img = args[0] 

319 label = args[1] 

320 mask = args[2] 

321 mask_data = numpy.asarray(mask) 

322 wid = numpy.sum(mask_data, axis=0) 

323 heig = numpy.sum(mask_data, axis=1) 

324 

325 crop_left, crop_right = (wid != 0).argmax(axis=0), ( 

326 wid[::-1] != 0 

327 ).argmax(axis=0) 

328 crop_up, crop_down = (heig != 0).argmax(axis=0), ( 

329 heig[::-1] != 0 

330 ).argmax(axis=0) 

331 

332 border = (crop_left, crop_up, crop_right, crop_down) 

333 

334 new_mask = PIL.ImageOps.crop(mask, border) 

335 new_img = PIL.ImageOps.crop(img, border) 

336 new_label = PIL.ImageOps.crop(label, border) 

337 

338 new_img = _expand2square(new_img, (0, 0, 0)) 

339 new_label = _expand2square(new_label, 0) 

340 new_mask = _expand2square(new_mask, 0) 

341 

342 args = (new_img, new_label, new_mask) 

343 

344 return args