Coverage for src/deepdraw/data/dataset.py: 79%

81 statements  

« prev     ^ index     » next       coverage.py v7.3.1, created at 2023-11-30 15:00 +0100

1# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> 

2# 

3# SPDX-License-Identifier: GPL-3.0-or-later 

4 

5import csv 

6import json 

7import logging 

8import os 

9import pathlib 

10 

11logger = logging.getLogger(__name__) 

12 

13 

14class JSONDataset: 

15 """Generic multi-protocol/subset filelist dataset that yields samples. 

16 

17 To create a new dataset, you need to provide one or more JSON formatted 

18 filelists (one per protocol) with the following contents: 

19 

20 .. code-block:: json 

21 

22 { 

23 "subset1": [ 

24 [ 

25 "value1", 

26 "value2", 

27 "value3" 

28 ], 

29 [ 

30 "value4", 

31 "value5", 

32 "value6" 

33 ] 

34 ], 

35 "subset2": [ 

36 ] 

37 } 

38 

39 Your dataset many contain any number of subsets, but all sample entries 

40 must contain the same number of fields. 

41 

42 

43 Parameters 

44 ---------- 

45 

46 protocols : list, dict 

47 Paths to one or more JSON formatted files containing the various 

48 protocols to be recognized by this dataset, or a dictionary, mapping 

49 protocol names to paths (or opened file objects) of CSV files. 

50 Internally, we save a dictionary where keys default to the basename of 

51 paths (list input). 

52 

53 fieldnames : list, tuple 

54 An iterable over the field names (strings) to assign to each entry in 

55 the JSON file. It should have as many items as fields in each entry of 

56 the JSON file. 

57 

58 loader : object 

59 A function that receives as input, a context dictionary (with at least 

60 a "protocol" and "subset" keys indicating which protocol and subset are 

61 being served), and a dictionary with ``{fieldname: value}`` entries, 

62 and returns an object with at least 2 attributes: 

63 

64 * ``key``: which must be a unique string for every sample across 

65 subsets in a protocol, and 

66 * ``data``: which contains the data associated witht this sample 

67 """ 

68 

69 def __init__(self, protocols, fieldnames, loader): 

70 if isinstance(protocols, dict): 

71 self._protocols = protocols 

72 else: 

73 self._protocols = { 

74 os.path.splitext(os.path.basename(k))[0]: k for k in protocols 

75 } 

76 self.fieldnames = fieldnames 

77 self._loader = loader 

78 

79 def check(self, limit=0): 

80 """For each protocol, check if all data can be correctly accessed. 

81 

82 This function assumes each sample has a ``data`` and a ``key`` 

83 attribute. The ``key`` attribute should be a string, or representable 

84 as such. 

85 

86 

87 Parameters 

88 ---------- 

89 

90 limit : int 

91 Maximum number of samples to check (in each protocol/subset 

92 combination) in this dataset. If set to zero, then check 

93 everything. 

94 

95 

96 Returns 

97 ------- 

98 

99 errors : int 

100 Number of errors found 

101 """ 

102 

103 logger.info("Checking dataset...") 

104 errors = 0 

105 for proto in self._protocols: 

106 logger.info(f"Checking protocol '{proto}'...") 

107 for name, samples in self.subsets(proto).items(): 

108 logger.info(f"Checking subset '{name}'...") 

109 if limit: 

110 logger.info(f"Checking at most first '{limit}' samples...") 

111 samples = samples[:limit] 

112 for pos, sample in enumerate(samples): 

113 try: 

114 sample.data # may trigger data loading 

115 logger.info(f"{sample.key}: OK") 

116 except Exception as e: 

117 logger.error( 

118 f"Found error loading entry {pos} in subset {name} " 

119 f"of protocol {proto} from file " 

120 f"'{self._protocols[proto]}': {e}" 

121 ) 

122 errors += 1 

123 except Exception as e: 

124 logger.error(f"{sample.key}: {e}") 

125 errors += 1 

126 return errors 

127 

128 def subsets(self, protocol): 

129 """Returns all subsets in a protocol. 

130 

131 This method will load JSON information for a given protocol and return 

132 all subsets of the given protocol after converting each entry through 

133 the loader function. 

134 

135 Parameters 

136 ---------- 

137 

138 protocol : str 

139 Name of the protocol data to load 

140 

141 

142 Returns 

143 ------- 

144 

145 subsets : dict 

146 A dictionary mapping subset names to lists of objects (respecting 

147 the ``key``, ``data`` interface). 

148 """ 

149 

150 fileobj = self._protocols[protocol] 

151 if isinstance(fileobj, (str, bytes, pathlib.Path)): 

152 with open(self._protocols[protocol]) as f: 

153 data = json.load(f) 

154 else: 

155 data = json.load(f) 

156 fileobj.seek(0) 

157 

158 retval = {} 

159 for subset, samples in data.items(): 

160 retval[subset] = [ 

161 self._loader( 

162 dict(protocol=protocol, subset=subset, order=n), 

163 dict(zip(self.fieldnames, k)), 

164 ) 

165 for n, k in enumerate(samples) 

166 ] 

167 

168 return retval 

169 

170 

171class CSVDataset: 

172 """Generic multi-subset filelist dataset that yields samples. 

173 

174 To create a new dataset, you only need to provide a CSV formatted filelist 

175 using any separator (e.g. comma, space, semi-colon) with the following 

176 information: 

177 

178 .. code-block:: text 

179 

180 value1,value2,value3 

181 value4,value5,value6 

182 ... 

183 

184 Notice that all rows must have the same number of entries. 

185 

186 Parameters 

187 ---------- 

188 

189 subsets : list, dict 

190 Paths to one or more CSV formatted files containing the various subsets 

191 to be recognized by this dataset, or a dictionary, mapping subset names 

192 to paths (or opened file objects) of CSV files. Internally, we save a 

193 dictionary where keys default to the basename of paths (list input). 

194 

195 fieldnames : list, tuple 

196 An iterable over the field names (strings) to assign to each column in 

197 the CSV file. It should have as many items as fields in each row of 

198 the CSV file(s). 

199 

200 loader : object 

201 A function that receives as input, a context dictionary (with, at 

202 least, a "subset" key indicating which subset is being served), and a 

203 dictionary with ``{key: path}`` entries, and returns a dictionary with 

204 the loaded data. 

205 """ 

206 

207 def __init__(self, subsets, fieldnames, loader): 

208 if isinstance(subsets, dict): 

209 self._subsets = subsets 

210 else: 

211 self._subsets = { 

212 os.path.splitext(os.path.basename(k))[0]: k for k in subsets 

213 } 

214 self.fieldnames = fieldnames 

215 self._loader = loader 

216 

217 def check(self, limit=0): 

218 """For each subset, check if all data can be correctly accessed. 

219 

220 This function assumes each sample has a ``data`` and a ``key`` 

221 attribute. The ``key`` attribute should be a string, or representable 

222 as such. 

223 

224 

225 Parameters 

226 ---------- 

227 

228 limit : int 

229 Maximum number of samples to check (in each protocol/subset 

230 combination) in this dataset. If set to zero, then check 

231 everything. 

232 

233 

234 Returns 

235 ------- 

236 

237 errors : int 

238 Number of errors found 

239 """ 

240 

241 logger.info("Checking dataset...") 

242 errors = 0 

243 for name in self._subsets.keys(): 

244 logger.info(f"Checking subset '{name}'...") 

245 samples = self.samples(name) 

246 if limit: 

247 logger.info(f"Checking at most first '{limit}' samples...") 

248 samples = samples[:limit] 

249 for pos, sample in enumerate(samples): 

250 try: 

251 sample.data # may trigger data loading 

252 logger.info(f"{sample.key}: OK") 

253 except Exception as e: 

254 logger.error( 

255 f"Found error loading entry {pos} in subset {name} " 

256 f"from file '{self._subsets[name]}': {e}" 

257 ) 

258 errors += 1 

259 return errors 

260 

261 def subsets(self): 

262 """Returns all available subsets at once. 

263 

264 Returns 

265 ------- 

266 

267 subsets : dict 

268 A dictionary mapping subset names to lists of objects (respecting 

269 the ``key``, ``data`` interface). 

270 """ 

271 

272 return {k: self.samples(k) for k in self._subsets.keys()} 

273 

274 def samples(self, subset): 

275 """Returns all samples in a subset. 

276 

277 This method will load CSV information for a given subset and return 

278 all samples of the given subset after passing each entry through the 

279 loading function. 

280 

281 

282 Parameters 

283 ---------- 

284 

285 subset : str 

286 Name of the subset data to load 

287 

288 

289 Returns 

290 ------- 

291 

292 subset : list 

293 A lists of objects (respecting the ``key``, ``data`` interface). 

294 """ 

295 

296 fileobj = self._subsets[subset] 

297 if isinstance(fileobj, (str, bytes, pathlib.Path)): 

298 with open(self._subsets[subset], newline="") as f: 

299 cf = csv.reader(f) 

300 samples = [k for k in cf] 

301 else: 

302 cf = csv.reader(fileobj) 

303 samples = [k for k in cf] 

304 fileobj.seek(0) 

305 

306 return [ 

307 self._loader( 

308 dict(subset=subset, order=n), dict(zip(self.fieldnames, k)) 

309 ) 

310 for n, k in enumerate(samples) 

311 ]