Coverage for /scratch/builds/bob/bob.med.tb/miniconda/conda-bld/bob.med.tb_1637571489937/_test_env_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placeho/lib/python3.8/site-packages/bob/med/tb/data/dataset.py: 78%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

85 statements  

1#!/usr/bin/env python 

2# coding=utf-8 

3 

4import os 

5import csv 

6import copy 

7import json 

8import pathlib 

9import numpy 

10import torch 

11from torch.autograd import Variable 

12 

13import logging 

14 

15logger = logging.getLogger(__name__) 

16 

17 

18class JSONDataset: 

19 """ 

20 Generic multi-protocol/subset filelist dataset that yields samples 

21 

22 To create a new dataset, you need to provide one or more JSON formatted 

23 filelists (one per protocol) with the following contents: 

24 

25 .. code-block:: json 

26 

27 { 

28 "subset1": [ 

29 [ 

30 "value1", 

31 "value2", 

32 "value3" 

33 ], 

34 [ 

35 "value4", 

36 "value5", 

37 "value6" 

38 ] 

39 ], 

40 "subset2": [ 

41 ] 

42 } 

43 

44 Your dataset many contain any number of subsets, but all sample entries 

45 must contain the same number of fields. 

46 

47 

48 Parameters 

49 ---------- 

50 

51 protocols : list, dict 

52 Paths to one or more JSON formatted files containing the various 

53 protocols to be recognized by this dataset, or a dictionary, mapping 

54 protocol names to paths (or opened file objects) of CSV files. 

55 Internally, we save a dictionary where keys default to the basename of 

56 paths (list input). 

57 

58 fieldnames : list, tuple 

59 An iterable over the field names (strings) to assign to each entry in 

60 the JSON file. It should have as many items as fields in each entry of 

61 the JSON file. 

62 

63 loader : object 

64 A function that receives as input, a context dictionary (with at least 

65 a "protocol" and "subset" keys indicating which protocol and subset are 

66 being served), and a dictionary with ``{fieldname: value}`` entries, 

67 and returns an object with at least 2 attributes: 

68 

69 * ``key``: which must be a unique string for every sample across 

70 subsets in a protocol, and 

71 * ``data``: which contains the data associated witht this sample 

72 

73 """ 

74 

75 def __init__(self, protocols, fieldnames, loader): 

76 

77 if isinstance(protocols, dict): 

78 self._protocols = protocols 

79 else: 

80 self._protocols = dict( 

81 (os.path.splitext(os.path.basename(k))[0], k) for k in protocols 

82 ) 

83 self.fieldnames = fieldnames 

84 self._loader = loader 

85 

86 def check(self, limit=0): 

87 """For each protocol, check if all data can be correctly accessed 

88 

89 This function assumes each sample has a ``data`` and a ``key`` 

90 attribute. The ``key`` attribute should be a string, or representable 

91 as such. 

92 

93 

94 Parameters 

95 ---------- 

96 

97 limit : int 

98 Maximum number of samples to check (in each protocol/subset 

99 combination) in this dataset. If set to zero, then check 

100 everything. 

101 

102 

103 Returns 

104 ------- 

105 

106 errors : int 

107 Number of errors found 

108 

109 """ 

110 

111 logger.info(f"Checking dataset...") 

112 errors = 0 

113 for proto in self._protocols: 

114 logger.info(f"Checking protocol '{proto}'...") 

115 for name, samples in self.subsets(proto).items(): 

116 logger.info(f"Checking subset '{name}'...") 

117 if limit: 

118 logger.info(f"Checking at most first '{limit}' samples...") 

119 samples = samples[:limit] 

120 for pos, sample in enumerate(samples): 

121 try: 

122 sample.data # may trigger data loading 

123 logger.info(f"{sample.key}: OK") 

124 except Exception as e: 

125 logger.error( 

126 f"Found error loading entry {pos} in subset {name} " 

127 f"of protocol {proto} from file " 

128 f"'{self._protocols[proto]}': {e}" 

129 ) 

130 errors += 1 

131 except Exception as e: 

132 logger.error(f"{sample.key}: {e}") 

133 errors += 1 

134 return errors 

135 

136 def subsets(self, protocol): 

137 """Returns all subsets in a protocol 

138 

139 This method will load JSON information for a given protocol and return 

140 all subsets of the given protocol after converting each entry through 

141 the loader function. 

142 

143 Parameters 

144 ---------- 

145 

146 protocol : str 

147 Name of the protocol data to load 

148 

149 

150 Returns 

151 ------- 

152 

153 subsets : dict 

154 A dictionary mapping subset names to lists of objects (respecting 

155 the ``key``, ``data`` interface). 

156 

157 """ 

158 

159 fileobj = self._protocols[protocol] 

160 if isinstance(fileobj, (str, bytes, pathlib.Path)): 

161 with open(self._protocols[protocol], "r") as f: 

162 data = json.load(f) 

163 else: 

164 data = json.load(f) 

165 fileobj.seek(0) 

166 

167 retval = {} 

168 for subset, samples in data.items(): 

169 retval[subset] = [ 

170 self._loader( 

171 dict(protocol=protocol, subset=subset, order=n), 

172 dict(zip(self.fieldnames, k)) 

173 ) 

174 for n, k in enumerate(samples) 

175 ] 

176 

177 return retval 

178 

179 

180class CSVDataset: 

181 """ 

182 Generic multi-subset filelist dataset that yields samples 

183 

184 To create a new dataset, you only need to provide a CSV formatted filelist 

185 using any separator (e.g. comma, space, semi-colon) with the following 

186 information: 

187 

188 .. code-block:: text 

189 

190 value1,value2,value3 

191 value4,value5,value6 

192 ... 

193 

194 Notice that all rows must have the same number of entries. 

195 

196 Parameters 

197 ---------- 

198 

199 subsets : list, dict 

200 Paths to one or more CSV formatted files containing the various subsets 

201 to be recognized by this dataset, or a dictionary, mapping subset names 

202 to paths (or opened file objects) of CSV files. Internally, we save a 

203 dictionary where keys default to the basename of paths (list input). 

204 

205 fieldnames : list, tuple 

206 An iterable over the field names (strings) to assign to each column in 

207 the CSV file. It should have as many items as fields in each row of 

208 the CSV file(s). 

209 

210 loader : object 

211 A function that receives as input, a context dictionary (with, at 

212 least, a "subset" key indicating which subset is being served), and a 

213 dictionary with ``{key: path}`` entries, and returns a dictionary with 

214 the loaded data. 

215 

216 """ 

217 

218 def __init__(self, subsets, fieldnames, loader): 

219 

220 if isinstance(subsets, dict): 

221 self._subsets = subsets 

222 else: 

223 self._subsets = dict( 

224 (os.path.splitext(os.path.basename(k))[0], k) for k in subsets 

225 ) 

226 self.fieldnames = fieldnames 

227 self._loader = loader 

228 

229 def check(self, limit=0): 

230 """For each subset, check if all data can be correctly accessed 

231 

232 This function assumes each sample has a ``data`` and a ``key`` 

233 attribute. The ``key`` attribute should be a string, or representable 

234 as such. 

235 

236 

237 Parameters 

238 ---------- 

239 

240 limit : int 

241 Maximum number of samples to check (in each protocol/subset 

242 combination) in this dataset. If set to zero, then check 

243 everything. 

244 

245 

246 Returns 

247 ------- 

248 

249 errors : int 

250 Number of errors found 

251 

252 """ 

253 

254 logger.info(f"Checking dataset...") 

255 errors = 0 

256 for name in self._subsets.keys(): 

257 logger.info(f"Checking subset '{name}'...") 

258 samples = self.samples(name) 

259 if limit: 

260 logger.info(f"Checking at most first '{limit}' samples...") 

261 samples = samples[:limit] 

262 for pos, sample in enumerate(samples): 

263 try: 

264 sample.data # may trigger data loading 

265 logger.info(f"{sample.key}: OK") 

266 except Exception as e: 

267 logger.error( 

268 f"Found error loading entry {pos} in subset {name} " 

269 f"from file '{self._subsets[name]}': {e}" 

270 ) 

271 errors += 1 

272 return errors 

273 

274 def subsets(self): 

275 """Returns all available subsets at once 

276 

277 Returns 

278 ------- 

279 

280 subsets : dict 

281 A dictionary mapping subset names to lists of objects (respecting 

282 the ``key``, ``data`` interface). 

283 

284 """ 

285 

286 return dict((k, self.samples(k)) for k in self._subsets.keys()) 

287 

288 def samples(self, subset): 

289 """Returns all samples in a subset 

290 

291 This method will load CSV information for a given subset and return 

292 all samples of the given subset after passing each entry through the 

293 loading function. 

294 

295 

296 Parameters 

297 ---------- 

298 

299 subset : str 

300 Name of the subset data to load 

301 

302 

303 Returns 

304 ------- 

305 

306 subset : list 

307 A lists of objects (respecting the ``key``, ``data`` interface). 

308 

309 """ 

310 

311 fileobj = self._subsets[subset] 

312 if isinstance(fileobj, (str, bytes, pathlib.Path)): 

313 with open(self._subsets[subset], newline="") as f: 

314 cf = csv.reader(f) 

315 samples = [k for k in cf] 

316 else: 

317 cf = csv.reader(fileobj) 

318 samples = [k for k in cf] 

319 fileobj.seek(0) 

320 

321 return [ 

322 self._loader( 

323 dict(subset=subset, order=n), dict(zip(self.fieldnames, k)) 

324 ) 

325 for n, k in enumerate(samples) 

326 ]