Source code for bob.ip.common.data.dataset

#!/usr/bin/env python
# coding=utf-8

import csv
import json
import logging
import os
import pathlib

logger = logging.getLogger(__name__)


[docs]class JSONDataset: """ Generic multi-protocol/subset filelist dataset that yields samples To create a new dataset, you need to provide one or more JSON formatted filelists (one per protocol) with the following contents: .. code-block:: json { "subset1": [ [ "value1", "value2", "value3" ], [ "value4", "value5", "value6" ] ], "subset2": [ ] } Your dataset many contain any number of subsets, but all sample entries must contain the same number of fields. Parameters ---------- protocols : list, dict Paths to one or more JSON formatted files containing the various protocols to be recognized by this dataset, or a dictionary, mapping protocol names to paths (or opened file objects) of CSV files. Internally, we save a dictionary where keys default to the basename of paths (list input). fieldnames : list, tuple An iterable over the field names (strings) to assign to each entry in the JSON file. It should have as many items as fields in each entry of the JSON file. loader : object A function that receives as input, a context dictionary (with at least a "protocol" and "subset" keys indicating which protocol and subset are being served), and a dictionary with ``{fieldname: value}`` entries, and returns an object with at least 2 attributes: * ``key``: which must be a unique string for every sample across subsets in a protocol, and * ``data``: which contains the data associated witht this sample """ def __init__(self, protocols, fieldnames, loader): if isinstance(protocols, dict): self._protocols = protocols else: self._protocols = dict( (os.path.splitext(os.path.basename(k))[0], k) for k in protocols ) self.fieldnames = fieldnames self._loader = loader
[docs] def check(self, limit=0): """For each protocol, check if all data can be correctly accessed This function assumes each sample has a ``data`` and a ``key`` attribute. The ``key`` attribute should be a string, or representable as such. Parameters ---------- limit : int Maximum number of samples to check (in each protocol/subset combination) in this dataset. If set to zero, then check everything. Returns ------- errors : int Number of errors found """ logger.info("Checking dataset...") errors = 0 for proto in self._protocols: logger.info(f"Checking protocol '{proto}'...") for name, samples in self.subsets(proto).items(): logger.info(f"Checking subset '{name}'...") if limit: logger.info(f"Checking at most first '{limit}' samples...") samples = samples[:limit] for pos, sample in enumerate(samples): try: sample.data # may trigger data loading logger.info(f"{sample.key}: OK") except Exception as e: logger.error( f"Found error loading entry {pos} in subset {name} " f"of protocol {proto} from file " f"'{self._protocols[proto]}': {e}" ) errors += 1 except Exception as e: logger.error(f"{sample.key}: {e}") errors += 1 return errors
[docs] def subsets(self, protocol): """Returns all subsets in a protocol This method will load JSON information for a given protocol and return all subsets of the given protocol after converting each entry through the loader function. Parameters ---------- protocol : str Name of the protocol data to load Returns ------- subsets : dict A dictionary mapping subset names to lists of objects (respecting the ``key``, ``data`` interface). """ fileobj = self._protocols[protocol] if isinstance(fileobj, (str, bytes, pathlib.Path)): with open(self._protocols[protocol], "r") as f: data = json.load(f) else: data = json.load(f) fileobj.seek(0) retval = {} for subset, samples in data.items(): retval[subset] = [ self._loader( dict(protocol=protocol, subset=subset, order=n), dict(zip(self.fieldnames, k)), ) for n, k in enumerate(samples) ] return retval
[docs]class CSVDataset: """ Generic multi-subset filelist dataset that yields samples To create a new dataset, you only need to provide a CSV formatted filelist using any separator (e.g. comma, space, semi-colon) with the following information: .. code-block:: text value1,value2,value3 value4,value5,value6 ... Notice that all rows must have the same number of entries. Parameters ---------- subsets : list, dict Paths to one or more CSV formatted files containing the various subsets to be recognized by this dataset, or a dictionary, mapping subset names to paths (or opened file objects) of CSV files. Internally, we save a dictionary where keys default to the basename of paths (list input). fieldnames : list, tuple An iterable over the field names (strings) to assign to each column in the CSV file. It should have as many items as fields in each row of the CSV file(s). loader : object A function that receives as input, a context dictionary (with, at least, a "subset" key indicating which subset is being served), and a dictionary with ``{key: path}`` entries, and returns a dictionary with the loaded data. """ def __init__(self, subsets, fieldnames, loader): if isinstance(subsets, dict): self._subsets = subsets else: self._subsets = dict( (os.path.splitext(os.path.basename(k))[0], k) for k in subsets ) self.fieldnames = fieldnames self._loader = loader
[docs] def check(self, limit=0): """For each subset, check if all data can be correctly accessed This function assumes each sample has a ``data`` and a ``key`` attribute. The ``key`` attribute should be a string, or representable as such. Parameters ---------- limit : int Maximum number of samples to check (in each protocol/subset combination) in this dataset. If set to zero, then check everything. Returns ------- errors : int Number of errors found """ logger.info("Checking dataset...") errors = 0 for name in self._subsets.keys(): logger.info(f"Checking subset '{name}'...") samples = self.samples(name) if limit: logger.info(f"Checking at most first '{limit}' samples...") samples = samples[:limit] for pos, sample in enumerate(samples): try: sample.data # may trigger data loading logger.info(f"{sample.key}: OK") except Exception as e: logger.error( f"Found error loading entry {pos} in subset {name} " f"from file '{self._subsets[name]}': {e}" ) errors += 1 return errors
[docs] def subsets(self): """Returns all available subsets at once Returns ------- subsets : dict A dictionary mapping subset names to lists of objects (respecting the ``key``, ``data`` interface). """ return dict((k, self.samples(k)) for k in self._subsets.keys())
[docs] def samples(self, subset): """Returns all samples in a subset This method will load CSV information for a given subset and return all samples of the given subset after passing each entry through the loading function. Parameters ---------- subset : str Name of the subset data to load Returns ------- subset : list A lists of objects (respecting the ``key``, ``data`` interface). """ fileobj = self._subsets[subset] if isinstance(fileobj, (str, bytes, pathlib.Path)): with open(self._subsets[subset], newline="") as f: cf = csv.reader(f) samples = [k for k in cf] else: cf = csv.reader(fileobj) samples = [k for k in cf] fileobj.seek(0) return [ self._loader( dict(subset=subset, order=n), dict(zip(self.fieldnames, k)) ) for n, k in enumerate(samples) ]