1#!/usr/bin/env python
2# coding=utf-8
3
4import csv
5import json
6import logging
7import os
8import pathlib
9
10logger = logging.getLogger(__name__)
11
12
13class JSONDataset:
14 """
15 Generic multi-protocol/subset filelist dataset that yields samples
16
17 To create a new dataset, you need to provide one or more JSON formatted
18 filelists (one per protocol) with the following contents:
19
20 .. code-block:: json
21
22 {
23 "subset1": [
24 [
25 "value1",
26 "value2",
27 "value3"
28 ],
29 [
30 "value4",
31 "value5",
32 "value6"
33 ]
34 ],
35 "subset2": [
36 ]
37 }
38
39 Your dataset many contain any number of subsets, but all sample entries
40 must contain the same number of fields.
41
42
43 Parameters
44 ----------
45
46 protocols : list, dict
47 Paths to one or more JSON formatted files containing the various
48 protocols to be recognized by this dataset, or a dictionary, mapping
49 protocol names to paths (or opened file objects) of CSV files.
50 Internally, we save a dictionary where keys default to the basename of
51 paths (list input).
52
53 fieldnames : list, tuple
54 An iterable over the field names (strings) to assign to each entry in
55 the JSON file. It should have as many items as fields in each entry of
56 the JSON file.
57
58 loader : object
59 A function that receives as input, a context dictionary (with at least
60 a "protocol" and "subset" keys indicating which protocol and subset are
61 being served), and a dictionary with ``{fieldname: value}`` entries,
62 and returns an object with at least 2 attributes:
63
64 * ``key``: which must be a unique string for every sample across
65 subsets in a protocol, and
66 * ``data``: which contains the data associated witht this sample
67
68 """
69
70 def __init__(self, protocols, fieldnames, loader):
71
72 if isinstance(protocols, dict):
73 self._protocols = protocols
74 else:
75 self._protocols = dict(
76 (os.path.splitext(os.path.basename(k))[0], k) for k in protocols
77 )
78 self.fieldnames = fieldnames
79 self._loader = loader
80
81 def check(self, limit=0):
82 """For each protocol, check if all data can be correctly accessed
83
84 This function assumes each sample has a ``data`` and a ``key``
85 attribute. The ``key`` attribute should be a string, or representable
86 as such.
87
88
89 Parameters
90 ----------
91
92 limit : int
93 Maximum number of samples to check (in each protocol/subset
94 combination) in this dataset. If set to zero, then check
95 everything.
96
97
98 Returns
99 -------
100
101 errors : int
102 Number of errors found
103
104 """
105
106 logger.info("Checking dataset...")
107 errors = 0
108 for proto in self._protocols:
109 logger.info(f"Checking protocol '{proto}'...")
110 for name, samples in self.subsets(proto).items():
111 logger.info(f"Checking subset '{name}'...")
112 if limit:
113 logger.info(f"Checking at most first '{limit}' samples...")
114 samples = samples[:limit]
115 for pos, sample in enumerate(samples):
116 try:
117 sample.data # may trigger data loading
118 logger.info(f"{sample.key}: OK")
119 except Exception as e:
120 logger.error(
121 f"Found error loading entry {pos} in subset {name} "
122 f"of protocol {proto} from file "
123 f"'{self._protocols[proto]}': {e}"
124 )
125 errors += 1
126 except Exception as e:
127 logger.error(f"{sample.key}: {e}")
128 errors += 1
129 return errors
130
131 def subsets(self, protocol):
132 """Returns all subsets in a protocol
133
134 This method will load JSON information for a given protocol and return
135 all subsets of the given protocol after converting each entry through
136 the loader function.
137
138 Parameters
139 ----------
140
141 protocol : str
142 Name of the protocol data to load
143
144
145 Returns
146 -------
147
148 subsets : dict
149 A dictionary mapping subset names to lists of objects (respecting
150 the ``key``, ``data`` interface).
151
152 """
153
154 fileobj = self._protocols[protocol]
155 if isinstance(fileobj, (str, bytes, pathlib.Path)):
156 with open(self._protocols[protocol], "r") as f:
157 data = json.load(f)
158 else:
159 data = json.load(f)
160 fileobj.seek(0)
161
162 retval = {}
163 for subset, samples in data.items():
164 retval[subset] = [
165 self._loader(
166 dict(protocol=protocol, subset=subset, order=n),
167 dict(zip(self.fieldnames, k)),
168 )
169 for n, k in enumerate(samples)
170 ]
171
172 return retval
173
174
175class CSVDataset:
176 """
177 Generic multi-subset filelist dataset that yields samples
178
179 To create a new dataset, you only need to provide a CSV formatted filelist
180 using any separator (e.g. comma, space, semi-colon) with the following
181 information:
182
183 .. code-block:: text
184
185 value1,value2,value3
186 value4,value5,value6
187 ...
188
189 Notice that all rows must have the same number of entries.
190
191 Parameters
192 ----------
193
194 subsets : list, dict
195 Paths to one or more CSV formatted files containing the various subsets
196 to be recognized by this dataset, or a dictionary, mapping subset names
197 to paths (or opened file objects) of CSV files. Internally, we save a
198 dictionary where keys default to the basename of paths (list input).
199
200 fieldnames : list, tuple
201 An iterable over the field names (strings) to assign to each column in
202 the CSV file. It should have as many items as fields in each row of
203 the CSV file(s).
204
205 loader : object
206 A function that receives as input, a context dictionary (with, at
207 least, a "subset" key indicating which subset is being served), and a
208 dictionary with ``{key: path}`` entries, and returns a dictionary with
209 the loaded data.
210
211 """
212
213 def __init__(self, subsets, fieldnames, loader):
214
215 if isinstance(subsets, dict):
216 self._subsets = subsets
217 else:
218 self._subsets = dict(
219 (os.path.splitext(os.path.basename(k))[0], k) for k in subsets
220 )
221 self.fieldnames = fieldnames
222 self._loader = loader
223
224 def check(self, limit=0):
225 """For each subset, check if all data can be correctly accessed
226
227 This function assumes each sample has a ``data`` and a ``key``
228 attribute. The ``key`` attribute should be a string, or representable
229 as such.
230
231
232 Parameters
233 ----------
234
235 limit : int
236 Maximum number of samples to check (in each protocol/subset
237 combination) in this dataset. If set to zero, then check
238 everything.
239
240
241 Returns
242 -------
243
244 errors : int
245 Number of errors found
246
247 """
248
249 logger.info("Checking dataset...")
250 errors = 0
251 for name in self._subsets.keys():
252 logger.info(f"Checking subset '{name}'...")
253 samples = self.samples(name)
254 if limit:
255 logger.info(f"Checking at most first '{limit}' samples...")
256 samples = samples[:limit]
257 for pos, sample in enumerate(samples):
258 try:
259 sample.data # may trigger data loading
260 logger.info(f"{sample.key}: OK")
261 except Exception as e:
262 logger.error(
263 f"Found error loading entry {pos} in subset {name} "
264 f"from file '{self._subsets[name]}': {e}"
265 )
266 errors += 1
267 return errors
268
269 def subsets(self):
270 """Returns all available subsets at once
271
272 Returns
273 -------
274
275 subsets : dict
276 A dictionary mapping subset names to lists of objects (respecting
277 the ``key``, ``data`` interface).
278
279 """
280
281 return dict((k, self.samples(k)) for k in self._subsets.keys())
282
283 def samples(self, subset):
284 """Returns all samples in a subset
285
286 This method will load CSV information for a given subset and return
287 all samples of the given subset after passing each entry through the
288 loading function.
289
290
291 Parameters
292 ----------
293
294 subset : str
295 Name of the subset data to load
296
297
298 Returns
299 -------
300
301 subset : list
302 A lists of objects (respecting the ``key``, ``data`` interface).
303
304 """
305
306 fileobj = self._subsets[subset]
307 if isinstance(fileobj, (str, bytes, pathlib.Path)):
308 with open(self._subsets[subset], newline="") as f:
309 cf = csv.reader(f)
310 samples = [k for k in cf]
311 else:
312 cf = csv.reader(fileobj)
313 samples = [k for k in cf]
314 fileobj.seek(0)
315
316 return [
317 self._loader(
318 dict(subset=subset, order=n), dict(zip(self.fieldnames, k))
319 )
320 for n, k in enumerate(samples)
321 ]