Source code for bob.learn.tensorflow.models.embedding_validation

import tensorflow as tf

from bob.learn.tensorflow.metrics.embedding_accuracy import accuracy_from_embeddings


class EmbeddingValidation(tf.keras.Model):
    """
    Use this model if the validation step should validate the accuracy with respect to embeddings.

    In this model, the `test_step` runs the function `bob.learn.tensorflow.metrics.embedding_accuracy.accuracy_from_embeddings`
    """

[docs] def compile( self, single_precision=False, **kwargs, ): """ Compile """ super().compile(**kwargs) self.train_loss = tf.keras.metrics.Mean(name="accuracy") self.validation_acc = tf.keras.metrics.Mean(name="accuracy")
[docs] def train_step(self, data): """ Train Step """ X, y = data with tf.GradientTape() as tape: logits, _ = self(X, training=True) loss = self.loss(y, logits) # trainable_vars = self.trainable_variables self.optimizer.minimize(loss, self.trainable_variables, tape=tape) self.compiled_metrics.update_state(y, logits, sample_weight=None) self.train_loss(loss) tf.summary.scalar("training_loss", data=loss, step=self._train_counter) return {m.name: m.result() for m in self.metrics + [self.train_loss]}
# self.optimizer.apply_gradients(zip(gradients, trainable_vars)) # self.train_loss(loss) # return {m.name: m.result() for m in [self.train_loss]}
[docs] def test_step(self, data): """ Test Step """ images, labels = data logits, prelogits = self(images, training=False) self.validation_acc(accuracy_from_embeddings(labels, prelogits)) return {m.name: m.result() for m in [self.validation_acc]}