Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""Loss implementations"""
3import torch
5from torch.nn.modules.loss import _Loss
8class WeightedBCELogitsLoss(_Loss):
9 """Calculates sum of weighted cross entropy loss.
11 Implements Equation 1 in [MANINIS-2016]_. The weight depends on the
12 current proportion between negatives and positives in the ground-truth
13 sample being analyzed.
14 """
16 def __init__(self):
17 super(WeightedBCELogitsLoss, self).__init__()
19 def forward(self, input, target, mask):
20 """
22 Parameters
23 ----------
25 input : :py:class:`torch.Tensor`
26 Value produced by the model to be evaluated, with the shape ``[n, c,
27 h, w]``
29 target : :py:class:`torch.Tensor`
30 Ground-truth information with the shape ``[n, c, h, w]``
32 mask : :py:class:`torch.Tensor`
33 Mask to be use for specifying the region of interest where to
34 compute the loss, with the shape ``[n, c, h, w]``
37 Returns
38 -------
40 loss : :py:class:`torch.Tensor`
41 The average loss for all input data
43 """
45 # calculates the proportion of negatives to the total number of pixels
46 # available in the masked region
47 valid = mask > 0.5
48 num_pos = target[valid].sum()
49 num_neg = valid.sum() - num_pos
50 pos_weight = num_neg / num_pos
52 return torch.nn.functional.binary_cross_entropy_with_logits(
53 input[valid], target[valid], reduction="mean", pos_weight=pos_weight
54 )
57class SoftJaccardBCELogitsLoss(_Loss):
58 """
59 Implements the generalized loss function of Equation (3) in
60 [IGLOVIKOV-2018]_, with J being the Jaccard distance, and H, the Binary
61 Cross-Entropy Loss:
63 .. math::
65 L = \alpha H + (1-\alpha)(1-J)
68 Our implementation is based on :py:class:`torch.nn.BCEWithLogitsLoss`.
71 Attributes
72 ----------
74 alpha : float
75 determines the weighting of J and H. Default: ``0.7``
77 """
79 def __init__(self, alpha=0.7):
80 super(SoftJaccardBCELogitsLoss, self).__init__()
81 self.alpha = alpha
83 def forward(self, input, target, mask):
84 """
86 Parameters
87 ----------
89 input : :py:class:`torch.Tensor`
90 Value produced by the model to be evaluated, with the shape ``[n, c,
91 h, w]``
93 target : :py:class:`torch.Tensor`
94 Ground-truth information with the shape ``[n, c, h, w]``
96 mask : :py:class:`torch.Tensor`
97 Mask to be use for specifying the region of interest where to
98 compute the loss, with the shape ``[n, c, h, w]``
101 Returns
102 -------
104 loss : :py:class:`torch.Tensor`
105 Loss, in a single entry
107 """
109 eps = 1e-8
110 valid = mask > 0.5
111 probabilities = torch.sigmoid(input[valid])
112 intersection = (probabilities * target[valid]).sum()
113 sums = probabilities.sum() + target[valid].sum()
114 J = intersection / (sums - intersection + eps)
116 # this implements the support for looking just into the RoI
117 H = torch.nn.functional.binary_cross_entropy_with_logits(
118 input[valid], target[valid], reduction="mean"
119 )
120 return (self.alpha * H) + ((1 - self.alpha) * (1 - J))
123class MultiWeightedBCELogitsLoss(WeightedBCELogitsLoss):
124 """
125 Weighted Binary Cross-Entropy Loss for multi-layered inputs (e.g. for
126 Holistically-Nested Edge Detection in [XIE-2015]_).
127 """
129 def __init__(self):
130 super(MultiWeightedBCELogitsLoss, self).__init__()
132 def forward(self, input, target, mask):
133 """
134 Parameters
135 ----------
137 input : iterable over :py:class:`torch.Tensor`
138 Value produced by the model to be evaluated, with the shape ``[L,
139 n, c, h, w]``
141 target : :py:class:`torch.Tensor`
142 Ground-truth information with the shape ``[n, c, h, w]``
144 mask : :py:class:`torch.Tensor`
145 Mask to be use for specifying the region of interest where to
146 compute the loss, with the shape ``[n, c, h, w]``
149 Returns
150 -------
152 loss : torch.Tensor
153 The average loss for all input data
155 """
157 return torch.cat(
158 [
159 super(MultiWeightedBCELogitsLoss, self)
160 .forward(i, target, mask)
161 .unsqueeze(0)
162 for i in input
163 ]
164 ).mean()
167class MultiSoftJaccardBCELogitsLoss(SoftJaccardBCELogitsLoss):
168 """
170 Implements Equation 3 in [IGLOVIKOV-2018]_ for the multi-output networks
171 such as HED or Little W-Net.
174 Attributes
175 ----------
177 alpha : float
178 determines the weighting of SoftJaccard and BCE. Default: ``0.3``
180 """
182 def __init__(self, alpha=0.7):
183 super(MultiSoftJaccardBCELogitsLoss, self).__init__(alpha=alpha)
185 def forward(self, input, target, mask):
186 """
187 Parameters
188 ----------
190 input : iterable over :py:class:`torch.Tensor`
191 Value produced by the model to be evaluated, with the shape ``[L,
192 n, c, h, w]``
194 target : :py:class:`torch.Tensor`
195 Ground-truth information with the shape ``[n, c, h, w]``
197 mask : :py:class:`torch.Tensor`
198 Mask to be use for specifying the region of interest where to
199 compute the loss, with the shape ``[n, c, h, w]``
202 Returns
203 -------
205 loss : torch.Tensor
206 The average loss for all input data
208 """
210 return torch.cat(
211 [
212 super(MultiSoftJaccardBCELogitsLoss, self)
213 .forward(i, target, mask)
214 .unsqueeze(0)
215 for i in input
216 ]
217 ).mean()
220class MixJacLoss(_Loss):
221 """
223 Parameters
224 ----------
226 lambda_u : int
227 determines the weighting of SoftJaccard and BCE.
229 """
231 def __init__(
232 self,
233 lambda_u=100,
234 jacalpha=0.7,
235 size_average=None,
236 reduce=None,
237 reduction="mean",
238 pos_weight=None,
239 ):
240 super(MixJacLoss, self).__init__(size_average, reduce, reduction)
241 self.lambda_u = lambda_u
242 self.labeled_loss = SoftJaccardBCELogitsLoss(alpha=jacalpha)
243 self.unlabeled_loss = torch.nn.BCEWithLogitsLoss()
245 def forward(
246 self, input, target, unlabeled_input, unlabeled_target, ramp_up_factor
247 ):
248 """
249 Parameters
250 ----------
252 input : :py:class:`torch.Tensor`
253 target : :py:class:`torch.Tensor`
254 unlabeled_input : :py:class:`torch.Tensor`
255 unlabeled_target : :py:class:`torch.Tensor`
256 ramp_up_factor : float
258 Returns
259 -------
261 list
263 """
264 ll = self.labeled_loss(input, target)
265 ul = self.unlabeled_loss(unlabeled_input, unlabeled_target)
267 loss = ll + self.lambda_u * ramp_up_factor * ul
268 return loss, ll, ul