Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1#!/usr/bin/env python 

2# coding=utf-8 

3 

4 

5"""Common utilities""" 

6 

7import PIL.Image 

8import PIL.ImageChops 

9import PIL.ImageOps 

10import torch 

11import torch.utils.data 

12 

13from .transforms import Compose, ToTensor 

14 

15 

16def invert_mode1_image(img): 

17 """Inverts a binary PIL image (mode == ``"1"``)""" 

18 

19 return PIL.ImageOps.invert(img.convert("RGB")).convert( 

20 mode="1", dither=None 

21 ) 

22 

23 

24def subtract_mode1_images(img1, img2): 

25 """Returns a new image that represents ``img1 - img2``""" 

26 

27 return PIL.ImageChops.subtract(img1, img2) 

28 

29 

30def overlayed_image( 

31 img, 

32 label, 

33 mask=None, 

34 label_color=(0, 255, 0), 

35 mask_color=(0, 0, 255), 

36 alpha=0.4, 

37): 

38 """Creates an image showing existing labels and masko 

39 

40 This function creates a new representation of the input image ``img`` 

41 overlaying a green mask for labelled objects, and a red mask for parts of 

42 the image that should be ignored (negative mask). By looking at this 

43 representation, it shall be possible to verify if the dataset/loader is 

44 yielding images correctly. 

45 

46 

47 Parameters 

48 ---------- 

49 

50 img : PIL.Image.Image 

51 An RGB PIL image that represents the original image for analysis 

52 

53 label : PIL.Image.Image 

54 A PIL image in any mode that represents the labelled elements in the 

55 image. In case of images in mode "L" or "1", white pixels represent 

56 the labelled object. Black-er pixels represent background. 

57 

58 mask : py:class:`PIL.Image.Image`, Optional 

59 A PIL image in mode "1" that represents the mask for the image. White 

60 pixels indicate where content should be used, black pixels, content to 

61 be ignored. 

62 

63 label_color : py:class:`tuple`, Optional 

64 A tuple with three integer entries indicating the RGB color to be used 

65 for labels. Only used if ``label.mode`` is "1" or "L". 

66 

67 mask_color : py:class:`tuple`, Optional 

68 A tuple with three integer entries indicating the RGB color to be used 

69 for the mask-negative (black parts in the original mask). 

70 

71 alpha : py:class:`float`, Optional 

72 A float that indicates how much of blending should be performed between 

73 the label, mask and the original image. 

74 

75 

76 Returns 

77 ------- 

78 

79 image : PIL.Image.Image 

80 A new image overlaying the original image, object labels (in green) and 

81 what is to be considered parts to be **masked-out** (i.e. a 

82 representation of a negative of the mask). 

83 

84 """ 

85 

86 # creates a representation of labels, in RGB format, with the right color 

87 if label.mode in ("1", "L"): 

88 label_colored = PIL.ImageOps.colorize( 

89 label.convert("L"), (0, 0, 0), label_color 

90 ) 

91 else: 

92 # user has already passed an RGB version of the labels, just compose 

93 label_colored = label 

94 

95 # blend image and label together - first blend to get vessels drawn with a 

96 # slight "label_color" tone on top, then composite with original image, to 

97 # avoid loosing brightness. 

98 retval = PIL.Image.blend(img, label_colored, alpha) 

99 if label.mode == "1": 

100 composite_mask = invert_mode1_image(label) 

101 else: 

102 composite_mask = PIL.ImageOps.invert(label.convert("L")) 

103 retval = PIL.Image.composite(img, retval, composite_mask) 

104 

105 # creates a representation of the mask negative with the right color 

106 if mask is not None: 

107 antimask_colored = PIL.ImageOps.colorize( 

108 mask.convert("L"), mask_color, (0, 0, 0) 

109 ) 

110 tmp = PIL.Image.blend(retval, antimask_colored, alpha) 

111 retval = PIL.Image.composite(retval, tmp, mask) 

112 

113 return retval 

114 

115 

116class SampleListDataset(torch.utils.data.Dataset): 

117 """PyTorch dataset wrapper around Sample lists 

118 

119 A transform object can be passed that will be applied to the image, ground 

120 truth and mask (if present). 

121 

122 It supports indexing such that dataset[i] can be used to get the i-th 

123 sample. 

124 

125 

126 Parameters 

127 ---------- 

128 

129 samples : list 

130 A list of :py:class:`bob.ip.binseg.data.sample.Sample` objects 

131 

132 transforms : :py:class:`list`, Optional 

133 a list of transformations to be applied to **both** image and 

134 ground-truth data. Notice a last transform 

135 (:py:class:`bob.ip.binseg.data.transforms.ToTensor`) is always applied 

136 - you do not need to add that. 

137 

138 """ 

139 

140 def __init__(self, samples, transforms=[]): 

141 

142 self._samples = samples 

143 self.transforms = transforms 

144 

145 @property 

146 def transforms(self): 

147 return self._transforms.transforms[:-1] 

148 

149 @transforms.setter 

150 def transforms(self, value): 

151 self._transforms = Compose(value + [ToTensor()]) 

152 

153 def copy(self, transforms=None): 

154 """Returns a deep copy of itself, optionally resetting transforms 

155 

156 Parameters 

157 ---------- 

158 

159 transforms : :py:class:`list`, Optional 

160 An optional list of transforms to set in the copy. If not 

161 specified, use ``self.transforms``. 

162 """ 

163 

164 return SampleListDataset(self._samples, transforms or self.transforms) 

165 

166 def keys(self): 

167 """Generator producing all keys for all samples""" 

168 for k in self._samples: 

169 yield k.key 

170 

171 def all_keys_match(self, other): 

172 """Compares all keys to ``other``, return ``True`` if all match""" 

173 return len(self) == len(other) and all( 

174 [(ks == ko) for ks, ko in zip(self.keys(), other.keys())] 

175 ) 

176 

177 def __len__(self): 

178 """ 

179 

180 Returns 

181 ------- 

182 

183 size : int 

184 size of the dataset 

185 

186 """ 

187 return len(self._samples) 

188 

189 def __getitem__(self, key): 

190 """ 

191 

192 Parameters 

193 ---------- 

194 

195 key : int, slice 

196 

197 Returns 

198 ------- 

199 

200 sample : list 

201 The sample data: ``[key, image[, gt[, mask]]]`` 

202 

203 """ 

204 

205 if isinstance(key, slice): 

206 return [self[k] for k in range(*key.indices(len(self)))] 

207 else: # we try it as an int 

208 item = self._samples[key] 

209 data = item.data # triggers data loading 

210 

211 retval = [data["data"]] 

212 if "label" in data: 

213 retval.append(data["label"]) 

214 if "mask" in data: 

215 retval.append(data["mask"]) 

216 

217 if self._transforms: 

218 retval = self._transforms(*retval) 

219 

220 return [item.key] + retval 

221 

222 

223class SSLDataset(torch.utils.data.Dataset): 

224 """PyTorch dataset wrapper around labelled and unlabelled sample lists 

225 

226 Yields elements of the form: 

227 

228 .. code-block:: text 

229 

230 [key, image, ground-truth, [mask,] unlabelled-key, unlabelled-image] 

231 

232 The size of the dataset is the same as the labelled dataset. 

233 

234 Indexing works by selecting the right element on the labelled dataset, and 

235 randomly picking another one from the unlabelled dataset 

236 

237 Parameters 

238 ---------- 

239 

240 labelled : :py:class:`torch.utils.data.Dataset` 

241 Labelled dataset (**must** have "mask" and "label" entries for every 

242 sample) 

243 

244 unlabelled : :py:class:`torch.utils.data.Dataset` 

245 Unlabelled dataset (**may** have "mask" and "label" entries for every 

246 sample, but are ignored) 

247 

248 """ 

249 

250 def __init__(self, labelled, unlabelled): 

251 self._labelled = labelled 

252 self._unlabelled = unlabelled 

253 

254 def keys(self): 

255 """Generator producing all keys for all samples""" 

256 for k in self._labelled + self._unlabelled: 

257 yield k.key 

258 

259 def all_keys_match(self, other): 

260 """Compares all keys to ``other``, return ``True`` if all match""" 

261 return len(self) == len(other) and all( 

262 [(ks == ko) for ks, ko in zip(self.keys(), other.keys())] 

263 ) 

264 

265 def __len__(self): 

266 """ 

267 

268 Returns 

269 ------- 

270 

271 size : int 

272 size of the dataset 

273 

274 """ 

275 

276 return len(self._labelled) 

277 

278 def __getitem__(self, index): 

279 """ 

280 

281 Parameters 

282 ---------- 

283 index : int 

284 The index for the element to pick 

285 

286 Returns 

287 ------- 

288 

289 sample : list 

290 The sample data: ``[key, image, gt, [mask, ]unlab-key, unlab-image]`` 

291 

292 """ 

293 

294 retval = self._labelled[index] 

295 # gets one an unlabelled sample randomly to follow the labelled sample 

296 unlab = self._unlabelled[torch.randint(len(self._unlabelled), ())] 

297 # only interested in key and data 

298 return retval + unlab[:2]