Coverage for src/deepdraw/data/dataset.py: 79%
81 statements
« prev ^ index » next coverage.py v7.3.1, created at 2023-11-30 15:00 +0100
« prev ^ index » next coverage.py v7.3.1, created at 2023-11-30 15:00 +0100
1# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
2#
3# SPDX-License-Identifier: GPL-3.0-or-later
5import csv
6import json
7import logging
8import os
9import pathlib
11logger = logging.getLogger(__name__)
14class JSONDataset:
15 """Generic multi-protocol/subset filelist dataset that yields samples.
17 To create a new dataset, you need to provide one or more JSON formatted
18 filelists (one per protocol) with the following contents:
20 .. code-block:: json
22 {
23 "subset1": [
24 [
25 "value1",
26 "value2",
27 "value3"
28 ],
29 [
30 "value4",
31 "value5",
32 "value6"
33 ]
34 ],
35 "subset2": [
36 ]
37 }
39 Your dataset many contain any number of subsets, but all sample entries
40 must contain the same number of fields.
43 Parameters
44 ----------
46 protocols : list, dict
47 Paths to one or more JSON formatted files containing the various
48 protocols to be recognized by this dataset, or a dictionary, mapping
49 protocol names to paths (or opened file objects) of CSV files.
50 Internally, we save a dictionary where keys default to the basename of
51 paths (list input).
53 fieldnames : list, tuple
54 An iterable over the field names (strings) to assign to each entry in
55 the JSON file. It should have as many items as fields in each entry of
56 the JSON file.
58 loader : object
59 A function that receives as input, a context dictionary (with at least
60 a "protocol" and "subset" keys indicating which protocol and subset are
61 being served), and a dictionary with ``{fieldname: value}`` entries,
62 and returns an object with at least 2 attributes:
64 * ``key``: which must be a unique string for every sample across
65 subsets in a protocol, and
66 * ``data``: which contains the data associated witht this sample
67 """
69 def __init__(self, protocols, fieldnames, loader):
70 if isinstance(protocols, dict):
71 self._protocols = protocols
72 else:
73 self._protocols = {
74 os.path.splitext(os.path.basename(k))[0]: k for k in protocols
75 }
76 self.fieldnames = fieldnames
77 self._loader = loader
79 def check(self, limit=0):
80 """For each protocol, check if all data can be correctly accessed.
82 This function assumes each sample has a ``data`` and a ``key``
83 attribute. The ``key`` attribute should be a string, or representable
84 as such.
87 Parameters
88 ----------
90 limit : int
91 Maximum number of samples to check (in each protocol/subset
92 combination) in this dataset. If set to zero, then check
93 everything.
96 Returns
97 -------
99 errors : int
100 Number of errors found
101 """
103 logger.info("Checking dataset...")
104 errors = 0
105 for proto in self._protocols:
106 logger.info(f"Checking protocol '{proto}'...")
107 for name, samples in self.subsets(proto).items():
108 logger.info(f"Checking subset '{name}'...")
109 if limit:
110 logger.info(f"Checking at most first '{limit}' samples...")
111 samples = samples[:limit]
112 for pos, sample in enumerate(samples):
113 try:
114 sample.data # may trigger data loading
115 logger.info(f"{sample.key}: OK")
116 except Exception as e:
117 logger.error(
118 f"Found error loading entry {pos} in subset {name} "
119 f"of protocol {proto} from file "
120 f"'{self._protocols[proto]}': {e}"
121 )
122 errors += 1
123 except Exception as e:
124 logger.error(f"{sample.key}: {e}")
125 errors += 1
126 return errors
128 def subsets(self, protocol):
129 """Returns all subsets in a protocol.
131 This method will load JSON information for a given protocol and return
132 all subsets of the given protocol after converting each entry through
133 the loader function.
135 Parameters
136 ----------
138 protocol : str
139 Name of the protocol data to load
142 Returns
143 -------
145 subsets : dict
146 A dictionary mapping subset names to lists of objects (respecting
147 the ``key``, ``data`` interface).
148 """
150 fileobj = self._protocols[protocol]
151 if isinstance(fileobj, (str, bytes, pathlib.Path)):
152 with open(self._protocols[protocol]) as f:
153 data = json.load(f)
154 else:
155 data = json.load(f)
156 fileobj.seek(0)
158 retval = {}
159 for subset, samples in data.items():
160 retval[subset] = [
161 self._loader(
162 dict(protocol=protocol, subset=subset, order=n),
163 dict(zip(self.fieldnames, k)),
164 )
165 for n, k in enumerate(samples)
166 ]
168 return retval
171class CSVDataset:
172 """Generic multi-subset filelist dataset that yields samples.
174 To create a new dataset, you only need to provide a CSV formatted filelist
175 using any separator (e.g. comma, space, semi-colon) with the following
176 information:
178 .. code-block:: text
180 value1,value2,value3
181 value4,value5,value6
182 ...
184 Notice that all rows must have the same number of entries.
186 Parameters
187 ----------
189 subsets : list, dict
190 Paths to one or more CSV formatted files containing the various subsets
191 to be recognized by this dataset, or a dictionary, mapping subset names
192 to paths (or opened file objects) of CSV files. Internally, we save a
193 dictionary where keys default to the basename of paths (list input).
195 fieldnames : list, tuple
196 An iterable over the field names (strings) to assign to each column in
197 the CSV file. It should have as many items as fields in each row of
198 the CSV file(s).
200 loader : object
201 A function that receives as input, a context dictionary (with, at
202 least, a "subset" key indicating which subset is being served), and a
203 dictionary with ``{key: path}`` entries, and returns a dictionary with
204 the loaded data.
205 """
207 def __init__(self, subsets, fieldnames, loader):
208 if isinstance(subsets, dict):
209 self._subsets = subsets
210 else:
211 self._subsets = {
212 os.path.splitext(os.path.basename(k))[0]: k for k in subsets
213 }
214 self.fieldnames = fieldnames
215 self._loader = loader
217 def check(self, limit=0):
218 """For each subset, check if all data can be correctly accessed.
220 This function assumes each sample has a ``data`` and a ``key``
221 attribute. The ``key`` attribute should be a string, or representable
222 as such.
225 Parameters
226 ----------
228 limit : int
229 Maximum number of samples to check (in each protocol/subset
230 combination) in this dataset. If set to zero, then check
231 everything.
234 Returns
235 -------
237 errors : int
238 Number of errors found
239 """
241 logger.info("Checking dataset...")
242 errors = 0
243 for name in self._subsets.keys():
244 logger.info(f"Checking subset '{name}'...")
245 samples = self.samples(name)
246 if limit:
247 logger.info(f"Checking at most first '{limit}' samples...")
248 samples = samples[:limit]
249 for pos, sample in enumerate(samples):
250 try:
251 sample.data # may trigger data loading
252 logger.info(f"{sample.key}: OK")
253 except Exception as e:
254 logger.error(
255 f"Found error loading entry {pos} in subset {name} "
256 f"from file '{self._subsets[name]}': {e}"
257 )
258 errors += 1
259 return errors
261 def subsets(self):
262 """Returns all available subsets at once.
264 Returns
265 -------
267 subsets : dict
268 A dictionary mapping subset names to lists of objects (respecting
269 the ``key``, ``data`` interface).
270 """
272 return {k: self.samples(k) for k in self._subsets.keys()}
274 def samples(self, subset):
275 """Returns all samples in a subset.
277 This method will load CSV information for a given subset and return
278 all samples of the given subset after passing each entry through the
279 loading function.
282 Parameters
283 ----------
285 subset : str
286 Name of the subset data to load
289 Returns
290 -------
292 subset : list
293 A lists of objects (respecting the ``key``, ``data`` interface).
294 """
296 fileobj = self._subsets[subset]
297 if isinstance(fileobj, (str, bytes, pathlib.Path)):
298 with open(self._subsets[subset], newline="") as f:
299 cf = csv.reader(f)
300 samples = [k for k in cf]
301 else:
302 cf = csv.reader(fileobj)
303 samples = [k for k in cf]
304 fileobj.seek(0)
306 return [
307 self._loader(
308 dict(subset=subset, order=n), dict(zip(self.fieldnames, k))
309 )
310 for n, k in enumerate(samples)
311 ]