1#!/usr/bin/env python
2# coding=utf-8
3
4from torchvision.transforms import RandomRotation
5import random
6import torch
7import numpy as np
8
9"""Standard configurations for dataset setup"""
10
11RANDOM_ROTATION = [RandomRotation(15)]
12"""Shared data augmentation based on random rotation only"""
13
14def make_subset(l, transforms=[], prefixes=[], suffixes=[]):
15 """Creates a new data set, applying transforms
16
17 .. note::
18
19 This is a convenience function for our own dataset definitions inside
20 this module, guaranteeting homogenity between dataset definitions
21 provided in this package. It assumes certain strategies for data
22 augmentation that may not be translatable to other applications.
23
24
25 Parameters
26 ----------
27
28 l : list
29 List of delayed samples
30
31 transforms : list
32 A list of transforms that needs to be applied to all samples in the set
33
34 prefixes : list
35 A list of data augmentation operations that needs to be applied
36 **before** the transforms above
37
38 suffixes : list
39 A list of data augmentation operations that needs to be applied
40 **after** the transforms above
41
42
43 Returns
44 -------
45
46 subset : :py:class:`bob.med.tb.data.utils.SampleListDataset`
47 A pre-formatted dataset that can be fed to one of our engines
48
49 """
50
51 from ...data.utils import SampleListDataset as wrapper
52
53 return wrapper(l, prefixes + transforms + suffixes)
54
55
56def make_dataset(subsets_groups, transforms=[], t_transforms=[],
57 post_transforms=[]):
58 """Creates a new configuration dataset from a list of dictionaries
59 and transforms
60
61 This function takes as input a list of dictionaries as those that can be
62 returned by :py:meth:`bob.med.tb.data.dataset.JSONDataset.subsets`
63 mapping protocol names (such as ``train``, ``dev`` and ``test``) to
64 :py:class:`bob.med.tb.data.sample.DelayedSample` lists, and a set of
65 transforms, and returns a dictionary applying
66 :py:class:`bob.med.tb.data.utils.SampleListDataset` to these
67 lists, and our standard data augmentation if a ``train`` set exists.
68
69 For example, if ``subsets`` is composed of two sets named ``train`` and
70 ``test``, this function will yield a dictionary with the following entries:
71
72 * ``__train__``: Wraps the ``train`` subset, includes data augmentation
73 (note: datasets with names starting with ``_`` (underscore) are excluded
74 from prediction and evaluation by default, as they contain data
75 augmentation transformations.)
76 * ``train``: Wraps the ``train`` subset, **without** data augmentation
77 * ``test``: Wraps the ``test`` subset, **without** data augmentation
78
79 .. note::
80
81 This is a convenience function for our own dataset definitions inside
82 this module, guaranteeting homogenity between dataset definitions
83 provided in this package. It assumes certain strategies for data
84 augmentation that may not be translatable to other applications.
85
86
87 Parameters
88 ----------
89
90 subsets : list
91 A list of dictionaries that contains the delayed sample lists
92 for a number of named lists. The subsets will be aggregated in one
93 final subset. If one of the keys is ``train``, our standard dataset
94 augmentation transforms are appended to the definition of that subset.
95 All other subsets remain un-augmented.
96
97 transforms : list
98 A list of transforms that needs to be applied to all samples in the set
99
100 t_transforms : list
101 A list of transforms that needs to be applied to the train samples
102
103 post_transforms : list
104 A list of transforms that needs to be applied to all samples in the set
105 after all the other transforms
106
107
108 Returns
109 -------
110
111 dataset : dict
112 A pre-formatted dataset that can be fed to one of our engines. It maps
113 string names to
114 :py:class:`bob.med.tb.data.utils.SampleListDataset`'s.
115
116 """
117
118 retval = {}
119
120 if len(subsets_groups) == 1:
121 subsets = subsets_groups[0]
122 else:
123 # If multiple subsets groups: aggregation
124 aggregated_subsets = {}
125 for subsets in subsets_groups:
126 for key in subsets.keys():
127 if key in aggregated_subsets:
128 aggregated_subsets[key] += subsets[key]
129 # Shuffle if data comes from multiple datasets
130 random.shuffle(aggregated_subsets[key])
131 else:
132 aggregated_subsets[key] = subsets[key]
133 subsets = aggregated_subsets
134
135 # Add post_transforms after t_transforms for the train set
136 t_transforms += post_transforms
137
138 for key in subsets.keys():
139
140 retval[key] = make_subset(subsets[key], transforms=transforms,
141 suffixes=post_transforms)
142 if key == "train":
143 retval["__train__"] = make_subset(subsets[key],
144 transforms=transforms,
145 suffixes=(t_transforms)
146 )
147 if key == "validation":
148 # also use it for validation during training
149 retval["__valid__"] = retval[key]
150
151 if ("__train__" in retval) and ("train" in retval) \
152 and ("__valid__" not in retval):
153 # if the dataset does not have a validation set, we use the unaugmented
154 # training set as validation set
155 retval["__valid__"] = retval["train"]
156
157 return retval
158
159
160def get_samples_weights(dataset):
161 """Compute the weights of all the samples of the dataset to balance it
162 using the sampler of the dataloader
163
164 This function takes as input a :py:class:`torch.utils.data.dataset.Dataset`
165 and computes the weights to balance each class in the dataset and the
166 datasets themselves if we have a ConcatDataset.
167
168
169 Parameters
170 ----------
171
172 dataset : torch.utils.data.dataset.Dataset
173 An instance of torch.utils.data.dataset.Dataset
174 ConcatDataset are supported
175
176
177 Returns
178 -------
179
180 samples_weights : :py:class:`torch.Tensor`
181 the weights for all the samples in the dataset given as input
182
183 """
184
185 samples_weights = []
186
187 if isinstance(dataset, torch.utils.data.ConcatDataset):
188 for ds in dataset.datasets:
189
190 # Weighting only for binary labels
191 if isinstance(ds._samples[0].label, int):
192
193 # Groundtruth
194 targets = []
195 for s in ds._samples:
196 targets.append(s.label)
197 targets = torch.tensor(targets)
198
199 # Count number of samples per class
200 class_sample_count = torch.tensor(
201 [(targets == t).sum() for t in torch.unique(targets, sorted=True)])
202
203 weight = 1. / class_sample_count.float()
204
205 samples_weights.append(torch.tensor([weight[t] for t in targets]))
206
207 else:
208 # We only weight to sample equally from each dataset
209 samples_weights.append(torch.full((len(ds),), 1. / len(ds)))
210
211 # Concatenate sample weights from all the datasets
212 samples_weights = torch.cat(samples_weights)
213
214 else:
215 # Weighting only for binary labels
216 if isinstance(dataset._samples[0].label, int):
217 # Groundtruth
218 targets = []
219 for s in dataset._samples:
220 targets.append(s.label)
221 targets = torch.tensor(targets)
222
223 # Count number of samples per class
224 class_sample_count = torch.tensor(
225 [(targets == t).sum() for t in torch.unique(targets, sorted=True)])
226
227 weight = 1. / class_sample_count.float()
228
229 samples_weights = torch.tensor([weight[t] for t in targets])
230
231 else:
232 # Equal weights for non-binary labels
233 samples_weights = torch.ones(len(dataset._samples))
234
235 return samples_weights
236
237
238def get_positive_weights(dataset):
239 """Compute the positive weights of each class of the dataset to balance
240 the BCEWithLogitsLoss criterion
241
242 This function takes as input a :py:class:`torch.utils.data.dataset.Dataset`
243 and computes the positive weights of each class to use them to have
244 a balanced loss.
245
246
247 Parameters
248 ----------
249
250 dataset : torch.utils.data.dataset.Dataset
251 An instance of torch.utils.data.dataset.Dataset
252 ConcatDataset are supported
253
254
255 Returns
256 -------
257
258 positive_weights : :py:class:`torch.Tensor`
259 the positive weight of each class in the dataset given as input
260
261 """
262 targets = []
263
264 if isinstance(dataset, torch.utils.data.ConcatDataset):
265
266 for ds in dataset.datasets:
267 for s in ds._samples:
268 targets.append(s.label)
269
270 else:
271 for s in dataset._samples:
272 targets.append(s.label)
273
274 targets = torch.tensor(targets)
275
276 # Binary labels
277 if len(list(targets.shape)) == 1:
278 class_sample_count = [float((targets == t).sum().item()) for t in torch.unique(targets, sorted=True)]
279
280 # Divide negatives by positives
281 positive_weights = torch.tensor([class_sample_count[0]/class_sample_count[1]]).reshape(-1)
282
283 # Multiclass labels
284 else:
285 class_sample_count = torch.sum(targets, dim=0)
286 negative_class_sample_count = torch.full((targets.size()[1],), float(targets.size()[0])) - class_sample_count
287
288 positive_weights = negative_class_sample_count / (class_sample_count + negative_class_sample_count)
289
290 return positive_weights