Source code for bob.learn.tensorflow.dataset.tfrecords

from functools import partial
import tensorflow as tf
from . import append_image_augmentation, DEFAULT_FEATURE
import os
import logging
logger = logging.getLogger(__name__)


[docs]def example_parser(serialized_example, feature, data_shape, data_type): """ Parses a single tf.Example into image and label tensors. """ # Decode the record read by the reader features = tf.parse_single_example(serialized_example, features=feature) # Convert the image data from string back to the numbers image = tf.decode_raw(features['data'], data_type) # Cast label data into int64 label = tf.cast(features['label'], tf.int64) # Reshape image data into the original shape image = tf.reshape(image, data_shape) key = tf.cast(features['key'], tf.string) return image, label, key
[docs]def image_augmentation_parser(serialized_example, feature, data_shape, data_type, gray_scale=False, output_shape=None, random_flip=False, random_brightness=False, random_contrast=False, random_saturation=False, random_rotate=False, per_image_normalization=True): """ Parses a single tf.Example into image and label tensors. """ # Decode the record read by the reader features = tf.parse_single_example(serialized_example, features=feature) # Convert the image data from string back to the numbers image = tf.decode_raw(features['data'], data_type) # Reshape image data into the original shape image = tf.reshape(image, data_shape) # Applying image augmentation image = append_image_augmentation( image, gray_scale=gray_scale, output_shape=output_shape, random_flip=random_flip, random_brightness=random_brightness, random_contrast=random_contrast, random_saturation=random_saturation, random_rotate=random_rotate, per_image_normalization=per_image_normalization) # Cast label data into int64 label = tf.cast(features['label'], tf.int64) key = tf.cast(features['key'], tf.string) return image, label, key
[docs]def read_and_decode(filename_queue, data_shape, data_type=tf.float32, feature=None): """ Simples parse possible for a tfrecord. It assumes that you have the pair **train/data** and **train/label** """ if feature is None: feature = DEFAULT_FEATURE # Define a reader and read the next record reader = tf.TFRecordReader() _, serialized_example = reader.read(filename_queue) return example_parser(serialized_example, feature, data_shape, data_type)
[docs]def create_dataset_from_records(tfrecord_filenames, data_shape, data_type, feature=None): """ Create dataset from a list of tf-record files **Parameters** tfrecord_filenames: List containing the tf-record paths data_shape: Samples shape saved in the tf-record data_type: tf data type(https://www.tensorflow.org/versions/r0.12/resources/dims_types#data_types) feature: """ if feature is None: feature = DEFAULT_FEATURE dataset = tf.data.TFRecordDataset(tfrecord_filenames) parser = partial( example_parser, feature=feature, data_shape=data_shape, data_type=data_type) dataset = dataset.map(parser) return dataset
[docs]def create_dataset_from_records_with_augmentation( tfrecord_filenames, data_shape, data_type, feature=None, gray_scale=False, output_shape=None, random_flip=False, random_brightness=False, random_contrast=False, random_saturation=False, random_rotate=False, per_image_normalization=True): """ Create dataset from a list of tf-record files **Parameters** tfrecord_filenames: List containing the tf-record paths data_shape: Samples shape saved in the tf-record data_type: tf data type(https://www.tensorflow.org/versions/r0.12/resources/dims_types#data_types) feature: """ if feature is None: feature = DEFAULT_FEATURE dataset = tf.data.TFRecordDataset(tfrecord_filenames) parser = partial( image_augmentation_parser, feature=feature, data_shape=data_shape, data_type=data_type, gray_scale=gray_scale, output_shape=output_shape, random_flip=random_flip, random_brightness=random_brightness, random_contrast=random_contrast, random_saturation=random_saturation, random_rotate=random_rotate, per_image_normalization=per_image_normalization) dataset = dataset.map(parser) return dataset
[docs]def shuffle_data_and_labels_image_augmentation(tfrecord_filenames, data_shape, data_type, batch_size, epochs=None, buffer_size=10**3, gray_scale=False, output_shape=None, random_flip=False, random_brightness=False, random_contrast=False, random_saturation=False, random_rotate=False, per_image_normalization=True): """ Dump random batches from a list of tf-record files and applies some image augmentation **Parameters** tfrecord_filenames: List containing the tf-record paths data_shape: Samples shape saved in the tf-record data_type: tf data type(https://www.tensorflow.org/versions/r0.12/resources/dims_types#data_types) batch_size: Size of the batch epochs: Number of epochs to be batched buffer_size: Size of the shuffle bucket gray_scale: Convert to gray scale? output_shape: If set, will randomly crop the image given the output shape random_flip: Randomly flip an image horizontally (https://www.tensorflow.org/api_docs/python/tf/image/random_flip_left_right) random_brightness: Adjust the brightness of an RGB image by a random factor (https://www.tensorflow.org/api_docs/python/tf/image/random_brightness) random_contrast: Adjust the contrast of an RGB image by a random factor (https://www.tensorflow.org/api_docs/python/tf/image/random_contrast) random_saturation: Adjust the saturation of an RGB image by a random factor (https://www.tensorflow.org/api_docs/python/tf/image/random_saturation) random_rotate: Randomly rotate face images between -5 and 5 degrees per_image_normalization: Linearly scales image to have zero mean and unit norm. """ dataset = create_dataset_from_records_with_augmentation( tfrecord_filenames, data_shape, data_type, gray_scale=gray_scale, output_shape=output_shape, random_flip=random_flip, random_brightness=random_brightness, random_contrast=random_contrast, random_saturation=random_saturation, random_rotate=random_rotate, per_image_normalization=per_image_normalization) dataset = dataset.shuffle(buffer_size).batch(batch_size).repeat(epochs) data, labels, key = dataset.make_one_shot_iterator().get_next() features = dict() features['data'] = data features['key'] = key return features, labels
[docs]def shuffle_data_and_labels(tfrecord_filenames, data_shape, data_type, batch_size, epochs=None, buffer_size=10**3): """ Dump random batches from a list of tf-record files **Parameters** tfrecord_filenames: List containing the tf-record paths data_shape: Samples shape saved in the tf-record data_type: tf data type(https://www.tensorflow.org/versions/r0.12/resources/dims_types#data_types) batch_size: Size of the batch epochs: Number of epochs to be batched buffer_size: Size of the shuffle bucket """ dataset = create_dataset_from_records(tfrecord_filenames, data_shape, data_type) dataset = dataset.shuffle(buffer_size).batch(batch_size).repeat(epochs) data, labels, key = dataset.make_one_shot_iterator().get_next() features = dict() features['data'] = data features['key'] = key return features, labels
[docs]def batch_data_and_labels(tfrecord_filenames, data_shape, data_type, batch_size, epochs=1): """ Dump in order batches from a list of tf-record files **Parameters** tfrecord_filenames: List containing the tf-record paths data_shape: Samples shape saved in the tf-record data_type: tf data type(https://www.tensorflow.org/versions/r0.12/resources/dims_types#data_types) batch_size: Size of the batch epochs: Number of epochs to be batched """ dataset = create_dataset_from_records(tfrecord_filenames, data_shape, data_type) dataset = dataset.batch(batch_size).repeat(epochs) data, labels, key = dataset.make_one_shot_iterator().get_next() features = dict() features['data'] = data features['key'] = key return features, labels
[docs]def batch_data_and_labels_image_augmentation(tfrecord_filenames, data_shape, data_type, batch_size, epochs=1, gray_scale=False, output_shape=None, random_flip=False, random_brightness=False, random_contrast=False, random_saturation=False, random_rotate=False, per_image_normalization=True): """ Dump in order batches from a list of tf-record files **Parameters** tfrecord_filenames: List containing the tf-record paths data_shape: Samples shape saved in the tf-record data_type: tf data type(https://www.tensorflow.org/versions/r0.12/resources/dims_types#data_types) batch_size: Size of the batch epochs: Number of epochs to be batched """ dataset = create_dataset_from_records_with_augmentation( tfrecord_filenames, data_shape, data_type, gray_scale=gray_scale, output_shape=output_shape, random_flip=random_flip, random_brightness=random_brightness, random_contrast=random_contrast, random_saturation=random_saturation, random_rotate=random_rotate, per_image_normalization=per_image_normalization) dataset = dataset.batch(batch_size).repeat(epochs) data, labels, key = dataset.make_one_shot_iterator().get_next() features = dict() features['data'] = data features['key'] = key return features, labels
[docs]def describe_tf_record(tf_record_path, shape, batch_size=1): """ Describe the number of samples and the number of classes of a tf-record Parameters ---------- tf_record_path: str Base path containing your tf-record files shape: tuple Shape inside of the tf-record batch_size: int Well, batch size Returns ------- n_samples: int Total number of samples n_classes: int Total number of classes """ tf_records = [os.path.join(tf_record_path, f) for f in os.listdir(tf_record_path)] filename_queue = tf.train.string_input_producer(tf_records, num_epochs=1, name="input") feature = {'data': tf.FixedLenFeature([], tf.string), 'label': tf.FixedLenFeature([], tf.int64), 'key': tf.FixedLenFeature([], tf.string) } # Define a reader and read the next record reader = tf.TFRecordReader() _, serialized_example = reader.read(filename_queue) # Decode the record read by the reader features = tf.parse_single_example(serialized_example, features=feature) # Convert the image data from string back to the numbers image = tf.decode_raw(features['data'], tf.uint8) # Cast label data into int32 label = tf.cast(features['label'], tf.int64) img_name = tf.cast(features['key'], tf.string) # Reshape image data into the original shape image = tf.reshape(image, shape) # Getting the batches in order data_ph, label_ph, img_name_ph = tf.train.batch([image, label, img_name], batch_size=batch_size, capacity=1000, num_threads=5, name="shuffle_batch") # Start the reading session = tf.Session() tf.local_variables_initializer().run(session=session) tf.global_variables_initializer().run(session=session) # Preparing the batches thread_pool = tf.train.Coordinator() threads = tf.train.start_queue_runners(coord=thread_pool, sess=session) logger.info("Counting in %s", tf_record_path) labels = set() counter = 0 try: while(True): _, label, _ = session.run([data_ph, label_ph, img_name_ph]) counter += len(label) for i in set(label): labels.add(i) except tf.errors.OutOfRangeError: pass thread_pool.request_stop() return counter, len(labels)