1#!/usr/bin/env python
2# coding=utf-8
3
4import os
5import csv
6import copy
7import json
8import pathlib
9import numpy
10import torch
11from torch.autograd import Variable
12
13import logging
14
15logger = logging.getLogger(__name__)
16
17
18class JSONDataset:
19 """
20 Generic multi-protocol/subset filelist dataset that yields samples
21
22 To create a new dataset, you need to provide one or more JSON formatted
23 filelists (one per protocol) with the following contents:
24
25 .. code-block:: json
26
27 {
28 "subset1": [
29 [
30 "value1",
31 "value2",
32 "value3"
33 ],
34 [
35 "value4",
36 "value5",
37 "value6"
38 ]
39 ],
40 "subset2": [
41 ]
42 }
43
44 Your dataset many contain any number of subsets, but all sample entries
45 must contain the same number of fields.
46
47
48 Parameters
49 ----------
50
51 protocols : list, dict
52 Paths to one or more JSON formatted files containing the various
53 protocols to be recognized by this dataset, or a dictionary, mapping
54 protocol names to paths (or opened file objects) of CSV files.
55 Internally, we save a dictionary where keys default to the basename of
56 paths (list input).
57
58 fieldnames : list, tuple
59 An iterable over the field names (strings) to assign to each entry in
60 the JSON file. It should have as many items as fields in each entry of
61 the JSON file.
62
63 loader : object
64 A function that receives as input, a context dictionary (with at least
65 a "protocol" and "subset" keys indicating which protocol and subset are
66 being served), and a dictionary with ``{fieldname: value}`` entries,
67 and returns an object with at least 2 attributes:
68
69 * ``key``: which must be a unique string for every sample across
70 subsets in a protocol, and
71 * ``data``: which contains the data associated witht this sample
72
73 """
74
75 def __init__(self, protocols, fieldnames, loader):
76
77 if isinstance(protocols, dict):
78 self._protocols = protocols
79 else:
80 self._protocols = dict(
81 (os.path.splitext(os.path.basename(k))[0], k) for k in protocols
82 )
83 self.fieldnames = fieldnames
84 self._loader = loader
85
86 def check(self, limit=0):
87 """For each protocol, check if all data can be correctly accessed
88
89 This function assumes each sample has a ``data`` and a ``key``
90 attribute. The ``key`` attribute should be a string, or representable
91 as such.
92
93
94 Parameters
95 ----------
96
97 limit : int
98 Maximum number of samples to check (in each protocol/subset
99 combination) in this dataset. If set to zero, then check
100 everything.
101
102
103 Returns
104 -------
105
106 errors : int
107 Number of errors found
108
109 """
110
111 logger.info(f"Checking dataset...")
112 errors = 0
113 for proto in self._protocols:
114 logger.info(f"Checking protocol '{proto}'...")
115 for name, samples in self.subsets(proto).items():
116 logger.info(f"Checking subset '{name}'...")
117 if limit:
118 logger.info(f"Checking at most first '{limit}' samples...")
119 samples = samples[:limit]
120 for pos, sample in enumerate(samples):
121 try:
122 sample.data # may trigger data loading
123 logger.info(f"{sample.key}: OK")
124 except Exception as e:
125 logger.error(
126 f"Found error loading entry {pos} in subset {name} "
127 f"of protocol {proto} from file "
128 f"'{self._protocols[proto]}': {e}"
129 )
130 errors += 1
131 except Exception as e:
132 logger.error(f"{sample.key}: {e}")
133 errors += 1
134 return errors
135
136 def subsets(self, protocol):
137 """Returns all subsets in a protocol
138
139 This method will load JSON information for a given protocol and return
140 all subsets of the given protocol after converting each entry through
141 the loader function.
142
143 Parameters
144 ----------
145
146 protocol : str
147 Name of the protocol data to load
148
149
150 Returns
151 -------
152
153 subsets : dict
154 A dictionary mapping subset names to lists of objects (respecting
155 the ``key``, ``data`` interface).
156
157 """
158
159 fileobj = self._protocols[protocol]
160 if isinstance(fileobj, (str, bytes, pathlib.Path)):
161 with open(self._protocols[protocol], "r") as f:
162 data = json.load(f)
163 else:
164 data = json.load(f)
165 fileobj.seek(0)
166
167 retval = {}
168 for subset, samples in data.items():
169 retval[subset] = [
170 self._loader(
171 dict(protocol=protocol, subset=subset, order=n),
172 dict(zip(self.fieldnames, k))
173 )
174 for n, k in enumerate(samples)
175 ]
176
177 return retval
178
179
180class CSVDataset:
181 """
182 Generic multi-subset filelist dataset that yields samples
183
184 To create a new dataset, you only need to provide a CSV formatted filelist
185 using any separator (e.g. comma, space, semi-colon) with the following
186 information:
187
188 .. code-block:: text
189
190 value1,value2,value3
191 value4,value5,value6
192 ...
193
194 Notice that all rows must have the same number of entries.
195
196 Parameters
197 ----------
198
199 subsets : list, dict
200 Paths to one or more CSV formatted files containing the various subsets
201 to be recognized by this dataset, or a dictionary, mapping subset names
202 to paths (or opened file objects) of CSV files. Internally, we save a
203 dictionary where keys default to the basename of paths (list input).
204
205 fieldnames : list, tuple
206 An iterable over the field names (strings) to assign to each column in
207 the CSV file. It should have as many items as fields in each row of
208 the CSV file(s).
209
210 loader : object
211 A function that receives as input, a context dictionary (with, at
212 least, a "subset" key indicating which subset is being served), and a
213 dictionary with ``{key: path}`` entries, and returns a dictionary with
214 the loaded data.
215
216 """
217
218 def __init__(self, subsets, fieldnames, loader):
219
220 if isinstance(subsets, dict):
221 self._subsets = subsets
222 else:
223 self._subsets = dict(
224 (os.path.splitext(os.path.basename(k))[0], k) for k in subsets
225 )
226 self.fieldnames = fieldnames
227 self._loader = loader
228
229 def check(self, limit=0):
230 """For each subset, check if all data can be correctly accessed
231
232 This function assumes each sample has a ``data`` and a ``key``
233 attribute. The ``key`` attribute should be a string, or representable
234 as such.
235
236
237 Parameters
238 ----------
239
240 limit : int
241 Maximum number of samples to check (in each protocol/subset
242 combination) in this dataset. If set to zero, then check
243 everything.
244
245
246 Returns
247 -------
248
249 errors : int
250 Number of errors found
251
252 """
253
254 logger.info(f"Checking dataset...")
255 errors = 0
256 for name in self._subsets.keys():
257 logger.info(f"Checking subset '{name}'...")
258 samples = self.samples(name)
259 if limit:
260 logger.info(f"Checking at most first '{limit}' samples...")
261 samples = samples[:limit]
262 for pos, sample in enumerate(samples):
263 try:
264 sample.data # may trigger data loading
265 logger.info(f"{sample.key}: OK")
266 except Exception as e:
267 logger.error(
268 f"Found error loading entry {pos} in subset {name} "
269 f"from file '{self._subsets[name]}': {e}"
270 )
271 errors += 1
272 return errors
273
274 def subsets(self):
275 """Returns all available subsets at once
276
277 Returns
278 -------
279
280 subsets : dict
281 A dictionary mapping subset names to lists of objects (respecting
282 the ``key``, ``data`` interface).
283
284 """
285
286 return dict((k, self.samples(k)) for k in self._subsets.keys())
287
288 def samples(self, subset):
289 """Returns all samples in a subset
290
291 This method will load CSV information for a given subset and return
292 all samples of the given subset after passing each entry through the
293 loading function.
294
295
296 Parameters
297 ----------
298
299 subset : str
300 Name of the subset data to load
301
302
303 Returns
304 -------
305
306 subset : list
307 A lists of objects (respecting the ``key``, ``data`` interface).
308
309 """
310
311 fileobj = self._subsets[subset]
312 if isinstance(fileobj, (str, bytes, pathlib.Path)):
313 with open(self._subsets[subset], newline="") as f:
314 cf = csv.reader(f)
315 samples = [k for k in cf]
316 else:
317 cf = csv.reader(fileobj)
318 samples = [k for k in cf]
319 fileobj.seek(0)
320
321 return [
322 self._loader(
323 dict(subset=subset, order=n), dict(zip(self.fieldnames, k))
324 )
325 for n, k in enumerate(samples)
326 ]