Source code for bob.learn.tensorflow.dataset.bio

from bob.bio.base import read_original_data
from .generator import Generator
import logging

logger = logging.getLogger(__name__)


[docs]class BioGenerator(Generator): """A generator class which wraps bob.bio.base databases so that they can be used with tf.data.Dataset.from_generator Attributes ---------- biofile_to_label : :obj:`object`, optional A callable with the signature of ``label = biofile_to_label(biofile)``. By default -1 is returned as label. database : :any:`bob.bio.base.database.BioDatabase` The database that you want to use. load_data : :obj:`object`, optional A callable with the signature of ``data = load_data(database, biofile)``. :any:`bob.bio.base.read_original_data` is wrapped to be used by default. biofiles : [:any:`bob.bio.base.database.BioFile`] The list of the bio files . keys : [str] The keys of samples obtained by calling ``biofile.make_path("", "")`` labels : [int] The labels obtained by calling ``label = biofile_to_label(biofile)`` """ def __init__( self, database, biofiles, load_data=None, biofile_to_label=None, multiple_samples=False, **kwargs ): if load_data is None: def load_data(database, biofile): data = read_original_data( biofile, database.original_directory, database.original_extension ) return data if biofile_to_label is None: def biofile_to_label(biofile): return -1 self.database = database self.load_data = load_data self.biofile_to_label = biofile_to_label def _reader(f): label = int(self.biofile_to_label(f)) data = self.load_data(self.database, f) key = str(f.make_path("", "")).encode("utf-8") return data, label, key if multiple_samples: def reader(f): data, label, key = _reader(f) for d in data: yield (d, label, key) else: def reader(f): return _reader(f) super(BioGenerator, self).__init__( biofiles, reader, multiple_samples=multiple_samples, **kwargs ) @property def labels(self): for f in self.biofiles: yield int(self.biofile_to_label(f)) @property def keys(self): for f in self.biofiles: yield str(f.make_path("", "")).encode("utf-8") @property def biofiles(self): return self.samples def __len__(self): return len(self.biofiles)