Source code for bob.ip.binseg.models.losses

"""Loss implementations"""

import torch

from torch.nn.modules.loss import _Loss


[docs]class WeightedBCELogitsLoss(_Loss): """Calculates sum of weighted cross entropy loss. Implements Equation 1 in [MANINIS-2016]_. The weight depends on the current proportion between negatives and positives in the ground-truth sample being analyzed. """ def __init__(self): super(WeightedBCELogitsLoss, self).__init__()
[docs] def forward(self, input, target, mask): """ Parameters ---------- input : :py:class:`torch.Tensor` Value produced by the model to be evaluated, with the shape ``[n, c, h, w]`` target : :py:class:`torch.Tensor` Ground-truth information with the shape ``[n, c, h, w]`` mask : :py:class:`torch.Tensor` Mask to be use for specifying the region of interest where to compute the loss, with the shape ``[n, c, h, w]`` Returns ------- loss : :py:class:`torch.Tensor` The average loss for all input data """ # calculates the proportion of negatives to the total number of pixels # available in the masked region valid = mask > 0.5 num_pos = target[valid].sum() num_neg = valid.sum() - num_pos pos_weight = num_neg / num_pos return torch.nn.functional.binary_cross_entropy_with_logits( input[valid], target[valid], reduction="mean", pos_weight=pos_weight )
[docs]class SoftJaccardBCELogitsLoss(_Loss): """ Implements the generalized loss function of Equation (3) in [IGLOVIKOV-2018]_, with J being the Jaccard distance, and H, the Binary Cross-Entropy Loss: .. math:: L = \alpha H + (1-\alpha)(1-J) Our implementation is based on :py:class:`torch.nn.BCEWithLogitsLoss`. Attributes ---------- alpha : float determines the weighting of J and H. Default: ``0.7`` """ def __init__(self, alpha=0.7): super(SoftJaccardBCELogitsLoss, self).__init__() self.alpha = alpha
[docs] def forward(self, input, target, mask): """ Parameters ---------- input : :py:class:`torch.Tensor` Value produced by the model to be evaluated, with the shape ``[n, c, h, w]`` target : :py:class:`torch.Tensor` Ground-truth information with the shape ``[n, c, h, w]`` mask : :py:class:`torch.Tensor` Mask to be use for specifying the region of interest where to compute the loss, with the shape ``[n, c, h, w]`` Returns ------- loss : :py:class:`torch.Tensor` Loss, in a single entry """ eps = 1e-8 valid = mask > 0.5 probabilities = torch.sigmoid(input[valid]) intersection = (probabilities * target[valid]).sum() sums = probabilities.sum() + target[valid].sum() J = intersection / (sums - intersection + eps) # this implements the support for looking just into the RoI H = torch.nn.functional.binary_cross_entropy_with_logits( input[valid], target[valid], reduction="mean" ) return (self.alpha * H) + ((1 - self.alpha) * (1 - J))
[docs]class MultiWeightedBCELogitsLoss(WeightedBCELogitsLoss): """ Weighted Binary Cross-Entropy Loss for multi-layered inputs (e.g. for Holistically-Nested Edge Detection in [XIE-2015]_). """ def __init__(self): super(MultiWeightedBCELogitsLoss, self).__init__()
[docs] def forward(self, input, target, mask): """ Parameters ---------- input : iterable over :py:class:`torch.Tensor` Value produced by the model to be evaluated, with the shape ``[L, n, c, h, w]`` target : :py:class:`torch.Tensor` Ground-truth information with the shape ``[n, c, h, w]`` mask : :py:class:`torch.Tensor` Mask to be use for specifying the region of interest where to compute the loss, with the shape ``[n, c, h, w]`` Returns ------- loss : torch.Tensor The average loss for all input data """ return torch.cat( [ super(MultiWeightedBCELogitsLoss, self) .forward(i, target, mask) .unsqueeze(0) for i in input ] ).mean()
[docs]class MultiSoftJaccardBCELogitsLoss(SoftJaccardBCELogitsLoss): """ Implements Equation 3 in [IGLOVIKOV-2018]_ for the multi-output networks such as HED or Little W-Net. Attributes ---------- alpha : float determines the weighting of SoftJaccard and BCE. Default: ``0.3`` """ def __init__(self, alpha=0.7): super(MultiSoftJaccardBCELogitsLoss, self).__init__(alpha=alpha)
[docs] def forward(self, input, target, mask): """ Parameters ---------- input : iterable over :py:class:`torch.Tensor` Value produced by the model to be evaluated, with the shape ``[L, n, c, h, w]`` target : :py:class:`torch.Tensor` Ground-truth information with the shape ``[n, c, h, w]`` mask : :py:class:`torch.Tensor` Mask to be use for specifying the region of interest where to compute the loss, with the shape ``[n, c, h, w]`` Returns ------- loss : torch.Tensor The average loss for all input data """ return torch.cat( [ super(MultiSoftJaccardBCELogitsLoss, self) .forward(i, target, mask) .unsqueeze(0) for i in input ] ).mean()
[docs]class MixJacLoss(_Loss): """ Parameters ---------- lambda_u : int determines the weighting of SoftJaccard and BCE. """ def __init__( self, lambda_u=100, jacalpha=0.7, size_average=None, reduce=None, reduction="mean", pos_weight=None, ): super(MixJacLoss, self).__init__(size_average, reduce, reduction) self.lambda_u = lambda_u self.labeled_loss = SoftJaccardBCELogitsLoss(alpha=jacalpha) self.unlabeled_loss = torch.nn.BCEWithLogitsLoss()
[docs] def forward( self, input, target, unlabeled_input, unlabeled_target, ramp_up_factor ): """ Parameters ---------- input : :py:class:`torch.Tensor` target : :py:class:`torch.Tensor` unlabeled_input : :py:class:`torch.Tensor` unlabeled_target : :py:class:`torch.Tensor` ramp_up_factor : float Returns ------- list """ ll = self.labeled_loss(input, target) ul = self.unlabeled_loss(unlabeled_input, unlabeled_target) loss = ll + self.lambda_u * ramp_up_factor * ul return loss, ll, ul