Source code for bob.learn.tensorflow.losses.pixel_wise

import tensorflow as tf

from ..utils import tf_repeat
from .balanced_cross_entropy import balanced_sigmoid_cross_entropy_loss_weights


def get_pixel_wise_labels(labels, n_pixels):
    labels = tf.reshape(labels, (-1, 1))
    labels = tf_repeat(labels, [n_pixels, 1])
    labels = tf.reshape(labels, (-1, n_pixels))
    return labels


class PixelwiseBinaryCrossentropy(tf.keras.losses.Loss):
    """A pixel wise loss which is just a cross entropy loss but applied to all pixels.
    Appeared in::

        @inproceedings{GeorgeICB2019,
            author = {Anjith George, Sebastien Marcel},
            title = {Deep Pixel-wise Binary Supervision for Face Presentation Attack Detection},
            year = {2019},
            booktitle = {ICB 2019},
        }
    """

    def __init__(
        self,
        balance_weights=True,
        label_smoothing=0.5,
        name="pixel_wise_binary_cross_entropy",
        **kwargs
    ):
        """
        Parameters
        ----------
        balance_weights : bool, optional
            Whether the loss should be balanced per samples of different classes in each batch, by default True
        label_smoothing : float, optional
            Label smoothing, by default 0.5
        """
        super().__init__(name=name, **kwargs)
        self.balance_weights = balance_weights
        self.label_smoothing = label_smoothing

[docs] def call(self, labels, logits): n_pixels = logits.shape[-1] pixel_wise_labels = get_pixel_wise_labels(labels, n_pixels) # per batch weighting, different from Keras's sample weights. weights = 1.0 if self.balance_weights: # use labels to figure out the required loss weights weights = balanced_sigmoid_cross_entropy_loss_weights( labels, dtype=logits.dtype ) loss_pixel_wise = tf.keras.losses.binary_crossentropy( y_true=pixel_wise_labels, y_pred=logits, label_smoothing=self.label_smoothing, from_logits=True, ) loss_pixel_wise = loss_pixel_wise * weights return loss_pixel_wise