Source code for bob.learn.tensorflow.estimators

#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>

import tensorflow as tf


[docs]def check_features(features): if "data" not in features or "key" not in features: raise ValueError( "The input function needs to contain a dictionary with the keys `data` and `key` " ) return True
[docs]def get_trainable_variables(extra_checkpoint, mode=tf.estimator.ModeKeys.TRAIN): """ Given the extra_checkpoint dictionary provided to the estimator, extract the content of "trainable_variables". If trainable_variables is not provided, all end points are trainable by default. If trainable_variables==[], all end points are NOT trainable. If trainable_variables contains some end_points, ONLY these endpoints will be trainable. Attributes ---------- extra_checkpoint: dict The extra_checkpoint dictionary provided to the estimator mode: The estimator mode. TRAIN, EVAL, and PREDICT. If not TRAIN, None is returned. Returns ------- Returns `None` if **trainable_variables** is not in extra_checkpoint; otherwise returns the content of extra_checkpoint . """ if mode != tf.estimator.ModeKeys.TRAIN: return None # If you don't set anything, everything is trainable if extra_checkpoint is None or "trainable_variables" not in extra_checkpoint: return None return extra_checkpoint["trainable_variables"]
from .utils import MovingAverageOptimizer, learning_rate_decay_fn from .Logits import Logits, LogitsCenterLoss from .Siamese import Siamese from .Triplet import Triplet from .Regressor import Regressor # gets sphinx autodoc done right - don't remove it def __appropriate__(*args): """Says object was actually declared here, an not on the import module. Parameters: *args: An iterable of objects to modify Resolves `Sphinx referencing issues <https://github.com/sphinx-doc/sphinx/issues/3048>` """ for obj in args: obj.__module__ = __name__ __appropriate__( Logits, LogitsCenterLoss, Siamese, Triplet, Regressor, MovingAverageOptimizer ) __all__ = [_ for _ in dir() if not _.startswith("_")]