Source code for bob.pipelines.dataset.database

The principles of this module are:

* one csv file -> one set
* one row -> one sample
* csv files could exist in a tarball or inside a folder
* scikit-learn transformers are used to further transform samples
* several csv files (sets) compose a protocol
* several protocols compose a database
import csv
import itertools
import os

from import Iterable
from pathlib import Path
from typing import Any, Optional, TextIO, Union

import sklearn.pipeline

from bob.pipelines.dataset.protocols.retrieve import (

from ..sample import Sample
from ..utils import check_parameter_for_validity, check_parameters_for_validity

def _maybe_open_file(path, **kwargs):
    if isinstance(path, (str, bytes, Path)):
        path = open(path, **kwargs)
    return path

class FileListToSamples(Iterable):
    """Converts a list of paths and metadata to a list of samples.

    This class reads a file containing paths and optionally metadata and returns a list
    of :py:class:`bob.pipelines.Sample`\\ s when called.

    A separator character can be set (defaults is space) to split the rows.
    No escaping is done (no quotes).

    A Transformer can be given to apply a transform on each sample. (Keep in mind this
    will not be distributed on Dask; Prefer applying Transformer in a

    def __init__(
        list_file: str,
        separator: str = " ",
        transformer: Optional[sklearn.pipeline.Pipeline] = None,
        self.list_file = list_file
        self.transformer = transformer
        self.separator = separator

    def __iter__(self):
        for row_dict in self.rows:
            sample = Sample(None, **row_dict)
            if self.transformer is not None:
                # The transformer might convert one sample to several samples
                for s in self.transformer.transform([sample]):
                    yield s
                yield sample

    def rows(self) -> dict[str, Any]:
        with open(self.list_file, "rt") as f:
            for line in f:
                yield dict(line.split(self.separator))

class CSVToSamples(FileListToSamples):
    """Converts a csv file to a list of samples"""

    def __init__(
        list_file: str,
        transformer: Optional[sklearn.pipeline.Pipeline] = None,
        dict_reader_kwargs: Optional[dict[str, Any]] = None,
        list_file = _maybe_open_file(list_file, newline="")
        self.dict_reader_kwargs = dict_reader_kwargs

    def rows(self):
        kw = self.dict_reader_kwargs or {}
        reader = csv.DictReader(self.list_file, **kw)
        return reader

class FileListDatabase:
    """A generic database interface.
    Use this class to convert csv files to a database that outputs samples. The
    format is simple, the files must be inside a folder (or a compressed
    tarball) with the following format::


    The top folders are the name of the protocols (if you only have one, you may
    name it ``default``). Inside each protocol folder, there are `<group>.csv`
    files where the name of the file specifies the name of the group. We
    recommend using the names ``train``, ``dev``, ``eval`` for your typical
    training, development, and test sets.


    def __init__(
        name: str,
        protocol: str,
        dataset_protocols_path: Union[os.PathLike[str], str, None] = None,
        reader_cls: Iterable = CSVToSamples,
        transformer: Optional[sklearn.pipeline.Pipeline] = None,
            Path to a folder or a tarball where the csv protocol files are located.
            The name of the protocol to be used for samples. If None, the first
            protocol found will be used.
            An iterable that returns created Sample objects from a list file.
            A scikit-learn transformer that further changes the samples.

            If the dataset_protocols_path does not exist.

        # Tricksy trick to make protocols non-classmethod when instantiated
        self.protocols = self._instance_protocols

        if getattr(self, "name", None) is None:
   = name

        if dataset_protocols_path is None:
            dataset_protocols_path = self.retrieve_dataset_protocols()

        self.dataset_protocols_path = Path(dataset_protocols_path)

        if len(self.protocols()) < 1:
            raise ValueError(
                f"No protocols found at `{dataset_protocols_path}`!"
        self.reader_cls = reader_cls
        self._transformer = transformer
        self.readers: dict[str, Iterable] = {}
        self._protocol = None
        self.protocol = protocol

    def protocol(self) -> str:
        return self._protocol

    def protocol(self, value: str):
        value = check_parameter_for_validity(
            value, "protocol", self.protocols(), self.protocols()[0]
        self._protocol = value

    def transformer(self) -> sklearn.pipeline.Pipeline:
        return self._transformer

    def transformer(self, value: sklearn.pipeline.Pipeline):
        self._transformer = value
        for reader in self.readers.values():
            reader.transformer = value

[docs] def groups(self) -> list[str]: """Returns all the available groups.""" return list_group_names(, protocol=self.protocol,, base_dir=self.dataset_protocols_path.parent, subdir=".", )
def _instance_protocols(self) -> list[str]: """Returns all the available protocols.""" return list_protocol_names(,, base_dir=self.dataset_protocols_path.parent, subdir=".", )
[docs] @classmethod def protocols(cls) -> list[str]: # pylint: disable=method-hidden """Returns all the available protocols.""" # Ensure the definition file exists locally loc = cls.retrieve_dataset_protocols() if not hasattr(cls, "name"): raise ValueError(f"{cls} has no attribute 'name'.") return list_protocol_names( database_name=getattr(cls, "name"),, base_dir=loc.parent, subdir=".", )
[docs] @classmethod def retrieve_dataset_protocols(cls) -> Path: """Return a path to the protocols definition files. If the files are not present locally in ``bob_data/<subdir>/<category>``, they will be downloaded. The class inheriting from CSVDatabase must have a ``name`` and an ``dataset_protocols_urls`` attributes. A ``checksum`` attribute can be used to verify the file and ensure the correct version is used. """ # When the path is specified, just return it. if getattr(cls, "dataset_protocols_path", None) is not None: return getattr(cls, "dataset_protocols_path") # Save to bob_data/protocols, or if present, in a category sub directory. subdir = Path("protocols") if hasattr(cls, "category"): subdir = subdir / getattr(cls, "category") # Retrieve the file from the server (or use the local version). return retrieve_protocols( urls=getattr(cls, "dataset_protocols_urls"), destination_filename=getattr(cls, "dataset_protocols_name", None), base_dir=None, subdir=subdir, checksum=getattr(cls, "dataset_protocols_checksum", None), )
[docs] def list_file(self, group: str) -> TextIO: """Returns the corresponding definition file of a group.""" list_file = open_definition_file( search_pattern=group + ".csv",, protocol=self.protocol,, base_dir=self.dataset_protocols_path.parent, subdir=".", ) return list_file
[docs] def get_reader(self, group: str) -> Iterable: """Returns an :any:`Iterable` of :any:`Sample` objects.""" key = (self.protocol, group) if key not in self.readers: self.readers[key] = self.reader_cls( list_file=self.list_file(group), transformer=self.transformer ) reader = self.readers[key] return reader
[docs] def samples(self, groups=None): """Get samples of a certain group Parameters ---------- groups : :obj:`str`, optional A str or list of str to be used for filtering samples, by default None Returns ------- list A list containing the samples loaded from csv files. """ groups = check_parameters_for_validity( groups, "groups", self.groups(), self.groups() ) all_samples = [] for grp in groups: for sample in self.get_reader(grp): all_samples.append(sample) return all_samples
[docs] @staticmethod def sort(samples: list[Sample], unique: bool = True): """Sorts samples and removes duplicates by default.""" def key_func(x): return x.key samples = sorted(samples, key=key_func) if unique: samples = [ next(iter(v)) for _, v in itertools.groupby(samples, key=key_func) ] return samples