Coverage for src/deepdraw/models/losses.py: 63%
46 statements
« prev ^ index » next coverage.py v7.3.1, created at 2023-11-30 15:00 +0100
« prev ^ index » next coverage.py v7.3.1, created at 2023-11-30 15:00 +0100
1# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
2#
3# SPDX-License-Identifier: GPL-3.0-or-later
5"""Loss implementations."""
7import torch
9from torch.nn.modules.loss import _Loss
12class WeightedBCELogitsLoss(_Loss):
13 """Calculates sum of weighted cross entropy loss.
15 Implements Equation 1 in [MANINIS-2016]_. The weight depends on the
16 current proportion between negatives and positives in the ground-
17 truth sample being analyzed.
18 """
20 def __init__(self):
21 super().__init__()
23 def forward(self, input, target, mask):
24 """
26 Parameters
27 ----------
29 input : :py:class:`torch.Tensor`
30 Value produced by the model to be evaluated, with the shape ``[n, c,
31 h, w]``
33 target : :py:class:`torch.Tensor`
34 Ground-truth information with the shape ``[n, c, h, w]``
36 mask : :py:class:`torch.Tensor`
37 Mask to be use for specifying the region of interest where to
38 compute the loss, with the shape ``[n, c, h, w]``
41 Returns
42 -------
44 loss : :py:class:`torch.Tensor`
45 The average loss for all input data
47 """
49 # calculates the proportion of negatives to the total number of pixels
50 # available in the masked region
51 valid = mask > 0.5
52 num_pos = target[valid].sum()
53 num_neg = valid.sum() - num_pos
54 pos_weight = num_neg / num_pos
56 return torch.nn.functional.binary_cross_entropy_with_logits(
57 input[valid], target[valid], reduction="mean", pos_weight=pos_weight
58 )
61class SoftJaccardBCELogitsLoss(_Loss):
62 """Implements the generalized loss function of Equation (3) in.
64 [IGLOVIKOV-2018]_, with J being the Jaccard distance, and H, the Binary
65 Cross-Entropy Loss:
67 .. math::
69 L = \alpha H + (1-\alpha)(1-J)
72 Our implementation is based on :py:class:`torch.nn.BCEWithLogitsLoss`.
75 Attributes
76 ----------
78 alpha : float
79 determines the weighting of J and H. Default: ``0.7``
80 """
82 def __init__(self, alpha=0.7):
83 super().__init__()
84 self.alpha = alpha
86 def forward(self, input, target, mask):
87 """
89 Parameters
90 ----------
92 input : :py:class:`torch.Tensor`
93 Value produced by the model to be evaluated, with the shape ``[n, c,
94 h, w]``
96 target : :py:class:`torch.Tensor`
97 Ground-truth information with the shape ``[n, c, h, w]``
99 mask : :py:class:`torch.Tensor`
100 Mask to be use for specifying the region of interest where to
101 compute the loss, with the shape ``[n, c, h, w]``
104 Returns
105 -------
107 loss : :py:class:`torch.Tensor`
108 Loss, in a single entry
110 """
112 eps = 1e-8
113 valid = mask > 0.5
114 probabilities = torch.sigmoid(input[valid])
115 intersection = (probabilities * target[valid]).sum()
116 sums = probabilities.sum() + target[valid].sum()
117 J = intersection / (sums - intersection + eps)
119 # this implements the support for looking just into the RoI
120 H = torch.nn.functional.binary_cross_entropy_with_logits(
121 input[valid], target[valid], reduction="mean"
122 )
123 return (self.alpha * H) + ((1 - self.alpha) * (1 - J))
126class MultiWeightedBCELogitsLoss(WeightedBCELogitsLoss):
127 """Weighted Binary Cross-Entropy Loss for multi-layered inputs (e.g. for
128 Holistically-Nested Edge Detection in [XIE-2015]_)."""
130 def __init__(self):
131 super().__init__()
133 def forward(self, input, target, mask):
134 """
135 Parameters
136 ----------
138 input : iterable over :py:class:`torch.Tensor`
139 Value produced by the model to be evaluated, with the shape ``[L,
140 n, c, h, w]``
142 target : :py:class:`torch.Tensor`
143 Ground-truth information with the shape ``[n, c, h, w]``
145 mask : :py:class:`torch.Tensor`
146 Mask to be use for specifying the region of interest where to
147 compute the loss, with the shape ``[n, c, h, w]``
150 Returns
151 -------
153 loss : torch.Tensor
154 The average loss for all input data
156 """
158 return torch.cat(
159 [
160 super(MultiWeightedBCELogitsLoss, self)
161 .forward(i, target, mask)
162 .unsqueeze(0)
163 for i in input
164 ]
165 ).mean()
168class MultiSoftJaccardBCELogitsLoss(SoftJaccardBCELogitsLoss):
169 """Implements Equation 3 in [IGLOVIKOV-2018]_ for the multi-output
170 networks such as HED or Little W-Net.
172 Attributes
173 ----------
175 alpha : float
176 determines the weighting of SoftJaccard and BCE. Default: ``0.3``
177 """
179 def __init__(self, alpha=0.7):
180 super().__init__(alpha=alpha)
182 def forward(self, input, target, mask):
183 """
184 Parameters
185 ----------
187 input : iterable over :py:class:`torch.Tensor`
188 Value produced by the model to be evaluated, with the shape ``[L,
189 n, c, h, w]``
191 target : :py:class:`torch.Tensor`
192 Ground-truth information with the shape ``[n, c, h, w]``
194 mask : :py:class:`torch.Tensor`
195 Mask to be use for specifying the region of interest where to
196 compute the loss, with the shape ``[n, c, h, w]``
199 Returns
200 -------
202 loss : torch.Tensor
203 The average loss for all input data
205 """
207 return torch.cat(
208 [
209 super(MultiSoftJaccardBCELogitsLoss, self)
210 .forward(i, target, mask)
211 .unsqueeze(0)
212 for i in input
213 ]
214 ).mean()
217class MixJacLoss(_Loss):
218 """
220 Parameters
221 ----------
223 lambda_u : int
224 determines the weighting of SoftJaccard and BCE.
226 """
228 def __init__(
229 self,
230 lambda_u=100,
231 jacalpha=0.7,
232 size_average=None,
233 reduce=None,
234 reduction="mean",
235 pos_weight=None,
236 ):
237 super().__init__(size_average, reduce, reduction)
238 self.lambda_u = lambda_u
239 self.labeled_loss = SoftJaccardBCELogitsLoss(alpha=jacalpha)
240 self.unlabeled_loss = torch.nn.BCEWithLogitsLoss()
242 def forward(
243 self, input, target, unlabeled_input, unlabeled_target, ramp_up_factor
244 ):
245 """
246 Parameters
247 ----------
249 input : :py:class:`torch.Tensor`
250 target : :py:class:`torch.Tensor`
251 unlabeled_input : :py:class:`torch.Tensor`
252 unlabeled_target : :py:class:`torch.Tensor`
253 ramp_up_factor : float
255 Returns
256 -------
258 list
260 """
261 ll = self.labeled_loss(input, target)
262 ul = self.unlabeled_loss(unlabeled_input, unlabeled_target)
264 loss = ll + self.lambda_u * ramp_up_factor * ul
265 return loss, ll, ul