Coverage for src/deepdraw/configs/datasets/__init__.py: 100%

27 statements  

« prev     ^ index     » next       coverage.py v7.3.1, created at 2023-11-30 15:00 +0100

1# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> 

2# 

3# SPDX-License-Identifier: GPL-3.0-or-later 

4 

5"""Standard configurations for dataset setup.""" 

6 

7 

8from ...data.transforms import ColorJitter as _jitter 

9from ...data.transforms import RandomHorizontalFlip as _hflip 

10from ...data.transforms import RandomRotation as _rotation 

11from ...data.transforms import RandomVerticalFlip as _vflip 

12 

13RANDOM_ROTATION = [_rotation()] 

14"""Shared data augmentation based on random rotation only.""" 

15 

16 

17RANDOM_FLIP_JITTER = [_hflip(), _vflip(), _jitter()] 

18"""Shared data augmentation transforms without random rotation.""" 

19 

20 

21def make_subset(samples, transforms, prefixes=[], suffixes=[]): 

22 """Creates a new data set, applying transforms. 

23 

24 .. note:: 

25 

26 This is a convenience function for our own dataset definitions inside 

27 this module, guaranteeting homogenity between dataset definitions 

28 provided in this package. It assumes certain strategies for data 

29 augmentation that may not be translatable to other applications. 

30 

31 

32 Parameters 

33 ---------- 

34 

35 samples : list 

36 List of delayed samples 

37 

38 transforms : list 

39 A list of transforms that needs to be applied to all samples in the set 

40 

41 prefixes : list 

42 A list of data augmentation operations that needs to be applied 

43 **before** the transforms above 

44 

45 suffixes : list 

46 A list of data augmentation operations that needs to be applied 

47 **after** the transforms above 

48 

49 

50 Returns 

51 ------- 

52 

53 subset : :py:class:`deepdraw.data.utils.SampleListDataset` 

54 A pre-formatted dataset that can be fed to one of our engines 

55 """ 

56 

57 from ...data.utils import SampleListDataset as wrapper 

58 

59 return wrapper(samples, prefixes + transforms + suffixes) 

60 

61 

62def augment_subset(s, rotation_before=False): 

63 """Creates a new subset set, **with data augmentation** 

64 

65 Typically, the transforms are chained to a default set of data augmentation 

66 operations (random rotation, horizontal and vertical flips, and color 

67 jitter), but a flag allows prefixing the rotation specially (useful for 

68 some COVD training sets). 

69 

70 .. note:: 

71 

72 This is a convenience function for our own dataset definitions inside 

73 this module, guaranteeting homogenity between dataset definitions 

74 provided in this package. It assumes certain strategies for data 

75 augmentation that may not be translatable to other applications. 

76 

77 

78 Parameters 

79 ---------- 

80 

81 s : deepdraw.data.utils.SampleListDataset 

82 A dataset that will be augmented 

83 

84 rotation_before : py:class:`bool`, Optional 

85 A optional flag allowing you to do a rotation augmentation transform 

86 **before** the sequence of transforms for this dataset, that will be 

87 augmented. 

88 

89 

90 Returns 

91 ------- 

92 

93 subset : :py:class:`deepdraw.data.utils.SampleListDataset` 

94 A pre-formatted dataset that can be fed to one of our engines 

95 """ 

96 

97 if rotation_before: 

98 return s.copy(RANDOM_ROTATION + s.transforms + RANDOM_FLIP_JITTER) 

99 

100 return s.copy(s.transforms + RANDOM_ROTATION + RANDOM_FLIP_JITTER) 

101 

102 

103def make_dataset(subsets, transforms): 

104 """Creates a new configuration dataset from dictionary and transforms. 

105 

106 This function takes as input a dictionary as those that can be returned by 

107 :py:meth:`deepdraw.data.dataset.JSONDataset.subsets`, or 

108 :py:meth:`deepdraw.data.dataset.CSVDataset.subsets`, mapping protocol 

109 names (such as ``train``, ``dev`` and ``test``) to 

110 :py:class:`deepdraw.data.sample.DelayedSample` lists, and a set of 

111 transforms, and returns a dictionary applying 

112 :py:class:`deepdraw.data.utils.SampleListDataset` to these 

113 lists, and our standard data augmentation if a ``train`` set exists. 

114 

115 For example, if ``subsets`` is composed of two sets named ``train`` and 

116 ``test``, this function will yield a dictionary with the following entries: 

117 

118 * ``__train__``: Wraps the ``train`` subset, includes data augmentation 

119 (note: datasets with names starting with ``_`` (underscore) are excluded 

120 from prediction and evaluation by default, as they contain data 

121 augmentation transformations.) 

122 * ``train``: Wraps the ``train`` subset, **without** data augmentation 

123 * ``train``: Wraps the ``test`` subset, **without** data augmentation 

124 

125 .. note:: 

126 

127 This is a convenience function for our own dataset definitions inside 

128 this module, guaranteeting homogenity between dataset definitions 

129 provided in this package. It assumes certain strategies for data 

130 augmentation that may not be translatable to other applications. 

131 

132 

133 Parameters 

134 ---------- 

135 

136 subsets : dict 

137 A dictionary that contains the delayed sample lists for a number of 

138 named lists. If one of the keys is ``train``, our standard dataset 

139 augmentation transforms are appended to the definition of that subset. 

140 All other subsets remain un-augmented. If one of the keys is 

141 ``validation``, then this dataset will be also copied to the 

142 ``__valid__`` hidden dataset and will be used for validation during 

143 training. Otherwise, if no ``valid`` subset is available, we set 

144 ``__valid__`` to be the same as the unaugmented ``train`` subset, if 

145 one is available. 

146 

147 transforms : list 

148 A list of transforms that needs to be applied to all samples in the set 

149 

150 

151 Returns 

152 ------- 

153 

154 dataset : dict 

155 A pre-formatted dataset that can be fed to one of our engines. It maps 

156 string names to 

157 :py:class:`deepdraw.data.utils.SampleListDataset`'s. 

158 """ 

159 

160 retval = {} 

161 

162 for key in subsets.keys(): 

163 retval[key] = make_subset(subsets[key], transforms=transforms) 

164 if key == "train": 

165 retval["__train__"] = make_subset( 

166 subsets[key], 

167 transforms=transforms, 

168 suffixes=(RANDOM_ROTATION + RANDOM_FLIP_JITTER), 

169 ) 

170 if key == "validation": 

171 # also use it for validation during training 

172 retval["__valid__"] = retval[key] 

173 

174 if ( 

175 ("__train__" in retval) 

176 and ("train" in retval) 

177 and ("__valid__" not in retval) 

178 ): 

179 # if the dataset does not have a validation set, we use the unaugmented 

180 # training set as validation set 

181 retval["__valid__"] = retval["train"] 

182 

183 return retval