Coverage for /scratch/builds/bob/bob.med.tb/miniconda/conda-bld/bob.med.tb_1637571489937/_test_env_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placeho/lib/python3.8/site-packages/bob/med/tb/configs/datasets/__init__.py: 83%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

70 statements  

1#!/usr/bin/env python 

2# coding=utf-8 

3 

4from torchvision.transforms import RandomRotation 

5import random 

6import torch 

7import numpy as np 

8 

9"""Standard configurations for dataset setup""" 

10 

11RANDOM_ROTATION = [RandomRotation(15)] 

12"""Shared data augmentation based on random rotation only""" 

13 

14def make_subset(l, transforms=[], prefixes=[], suffixes=[]): 

15 """Creates a new data set, applying transforms 

16 

17 .. note:: 

18 

19 This is a convenience function for our own dataset definitions inside 

20 this module, guaranteeting homogenity between dataset definitions 

21 provided in this package. It assumes certain strategies for data 

22 augmentation that may not be translatable to other applications. 

23 

24 

25 Parameters 

26 ---------- 

27 

28 l : list 

29 List of delayed samples 

30 

31 transforms : list 

32 A list of transforms that needs to be applied to all samples in the set 

33 

34 prefixes : list 

35 A list of data augmentation operations that needs to be applied 

36 **before** the transforms above 

37 

38 suffixes : list 

39 A list of data augmentation operations that needs to be applied 

40 **after** the transforms above 

41 

42 

43 Returns 

44 ------- 

45 

46 subset : :py:class:`bob.med.tb.data.utils.SampleListDataset` 

47 A pre-formatted dataset that can be fed to one of our engines 

48 

49 """ 

50 

51 from ...data.utils import SampleListDataset as wrapper 

52 

53 return wrapper(l, prefixes + transforms + suffixes) 

54 

55 

56def make_dataset(subsets_groups, transforms=[], t_transforms=[], 

57 post_transforms=[]): 

58 """Creates a new configuration dataset from a list of dictionaries  

59 and transforms 

60 

61 This function takes as input a list of dictionaries as those that can be  

62 returned by :py:meth:`bob.med.tb.data.dataset.JSONDataset.subsets`  

63 mapping protocol names (such as ``train``, ``dev`` and ``test``) to 

64 :py:class:`bob.med.tb.data.sample.DelayedSample` lists, and a set of 

65 transforms, and returns a dictionary applying 

66 :py:class:`bob.med.tb.data.utils.SampleListDataset` to these 

67 lists, and our standard data augmentation if a ``train`` set exists. 

68 

69 For example, if ``subsets`` is composed of two sets named ``train`` and 

70 ``test``, this function will yield a dictionary with the following entries: 

71 

72 * ``__train__``: Wraps the ``train`` subset, includes data augmentation 

73 (note: datasets with names starting with ``_`` (underscore) are excluded 

74 from prediction and evaluation by default, as they contain data 

75 augmentation transformations.) 

76 * ``train``: Wraps the ``train`` subset, **without** data augmentation 

77 * ``test``: Wraps the ``test`` subset, **without** data augmentation 

78 

79 .. note:: 

80 

81 This is a convenience function for our own dataset definitions inside 

82 this module, guaranteeting homogenity between dataset definitions 

83 provided in this package. It assumes certain strategies for data 

84 augmentation that may not be translatable to other applications. 

85 

86 

87 Parameters 

88 ---------- 

89 

90 subsets : list 

91 A list of dictionaries that contains the delayed sample lists  

92 for a number of named lists. The subsets will be aggregated in one 

93 final subset. If one of the keys is ``train``, our standard dataset 

94 augmentation transforms are appended to the definition of that subset. 

95 All other subsets remain un-augmented. 

96 

97 transforms : list 

98 A list of transforms that needs to be applied to all samples in the set 

99 

100 t_transforms : list 

101 A list of transforms that needs to be applied to the train samples 

102 

103 post_transforms : list 

104 A list of transforms that needs to be applied to all samples in the set 

105 after all the other transforms 

106 

107 

108 Returns 

109 ------- 

110 

111 dataset : dict 

112 A pre-formatted dataset that can be fed to one of our engines. It maps 

113 string names to 

114 :py:class:`bob.med.tb.data.utils.SampleListDataset`'s. 

115 

116 """ 

117 

118 retval = {} 

119 

120 if len(subsets_groups) == 1: 

121 subsets = subsets_groups[0] 

122 else: 

123 # If multiple subsets groups: aggregation 

124 aggregated_subsets = {} 

125 for subsets in subsets_groups: 

126 for key in subsets.keys(): 

127 if key in aggregated_subsets: 

128 aggregated_subsets[key] += subsets[key] 

129 # Shuffle if data comes from multiple datasets 

130 random.shuffle(aggregated_subsets[key]) 

131 else: 

132 aggregated_subsets[key] = subsets[key] 

133 subsets = aggregated_subsets 

134 

135 # Add post_transforms after t_transforms for the train set 

136 t_transforms += post_transforms 

137 

138 for key in subsets.keys(): 

139 

140 retval[key] = make_subset(subsets[key], transforms=transforms, 

141 suffixes=post_transforms) 

142 if key == "train": 

143 retval["__train__"] = make_subset(subsets[key], 

144 transforms=transforms, 

145 suffixes=(t_transforms) 

146 ) 

147 if key == "validation": 

148 # also use it for validation during training 

149 retval["__valid__"] = retval[key] 

150 

151 if ("__train__" in retval) and ("train" in retval) \ 

152 and ("__valid__" not in retval): 

153 # if the dataset does not have a validation set, we use the unaugmented 

154 # training set as validation set 

155 retval["__valid__"] = retval["train"] 

156 

157 return retval 

158 

159 

160def get_samples_weights(dataset): 

161 """Compute the weights of all the samples of the dataset to balance it 

162 using the sampler of the dataloader 

163  

164 This function takes as input a :py:class:`torch.utils.data.dataset.Dataset`  

165 and computes the weights to balance each class in the dataset and the 

166 datasets themselves if we have a ConcatDataset. 

167 

168 

169 Parameters 

170 ---------- 

171 

172 dataset : torch.utils.data.dataset.Dataset 

173 An instance of torch.utils.data.dataset.Dataset 

174 ConcatDataset are supported 

175 

176 

177 Returns 

178 ------- 

179 

180 samples_weights : :py:class:`torch.Tensor` 

181 the weights for all the samples in the dataset given as input 

182 

183 """ 

184 

185 samples_weights = [] 

186 

187 if isinstance(dataset, torch.utils.data.ConcatDataset): 

188 for ds in dataset.datasets: 

189 

190 # Weighting only for binary labels 

191 if isinstance(ds._samples[0].label, int): 

192 

193 # Groundtruth 

194 targets = [] 

195 for s in ds._samples: 

196 targets.append(s.label) 

197 targets = torch.tensor(targets) 

198 

199 # Count number of samples per class 

200 class_sample_count = torch.tensor( 

201 [(targets == t).sum() for t in torch.unique(targets, sorted=True)]) 

202 

203 weight = 1. / class_sample_count.float() 

204 

205 samples_weights.append(torch.tensor([weight[t] for t in targets])) 

206 

207 else: 

208 # We only weight to sample equally from each dataset 

209 samples_weights.append(torch.full((len(ds),), 1. / len(ds))) 

210 

211 # Concatenate sample weights from all the datasets 

212 samples_weights = torch.cat(samples_weights) 

213 

214 else: 

215 # Weighting only for binary labels 

216 if isinstance(dataset._samples[0].label, int): 

217 # Groundtruth 

218 targets = [] 

219 for s in dataset._samples: 

220 targets.append(s.label) 

221 targets = torch.tensor(targets) 

222 

223 # Count number of samples per class 

224 class_sample_count = torch.tensor( 

225 [(targets == t).sum() for t in torch.unique(targets, sorted=True)]) 

226 

227 weight = 1. / class_sample_count.float() 

228 

229 samples_weights = torch.tensor([weight[t] for t in targets]) 

230 

231 else: 

232 # Equal weights for non-binary labels 

233 samples_weights = torch.ones(len(dataset._samples)) 

234 

235 return samples_weights 

236 

237 

238def get_positive_weights(dataset): 

239 """Compute the positive weights of each class of the dataset to balance  

240 the BCEWithLogitsLoss criterion 

241  

242 This function takes as input a :py:class:`torch.utils.data.dataset.Dataset`  

243 and computes the positive weights of each class to use them to have  

244 a balanced loss. 

245 

246 

247 Parameters 

248 ---------- 

249 

250 dataset : torch.utils.data.dataset.Dataset 

251 An instance of torch.utils.data.dataset.Dataset 

252 ConcatDataset are supported 

253 

254 

255 Returns 

256 ------- 

257 

258 positive_weights : :py:class:`torch.Tensor` 

259 the positive weight of each class in the dataset given as input 

260 

261 """ 

262 targets = [] 

263 

264 if isinstance(dataset, torch.utils.data.ConcatDataset): 

265 

266 for ds in dataset.datasets: 

267 for s in ds._samples: 

268 targets.append(s.label) 

269 

270 else: 

271 for s in dataset._samples: 

272 targets.append(s.label) 

273 

274 targets = torch.tensor(targets) 

275 

276 # Binary labels 

277 if len(list(targets.shape)) == 1: 

278 class_sample_count = [float((targets == t).sum().item()) for t in torch.unique(targets, sorted=True)] 

279 

280 # Divide negatives by positives 

281 positive_weights = torch.tensor([class_sample_count[0]/class_sample_count[1]]).reshape(-1) 

282 

283 # Multiclass labels 

284 else: 

285 class_sample_count = torch.sum(targets, dim=0) 

286 negative_class_sample_count = torch.full((targets.size()[1],), float(targets.size()[0])) - class_sample_count 

287 

288 positive_weights = negative_class_sample_count / (class_sample_count + negative_class_sample_count) 

289 

290 return positive_weights