Coverage for src/bob/bio/base/database/csv_database.py: 92%
190 statements
« prev ^ index » next coverage.py v7.6.5, created at 2024-11-14 21:41 +0100
« prev ^ index » next coverage.py v7.6.5, created at 2024-11-14 21:41 +0100
1import csv
2import functools
3import logging
4import os
6from collections import defaultdict
7from pathlib import Path
8from typing import Any, Callable, Iterable, Optional, TextIO
10import sklearn.pipeline
12from sklearn.base import BaseEstimator, TransformerMixin
14from bob.bio.base.pipelines.abstract_classes import Database
15from bob.pipelines import (
16 DelayedSample,
17 FileListDatabase,
18 Sample,
19 SampleSet,
20 check_parameters_for_validity,
21)
22from bob.pipelines.dataset import open_definition_file
24from ..utils.annotations import read_annotation_file
26logger = logging.getLogger(__name__)
29def _sample_sets_to_samples(sample_sets):
30 return functools.reduce(
31 lambda x, y: x + y, (s.samples for s in sample_sets), []
32 )
35def _add_key(samples: list[Sample]) -> list[Sample]:
36 """Adds a ``key`` attribute to all samples if ``key`` is not present
38 Will use ``path`` to create a unique ``key``.
39 Note that this won't create unique keys if you have multiple times the same path in
40 different samples. This will be problematic, as key are expected to be unique.
41 """
43 out = []
44 for sample in samples:
45 if isinstance(sample, SampleSet):
46 out.append(SampleSet(samples=_add_key(sample), parent=sample))
47 continue
48 if not hasattr(sample, "key"):
49 if hasattr(sample, "path"):
50 sample.key = sample.path
51 else:
52 raise ValueError(
53 f"Sample has no 'key' and no 'path' to infer it. {sample=}"
54 )
55 out.append(sample)
56 return out
59def validate_bio_samples(samples):
60 """Validates Samples or SampleSets for backwards compatibility reasons.
62 This will add a ``key`` attribute (if not already present) to each sample, copied
63 from the path.
64 """
66 for sample in samples:
67 if isinstance(sample, SampleSet):
68 validate_bio_samples(sample.samples)
69 if not hasattr(sample, "template_id"):
70 raise ValueError(
71 f"SampleSet must have a template_id attribute, got {sample}"
72 )
73 if not hasattr(sample, "subject_id"):
74 raise ValueError(
75 f"SampleSet must have a subject_id attribute, got {sample}"
76 )
77 continue
79 if not hasattr(sample, "key"):
80 if hasattr(sample, "path"):
81 sample.key = sample.path
82 else:
83 raise ValueError(
84 f"Sample must have a key or a path attribute, got {sample}"
85 )
87 if not hasattr(sample, "subject_id"):
88 raise ValueError(
89 f"Sample must have a subject_id attribute, got {sample}"
90 )
93class CSVDatabase(FileListDatabase, Database):
94 """A csv file database.
96 The database protocol files should provide the following files:
99 .. code-block:: text
101 dataset_protocols_path/
102 dataset_protocols_path/my_protocol/train/for_background_model.csv
103 dataset_protocols_path/my_protocol/train/for_znorm.csv
104 dataset_protocols_path/my_protocol/train/for_tnorm.csv
105 dataset_protocols_path/my_protocol/dev/for_enrolling.csv
106 dataset_protocols_path/my_protocol/dev/for_probing.csv
107 dataset_protocols_path/my_protocol/dev/for_matching.csv
108 dataset_protocols_path/my_protocol/eval/for_enrolling.csv
109 dataset_protocols_path/my_protocol/eval/for_probing.csv
110 dataset_protocols_path/my_protocol/eval/for_matching.csv ...
112 The ``for_background_model`` file should contain the following columns::
114 key,subject_id
115 subject1_image1.png,1
116 subject1_image2.png,1
117 subject2_image1.png,2
118 subject2_image2.png,2
120 In all the csv files, you can have a column called ``path`` which will be
121 used as ``key`` if the ``key`` is not specified. For example::
123 path,subject_id
124 subject1_image1.png,1
125 subject1_image2.png,1
126 subject2_image1.png,2
127 subject2_image2.png,2
129 or::
131 path,subject_id,key
132 subject1_audio1.wav,1,subject1_audio1_channel1
133 subject1_audio1.wav,1,subject1_audio1_channel2
134 subject1_audio2.wav,1,subject1_audio2_channel1
135 subject1_audio2.wav,1,subject1_audio2_channel2
137 The ``key`` column will be used to checkpoint each sample into a unique file and
138 must therefore be unique across the whole dataset.
140 The ``for_enrolling.csv`` file should contain the following columns::
142 key,subject_id,template_id
143 subject3_image1.png,3,template_1
144 subject3_image2.png,3,template_1
145 subject3_image3.png,3,template_2
146 subject3_image4.png,3,template_2
147 subject4_image1.png,4,template_3
148 subject4_image2.png,4,template_3
149 subject4_image3.png,4,template_4
150 subject4_image4.png,4,template_4
152 The ``for_probing.csv`` file should contain the following columns::
154 key,subject_id,template_id
155 subject5_image1.png,5,template_5
156 subject5_image2.png,5,template_5
157 subject5_image3.png,5,template_6
158 subject5_image4.png,5,template_6
159 subject6_image1.png,6,template_7
160 subject6_image2.png,6,template_7
161 subject6_image3.png,6,template_8
162 subject6_image4.png,6,template_8
164 Subject identity (``subject_id``) is a unique identifier for one identity (one
165 person).
166 Template Identity (``template_id``) is an identifier used to group samples when
167 they need to be enrolled or scored together.
168 :class:`~bob.bio.base.pipelines.BioAlgorithm` will process
169 these template.
172 By default, each enroll ``template_id`` will be compared against each
173 probe ``template_id`` to produce one score per pair. If you want to specify exact
174 comparisons (sparse scoring), you can add the ``for_matching.csv`` with the
175 following columns::
177 enroll_template_id,probe_template_id
178 template_1,template_5
179 template_2,template_6
180 template_3,template_5
181 template_3,template_7
182 template_4,template_5
183 template_4,template_8
185 ``for_znorm.csv`` and ``for_tnorm.csv`` files are optional and are used for score
186 normalization. See :class:`~bob.bio.base.pipelines.PipelineScoreNorm`.
187 ``for_znorm.csv`` has the same format as ``for_probing.csv`` and
188 ``for_tnorm.csv`` has the same format as ``for_enrolling.csv``.
189 """
191 def __init__(
192 self,
193 *,
194 name: str,
195 protocol: str,
196 dataset_protocols_path: Optional[str] = None,
197 transformer: Optional[sklearn.pipeline.Pipeline] = None,
198 templates_metadata: Optional[list[str]] = None,
199 annotation_type: Optional[str] = None,
200 fixed_positions: Optional[dict[str, tuple[float, float]]] = None,
201 memory_demanding=False,
202 **kwargs,
203 ):
204 """
205 Parameters
206 ----------
207 name
208 The name of the database.
209 protocol
210 Name of the protocol folder in the CSV definition structure.
211 dataset_protocol_path
212 Path to the CSV files structure (see :ref:`bob.bio.base.database_interface`
213 for more info).
214 transformer
215 An sklearn pipeline or equivalent transformer that handles some light
216 preprocessing of the samples (This will always run locally).
217 templates_metadata
218 Metadata that originate from the samples and must be present in the
219 templates (SampleSet) e.g. ``["gender", "age"]``. This should be metadata
220 that is common to all the samples in a template.
221 annotation_type
222 A string describing the annotations passed to the annotation loading
223 function
224 fixed_positions
225 TODO Why is it here? What does it do exactly?
226 --> move it when the FaceCrop annotator is implemented correctly.
227 memory_demanding
228 Flag that indicates that experiments using this should not run on low-mem
229 workers.
230 """
231 if not hasattr(self, "name"):
232 self.name = name
233 transformer = sklearn.pipeline.make_pipeline(
234 sklearn.pipeline.FunctionTransformer(_add_key), transformer
235 )
236 super().__init__(
237 name=name, # For FileListDatabase
238 protocol=protocol,
239 dataset_protocols_path=dataset_protocols_path,
240 transformer=transformer,
241 annotation_type=annotation_type,
242 fixed_positions=fixed_positions,
243 memory_demanding=memory_demanding,
244 **kwargs,
245 )
246 if self.list_file("dev", "for_matching") is None:
247 self.score_all_vs_all = True
248 else:
249 self.score_all_vs_all = False
251 self.templates_metadata = []
252 if templates_metadata is not None:
253 self.templates_metadata = templates_metadata
255 def list_file(self, group: str, name: str) -> TextIO:
256 """Returns a definition file containing one sample per row.
258 Overloads ``bob.pipelines`` list_file as the group is a dir.
259 """
261 try:
262 list_file = open_definition_file(
263 search_pattern=Path(group) / (name + ".csv"),
264 database_name=self.name,
265 protocol=self.protocol,
266 database_filename=self.dataset_protocols_path.name,
267 base_dir=self.dataset_protocols_path.parent,
268 subdir=".",
269 )
270 return list_file
271 except FileNotFoundError:
272 return None
274 def get_reader(self, group: str, name: str) -> Iterable:
275 """Returns an :any:`Iterable` containing :class:`Sample` or :class:`SampleSet` objects."""
276 key = (self.protocol, group, name)
277 if key not in self.readers:
278 list_file = self.list_file(group, name)
279 self.readers[key] = None
280 if list_file is not None:
281 self.readers[key] = self.reader_cls(
282 list_file=list_file, transformer=self.transformer
283 )
285 reader = self.readers[key]
286 return reader
288 # cached methods should be based on protocol as well
289 @functools.lru_cache(maxsize=None)
290 def _background_model_samples(self, protocol):
291 reader = self.get_reader("train", "for_background_model")
292 if reader is None:
293 return []
294 samples = list(reader)
295 validate_bio_samples(samples)
296 return samples
298 def background_model_samples(self):
299 return self._background_model_samples(self.protocol)
301 def _sample_sets(self, group, name):
302 # we need protocol as input so we can cache the result
303 reader = self.get_reader(group, name)
304 if reader is None:
305 return []
306 # create Sample_sets from samples given their unique enroll_template_id/probe_template_id
307 samples_grouped_by_template_id = defaultdict(list)
308 for sample in reader:
309 samples_grouped_by_template_id[sample.template_id].append(sample)
310 sample_sets = []
311 for (
312 template_id,
313 samples_for_template_id,
314 ) in samples_grouped_by_template_id.items():
315 # since all samples inside one sampleset have the same subject_id,
316 # we add that as well.
317 samples = list(samples_for_template_id)
318 subject_id = samples[0].subject_id
319 metadata = {
320 m: getattr(samples[0], m) for m in self.templates_metadata
321 }
322 sample_sets.append(
323 SampleSet(
324 samples,
325 template_id=template_id,
326 subject_id=subject_id,
327 key=f"template_{template_id}",
328 **metadata,
329 )
330 )
331 validate_bio_samples(sample_sets)
332 return sample_sets
334 # cached methods should be based on protocol as well
335 @functools.lru_cache(maxsize=None)
336 def _references(self, protocol, group): # TODO: protocol
337 return self._sample_sets(group, "for_enrolling")
339 def references(self, group="dev"):
340 return self._references(self.protocol, group)
342 def _add_all_references(self, sample_sets, group):
343 references = [s.template_id for s in self.references(group)]
344 for sample_set in sample_sets:
345 sample_set.references = references
347 # cached methods should be based on protocol as well
348 @functools.lru_cache(maxsize=None)
349 def _probes(self, protocol, group): # TODO: protocol
350 sample_sets = self._sample_sets(group, "for_probing")
352 # if there are no probes
353 if not sample_sets:
354 return sample_sets
356 # populate .references for each sample set
357 matching_file = self.list_file(group, "for_matching")
358 if matching_file is None:
359 self._add_all_references(sample_sets, group)
361 # read the matching file
362 else:
363 # references is dict where key is probe_template_id and value is a
364 # list of enroll_template_ids
365 references = defaultdict(list)
366 reader = csv.DictReader(matching_file)
367 for row in reader:
368 references[row["probe_template_id"]].append(
369 row["enroll_template_id"]
370 )
372 for sample_set in sample_sets:
373 sample_set.references = references[sample_set.template_id]
375 return sample_sets
377 def probes(self, group="dev"):
378 return self._probes(self.protocol, group)
380 def samples(self, groups=None):
381 """Get samples of a certain group
383 Parameters
384 ----------
385 groups : :obj:`str`, optional
386 A str or list of str to be used for filtering samples, by default None
388 Returns
389 -------
390 list
391 A list containing the samples loaded from csv files.
392 """
393 groups = check_parameters_for_validity(
394 groups, "groups", self.groups(), self.groups()
395 )
396 all_samples = []
398 if "train" in groups:
399 all_samples.extend(self.background_model_samples())
400 groups.remove("train")
402 for grp in groups:
403 all_samples.extend(_sample_sets_to_samples(self.references(grp)))
404 all_samples.extend(_sample_sets_to_samples(self.probes(grp)))
406 # Add znorm samples. Returning znorm samples for one group of dev or
407 # eval is enough because they are duplicated.
408 for grp in groups:
409 all_samples.extend(_sample_sets_to_samples(self.zprobes(grp)))
410 break
412 # Add tnorm samples.
413 all_samples.extend(_sample_sets_to_samples(self.treferences()))
415 return all_samples
417 def all_samples(self, groups=None):
418 return self.samples(groups)
420 @functools.lru_cache(maxsize=None)
421 def _zprobes(self, protocol, group):
422 sample_sets = self._sample_sets("train", "for_znorm")
423 if not sample_sets:
424 return sample_sets
426 self._add_all_references(sample_sets, group)
428 return sample_sets
430 def zprobes(self, group="dev", proportion=1.0):
431 sample_sets = self._zprobes(self.protocol, group)
432 if not sample_sets:
433 return sample_sets
435 sample_sets = sample_sets[: int(len(sample_sets) * proportion)]
436 return sample_sets
438 @functools.lru_cache(maxsize=None)
439 def _treferences(self, protocol):
440 sample_sets = self._sample_sets("train", "for_tnorm")
441 return sample_sets
443 def treferences(self, proportion=1.0):
444 sample_sets = self._treferences(self.protocol)
445 if not sample_sets:
446 return sample_sets
448 sample_sets = sample_sets[: int(len(sample_sets) * proportion)]
449 return sample_sets
452class FileSampleLoader(BaseEstimator, TransformerMixin):
453 """Loads file-based samples into :class:`~bob.pipelines.DelayedSample` objects.
455 Given the :attr:`sample.path` attribute,``dataset_original_directory`` and an
456 ``extension``, this transformer will load lazily the samples from the file.
457 The ``data_loader`` is used to load the data.
459 The resulting :class:`~bob.pipelines.DelayedSample` objects will call
460 ``data_loader`` when their :attr:`data` is accessed.
462 This transformer will not access the data files.
464 Parameters
465 ----------
466 data_loader
467 A callable to load the sample, given the full path to the file.
468 dataset_original_directory
469 Path of where the raw data files are stored. This will be prepended to the
470 ``path`` attribute of the samples.
471 extension
472 File extension of the raw data files. This will be appended to the ``path``
473 attribute of the samples.
474 """
476 def __init__(
477 self,
478 data_loader: Callable[[str], Any],
479 dataset_original_directory: str = "",
480 extension: str = "",
481 **kwargs,
482 ):
483 super().__init__(**kwargs)
484 self.data_loader = data_loader
485 self.dataset_original_directory = dataset_original_directory
486 self.extension = extension
488 def transform(self, samples: list[Sample]) -> list[DelayedSample]:
489 """Prepares the data into :class:`~bob.pipelines.DelayedSample` objects.
491 Transforms :class:`~bob.pipelines.Sample` objects with a ``path`` attribute to
492 :class:`~bob.pipelines.DelayedSample` with data ready to be loaded (lazily) by
493 :attr:`data_loader`.
495 When needed (access to the :class:`DelayedSample`\\ 's :attr:`data` attribute),
496 :attr:`data_loader` will be called with the path (extended with
497 :attr:`original_directory` and :attr:`extension`) as argument.
499 Parameters
500 ----------
501 samples
502 :class:`~bob.pipelines.Sample` objects with their ``path`` attribute
503 containing a path to a file to load.
504 """
505 output = []
506 for sample in samples:
507 path = getattr(sample, "path")
508 delayed_sample = DelayedSample(
509 functools.partial(
510 self.data_loader,
511 os.path.join(
512 # we append ./ to path to make sure that the path is
513 # relative to the dataset_original_directory
514 self.dataset_original_directory,
515 f"./{path + self.extension}",
516 ),
517 ),
518 parent=sample,
519 )
520 output.append(delayed_sample)
521 return output
523 def _more_tags(self):
524 return {"requires_fit": False}
527class AnnotationsLoader(TransformerMixin, BaseEstimator):
528 """Prepares annotations to be loaded from a path in ``delayed_attributes``.
530 Metadata loader that loads samples' annotations using
531 :py:func:`~bob.bio.base.utils.annotations.read_annotation_file`. This assumes that
532 the annotation files follows the same folder structure and naming as the raw data
533 files. Although, the base location and the extension can vary from those and is
534 specified by :attr:`annotation_directory` and :attr:`annotation_extension`.
536 Parameters
537 ----------
538 annotation_directory
539 Path where the annotations are stored.
541 annotation_extension : str
542 Extension of the annotations.
544 annotation_type : str
545 Annotations type passed to
546 :func:`~bob.bio.base.utils.annotations.read_annotation_file`.
548 """
550 def __init__(
551 self,
552 annotation_directory: Optional[str] = None,
553 annotation_extension: str = ".json",
554 annotation_type: str = "json",
555 ):
556 self.annotation_directory = annotation_directory
557 self.annotation_extension = annotation_extension
558 self.annotation_type = annotation_type
560 def transform(self, X: list[DelayedSample]) -> list[DelayedSample]:
561 """Edits the samples to lazily load annotations files.
563 Parameters
564 ----------
565 X
566 The samples to augment.
567 """
568 if self.annotation_directory is None:
569 return None
571 annotated_samples = []
572 for x in X:
573 # we use .key here because .path might not be unique for all
574 # samples. Also, the ``bob bio annotate-samples`` command dictates
575 # how annotations are stored.
576 annotation_file = os.path.join(
577 self.annotation_directory, x.key + self.annotation_extension
578 )
580 annotated_samples.append(
581 DelayedSample.from_sample(
582 x,
583 delayed_attributes=dict(
584 annotations=lambda: read_annotation_file(
585 annotation_file, self.annotation_type
586 )
587 ),
588 )
589 )
591 return annotated_samples
593 def _more_tags(self):
594 return {"requires_fit": False}