Source code for bob.learn.tensorflow.dataset.tfrecords

"""Utilities for TFRecords
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from functools import partial
import json
import logging
import os
import sys

import tensorflow as tf

from . import append_image_augmentation, DEFAULT_FEATURE


logger = logging.getLogger(__name__)
TFRECORDS_EXT = ".tfrecords"


[docs]def tfrecord_name_and_json_name(output): output = normalize_tfrecords_path(output) json_output = output[: -len(TFRECORDS_EXT)] + ".json" return output, json_output
[docs]def normalize_tfrecords_path(output): if not output.endswith(TFRECORDS_EXT): output += TFRECORDS_EXT return output
[docs]def bytes_feature(value): return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
[docs]def int64_feature(value): return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
[docs]def dataset_to_tfrecord(dataset, output): """Writes a tf.data.Dataset into a TFRecord file. Parameters ---------- dataset : ``tf.data.Dataset`` The tf.data.Dataset that you want to write into a TFRecord file. output : str Path to the TFRecord file. Besides this file, a .json file is also created. This json file is needed when you want to convert the TFRecord file back into a dataset. Returns ------- ``tf.Operation`` A tf.Operation that, when run, writes contents of dataset to a file. When running in eager mode, calling this function will write the file. Otherwise, you have to call session.run() on the returned operation. """ output, json_output = tfrecord_name_and_json_name(output) # dump the structure so that we can read it back meta = { "output_types": repr(dataset.output_types), "output_shapes": repr(dataset.output_shapes), } with open(json_output, "w") as f: json.dump(meta, f) # create a custom map function that serializes the dataset def serialize_example_pyfunction(*args): feature = {} for i, f in enumerate(args): key = f"feature{i}" feature[key] = bytes_feature(f) example_proto = tf.train.Example(features=tf.train.Features(feature=feature)) return example_proto.SerializeToString() def tf_serialize_example(*args): args = tf.contrib.framework.nest.flatten(args) args = [tf.serialize_tensor(f) for f in args] tf_string = tf.py_func(serialize_example_pyfunction, args, tf.string) return tf.reshape(tf_string, ()) # The result is a scalar dataset = dataset.map(tf_serialize_example) writer = tf.data.experimental.TFRecordWriter(output) return writer.write(dataset)
[docs]def dataset_from_tfrecord(tfrecord): """Reads TFRecords and returns a dataset. The TFRecord file must have been created using the :any:`dataset_to_tfrecord` function. Parameters ---------- tfrecord : str or list Path to the TFRecord file. Pass a list if you are sure several tfrecords need the same map function. Returns ------- ``tf.data.Dataset`` A dataset that contains the data from the TFRecord file. """ # these imports are needed so that eval can work from tensorflow import TensorShape, Dimension if isinstance(tfrecord, str): tfrecord = [tfrecord] tfrecord = [tfrecord_name_and_json_name(path) for path in tfrecord] json_output = tfrecord[0][1] tfrecord = [path[0] for path in tfrecord] raw_dataset = tf.data.TFRecordDataset(tfrecord) with open(json_output) as f: meta = json.load(f) for k, v in meta.items(): meta[k] = eval(v) output_types = tf.contrib.framework.nest.flatten(meta["output_types"]) output_shapes = tf.contrib.framework.nest.flatten(meta["output_shapes"]) feature_description = {} for i in range(len(output_types)): key = f"feature{i}" feature_description[key] = tf.FixedLenFeature([], tf.string) def _parse_function(example_proto): # Parse the input tf.Example proto using the dictionary above. args = tf.parse_single_example(example_proto, feature_description) args = tf.contrib.framework.nest.flatten(args) args = [tf.parse_tensor(v, t) for v, t in zip(args, output_types)] args = [tf.reshape(v, s) for v, s in zip(args, output_shapes)] return tf.contrib.framework.nest.pack_sequence_as(meta["output_types"], args) return raw_dataset.map(_parse_function)
[docs]def write_a_sample(writer, data, label, key, feature=None, size_estimate=False): if feature is None: feature = { "data": bytes_feature(data.tostring()), "label": int64_feature(label), "key": bytes_feature(key), } example = tf.train.Example(features=tf.train.Features(feature=feature)) example = example.SerializeToString() if not size_estimate: writer.write(example) return sys.getsizeof(example)
[docs]def example_parser(serialized_example, feature, data_shape, data_type): """ Parses a single tf.Example into image and label tensors. """ # Decode the record read by the reader features = tf.parse_single_example(serialized_example, features=feature) # Convert the image data from string back to the numbers image = tf.decode_raw(features["data"], data_type) # Cast label data into int64 label = tf.cast(features["label"], tf.int64) # Reshape image data into the original shape image = tf.reshape(image, data_shape) key = tf.cast(features["key"], tf.string) return image, label, key
[docs]def image_augmentation_parser( serialized_example, feature, data_shape, data_type, gray_scale=False, output_shape=None, random_flip=False, random_brightness=False, random_contrast=False, random_saturation=False, random_rotate=False, per_image_normalization=True, random_gamma=False, random_crop=False, ): """ Parses a single tf.Example into image and label tensors. """ # Decode the record read by the reader features = tf.parse_single_example(serialized_example, features=feature) # Convert the image data from string back to the numbers image = tf.decode_raw(features["data"], data_type) # Reshape image data into the original shape image = tf.reshape(image, data_shape) # Applying image augmentation image = append_image_augmentation( image, gray_scale=gray_scale, output_shape=output_shape, random_flip=random_flip, random_brightness=random_brightness, random_contrast=random_contrast, random_saturation=random_saturation, random_rotate=random_rotate, per_image_normalization=per_image_normalization, random_gamma=random_gamma, random_crop=random_crop, ) # Cast label data into int64 label = tf.cast(features["label"], tf.int64) key = tf.cast(features["key"], tf.string) return image, label, key
[docs]def read_and_decode(filename_queue, data_shape, data_type=tf.float32, feature=None): """ Simples parse possible for a tfrecord. It assumes that you have the pair **train/data** and **train/label** """ if feature is None: feature = DEFAULT_FEATURE # Define a reader and read the next record reader = tf.TFRecordReader() _, serialized_example = reader.read(filename_queue) return example_parser(serialized_example, feature, data_shape, data_type)
[docs]def create_dataset_from_records( tfrecord_filenames, data_shape, data_type, feature=None ): """ Create dataset from a list of tf-record files **Parameters** tfrecord_filenames: List containing the tf-record paths data_shape: Samples shape saved in the tf-record data_type: tf data type(https://www.tensorflow.org/versions/r0.12/resources/dims_types#data_types) feature: """ if feature is None: feature = DEFAULT_FEATURE dataset = tf.data.TFRecordDataset(tfrecord_filenames) parser = partial( example_parser, feature=feature, data_shape=data_shape, data_type=data_type ) dataset = dataset.map(parser) return dataset
[docs]def create_dataset_from_records_with_augmentation( tfrecord_filenames, data_shape, data_type, feature=None, gray_scale=False, output_shape=None, random_flip=False, random_brightness=False, random_contrast=False, random_saturation=False, random_rotate=False, per_image_normalization=True, random_gamma=False, random_crop=False, ): """ Create dataset from a list of tf-record files **Parameters** tfrecord_filenames: List containing the tf-record paths data_shape: Samples shape saved in the tf-record data_type: tf data type(https://www.tensorflow.org/versions/r0.12/resources/dims_types#data_types) feature: """ if feature is None: feature = DEFAULT_FEATURE if isinstance(tfrecord_filenames, str) and os.path.isdir(tfrecord_filenames): tfrecord_filenames = [ os.path.join(tfrecord_filenames, f) for f in os.listdir(tfrecord_filenames) ] dataset = tf.data.TFRecordDataset(tfrecord_filenames) parser = partial( image_augmentation_parser, feature=feature, data_shape=data_shape, data_type=data_type, gray_scale=gray_scale, output_shape=output_shape, random_flip=random_flip, random_brightness=random_brightness, random_contrast=random_contrast, random_saturation=random_saturation, random_rotate=random_rotate, per_image_normalization=per_image_normalization, random_gamma=random_gamma, random_crop=random_crop, ) dataset = dataset.map(parser) return dataset
[docs]def shuffle_data_and_labels_image_augmentation( tfrecord_filenames, data_shape, data_type, batch_size, epochs=None, buffer_size=10 ** 3, gray_scale=False, output_shape=None, random_flip=False, random_brightness=False, random_contrast=False, random_saturation=False, random_rotate=False, per_image_normalization=True, random_gamma=False, random_crop=False, drop_remainder=False, ): """Dump random batches from a list of tf-record files and applies some image augmentation Parameters ---------- tfrecord_filenames: List containing the tf-record paths data_shape: Samples shape saved in the tf-record data_type: tf data type(https://www.tensorflow.org/versions/r0.12/resources/dims_types#data_types) batch_size: Size of the batch epochs: Number of epochs to be batched buffer_size: Size of the shuffle bucket gray_scale: Convert to gray scale? output_shape: If set, will randomly crop the image given the output shape random_flip: Randomly flip an image horizontally (https://www.tensorflow.org/api_docs/python/tf/image/random_flip_left_right) random_brightness: Adjust the brightness of an RGB image by a random factor (https://www.tensorflow.org/api_docs/python/tf/image/random_brightness) random_contrast: Adjust the contrast of an RGB image by a random factor (https://www.tensorflow.org/api_docs/python/tf/image/random_contrast) random_saturation: Adjust the saturation of an RGB image by a random factor (https://www.tensorflow.org/api_docs/python/tf/image/random_saturation) random_rotate: Randomly rotate face images between -5 and 5 degrees per_image_normalization: Linearly scales image to have zero mean and unit norm. drop_remainder: If True, the last remaining batch that has smaller size than batch_size will be dropped. """ dataset = create_dataset_from_records_with_augmentation( tfrecord_filenames, data_shape, data_type, gray_scale=gray_scale, output_shape=output_shape, random_flip=random_flip, random_brightness=random_brightness, random_contrast=random_contrast, random_saturation=random_saturation, random_rotate=random_rotate, per_image_normalization=per_image_normalization, random_gamma=random_gamma, random_crop=random_crop, ) dataset = dataset.shuffle(buffer_size) dataset = dataset.batch(batch_size, drop_remainder=drop_remainder) dataset = dataset.repeat(epochs) dataset = dataset.map(lambda d, l, k: ({"data": d, "key": k}, l)) return dataset
[docs]def shuffle_data_and_labels( tfrecord_filenames, data_shape, data_type, batch_size, epochs=None, buffer_size=10 ** 3, ): """ Dump random batches from a list of tf-record files **Parameters** tfrecord_filenames: List containing the tf-record paths data_shape: Samples shape saved in the tf-record data_type: tf data type(https://www.tensorflow.org/versions/r0.12/resources/dims_types#data_types) batch_size: Size of the batch epochs: Number of epochs to be batched buffer_size: Size of the shuffle bucket """ dataset = create_dataset_from_records(tfrecord_filenames, data_shape, data_type) dataset = dataset.shuffle(buffer_size).batch(batch_size).repeat(epochs) data, labels, key = dataset.make_one_shot_iterator().get_next() features = dict() features["data"] = data features["key"] = key return features, labels
[docs]def batch_data_and_labels( tfrecord_filenames, data_shape, data_type, batch_size, epochs=1 ): """ Dump in order batches from a list of tf-record files Parameters ---------- tfrecord_filenames: List containing the tf-record paths data_shape: Samples shape saved in the tf-record data_type: tf data type(https://www.tensorflow.org/versions/r0.12/resources/dims_types#data_types) batch_size: Size of the batch epochs: Number of epochs to be batched """ dataset = create_dataset_from_records(tfrecord_filenames, data_shape, data_type) dataset = dataset.batch(batch_size).repeat(epochs) data, labels, key = dataset.make_one_shot_iterator().get_next() features = dict() features["data"] = data features["key"] = key return features, labels
[docs]def batch_data_and_labels_image_augmentation( tfrecord_filenames, data_shape, data_type, batch_size, epochs=1, gray_scale=False, output_shape=None, random_flip=False, random_brightness=False, random_contrast=False, random_saturation=False, random_rotate=False, per_image_normalization=True, random_gamma=False, random_crop=False, drop_remainder=False, ): """ Dump in order batches from a list of tf-record files Parameters ---------- tfrecord_filenames: List containing the tf-record paths data_shape: Samples shape saved in the tf-record data_type: tf data type(https://www.tensorflow.org/versions/r0.12/resources/dims_types#data_types) batch_size: Size of the batch epochs: Number of epochs to be batched drop_remainder: If True, the last remaining batch that has smaller size than batch_size will be dropped. """ dataset = create_dataset_from_records_with_augmentation( tfrecord_filenames, data_shape, data_type, gray_scale=gray_scale, output_shape=output_shape, random_flip=random_flip, random_brightness=random_brightness, random_contrast=random_contrast, random_saturation=random_saturation, random_rotate=random_rotate, per_image_normalization=per_image_normalization, random_gamma=random_gamma, random_crop=random_crop, ) dataset = dataset.batch(batch_size, drop_remainder=drop_remainder) dataset = dataset.repeat(epochs) data, labels, key = dataset.make_one_shot_iterator().get_next() features = dict() features["data"] = data features["key"] = key return features, labels
[docs]def describe_tf_record(tf_record_path, shape, batch_size=1): """ Describe the number of samples and the number of classes of a tf-record Parameters ---------- tf_record_path: str Base path containing your tf-record files shape: tuple Shape inside of the tf-record batch_size: int Well, batch size Returns ------- n_samples: int Total number of samples n_classes: int Total number of classes """ tf_records = [os.path.join(tf_record_path, f) for f in os.listdir(tf_record_path)] filename_queue = tf.train.string_input_producer( tf_records, num_epochs=1, name="input" ) feature = { "data": tf.FixedLenFeature([], tf.string), "label": tf.FixedLenFeature([], tf.int64), "key": tf.FixedLenFeature([], tf.string), } # Define a reader and read the next record reader = tf.TFRecordReader() _, serialized_example = reader.read(filename_queue) # Decode the record read by the reader features = tf.parse_single_example(serialized_example, features=feature) # Convert the image data from string back to the numbers image = tf.decode_raw(features["data"], tf.uint8) # Cast label data into int32 label = tf.cast(features["label"], tf.int64) img_name = tf.cast(features["key"], tf.string) # Reshape image data into the original shape image = tf.reshape(image, shape) # Getting the batches in order data_ph, label_ph, img_name_ph = tf.train.batch( [image, label, img_name], batch_size=batch_size, capacity=1000, num_threads=5, name="shuffle_batch", ) # Start the reading session = tf.Session() tf.local_variables_initializer().run(session=session) tf.global_variables_initializer().run(session=session) # Preparing the batches thread_pool = tf.train.Coordinator() threads = tf.train.start_queue_runners(coord=thread_pool, sess=session) logger.info("Counting in %s", tf_record_path) labels = set() counter = 0 try: while True: _, label, _ = session.run([data_ph, label_ph, img_name_ph]) counter += len(label) for i in set(label): labels.add(i) except tf.errors.OutOfRangeError: pass thread_pool.request_stop() return counter, len(labels)