Source code for bob.learn.tensorflow.data.generator

import logging
import random

import tensorflow as tf

logger = logging.getLogger(__name__)


class Generator:
    """A generator class which wraps samples so that they can
    be used with tf.data.Dataset.from_generator

    Attributes
    ----------
    epoch : int
        The number of epochs that have been passed so far.

    multiple_samples : :obj:`bool`, optional
        If true, it assumes that the bio database's samples actually contain
        multiple samples. This is useful for when you want to for example treat
        video databases as image databases.

    reader : :obj:`object`, optional
        A callable with the signature of ``data, label, key = reader(sample)``
        which takes a sample and loads it.

    samples : [:obj:`object`]
        A list of samples to be given to ``reader`` to load the data.

    shuffle_on_epoch_end : :obj:`bool`, optional
        If True, it shuffle the samples at the end of each epoch.
    """

    def __init__(
        self,
        samples,
        reader,
        multiple_samples=False,
        shuffle_on_epoch_end=False,
        **kwargs
    ):
        super().__init__(**kwargs)
        self.reader = reader
        self.samples = list(samples)
        self.multiple_samples = multiple_samples
        self.epoch = 0
        self.shuffle_on_epoch_end = shuffle_on_epoch_end

        # load samples until one of them is not empty
        # this data is used to get the type and shape
        for sample in self.samples:
            try:
                dlk = self.reader(sample)
                if self.multiple_samples:
                    try:
                        dlk = dlk[0]
                    except TypeError:
                        # if the data is a generator
                        dlk = next(dlk)
            except StopIteration:
                continue
            else:
                break
        # Creating a "fake" dataset just to get the types and shapes
        dataset = tf.data.Dataset.from_tensors(dlk)
        self._output_types = tf.compat.v1.data.get_output_types(dataset)
        self._output_shapes = tf.compat.v1.data.get_output_shapes(dataset)

        logger.info(
            "Initializing a dataset with %d %s and %s types and %s shapes",
            len(self.samples),
            "multi-samples" if self.multiple_samples else "samples",
            self.output_types,
            self.output_shapes,
        )

    @property
    def output_types(self):
        "The types of the returned samples"
        return self._output_types

    @property
    def output_shapes(self):
        "The shapes of the returned samples"
        return self._output_shapes

    def __call__(self):
        """A generator function that when called will yield the samples.

        Yields
        ------
        object
            Samples one by one.
        """
        for sample in self.samples:
            dlk = self.reader(sample)
            if self.multiple_samples:
                for sub_dlk in dlk:
                    yield sub_dlk
            else:
                yield dlk
        self.epoch += 1
        logger.info("Elapsed %d epoch(s)", self.epoch)
        if self.shuffle_on_epoch_end:
            logger.info("Shuffling samples")
            random.shuffle(self.samples)


[docs]def dataset_using_generator(samples, reader, **kwargs): """ A generator class which wraps samples so that they can be used with tf.data.Dataset.from_generator Parameters ---------- samples : [:obj:`object`] A list of samples to be given to ``reader`` to load the data. reader : :obj:`object`, optional A callable with the signature of ``data, label, key = reader(sample)`` which takes a sample and loads it. **kwargs Extra keyword arguments are passed to Generator Returns ------- object A tf.data.Dataset """ generator = Generator(samples, reader, **kwargs) dataset = tf.data.Dataset.from_generator( generator, generator.output_types, generator.output_shapes ) return dataset