Coverage for src/deepdraw/models/losses.py: 63%

46 statements  

« prev     ^ index     » next       coverage.py v7.3.1, created at 2023-11-30 15:00 +0100

1# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> 

2# 

3# SPDX-License-Identifier: GPL-3.0-or-later 

4 

5"""Loss implementations.""" 

6 

7import torch 

8 

9from torch.nn.modules.loss import _Loss 

10 

11 

12class WeightedBCELogitsLoss(_Loss): 

13 """Calculates sum of weighted cross entropy loss. 

14 

15 Implements Equation 1 in [MANINIS-2016]_. The weight depends on the 

16 current proportion between negatives and positives in the ground- 

17 truth sample being analyzed. 

18 """ 

19 

20 def __init__(self): 

21 super().__init__() 

22 

23 def forward(self, input, target, mask): 

24 """ 

25 

26 Parameters 

27 ---------- 

28 

29 input : :py:class:`torch.Tensor` 

30 Value produced by the model to be evaluated, with the shape ``[n, c, 

31 h, w]`` 

32 

33 target : :py:class:`torch.Tensor` 

34 Ground-truth information with the shape ``[n, c, h, w]`` 

35 

36 mask : :py:class:`torch.Tensor` 

37 Mask to be use for specifying the region of interest where to 

38 compute the loss, with the shape ``[n, c, h, w]`` 

39 

40 

41 Returns 

42 ------- 

43 

44 loss : :py:class:`torch.Tensor` 

45 The average loss for all input data 

46 

47 """ 

48 

49 # calculates the proportion of negatives to the total number of pixels 

50 # available in the masked region 

51 valid = mask > 0.5 

52 num_pos = target[valid].sum() 

53 num_neg = valid.sum() - num_pos 

54 pos_weight = num_neg / num_pos 

55 

56 return torch.nn.functional.binary_cross_entropy_with_logits( 

57 input[valid], target[valid], reduction="mean", pos_weight=pos_weight 

58 ) 

59 

60 

61class SoftJaccardBCELogitsLoss(_Loss): 

62 """Implements the generalized loss function of Equation (3) in. 

63 

64 [IGLOVIKOV-2018]_, with J being the Jaccard distance, and H, the Binary 

65 Cross-Entropy Loss: 

66 

67 .. math:: 

68 

69 L = \alpha H + (1-\alpha)(1-J) 

70 

71 

72 Our implementation is based on :py:class:`torch.nn.BCEWithLogitsLoss`. 

73 

74 

75 Attributes 

76 ---------- 

77 

78 alpha : float 

79 determines the weighting of J and H. Default: ``0.7`` 

80 """ 

81 

82 def __init__(self, alpha=0.7): 

83 super().__init__() 

84 self.alpha = alpha 

85 

86 def forward(self, input, target, mask): 

87 """ 

88 

89 Parameters 

90 ---------- 

91 

92 input : :py:class:`torch.Tensor` 

93 Value produced by the model to be evaluated, with the shape ``[n, c, 

94 h, w]`` 

95 

96 target : :py:class:`torch.Tensor` 

97 Ground-truth information with the shape ``[n, c, h, w]`` 

98 

99 mask : :py:class:`torch.Tensor` 

100 Mask to be use for specifying the region of interest where to 

101 compute the loss, with the shape ``[n, c, h, w]`` 

102 

103 

104 Returns 

105 ------- 

106 

107 loss : :py:class:`torch.Tensor` 

108 Loss, in a single entry 

109 

110 """ 

111 

112 eps = 1e-8 

113 valid = mask > 0.5 

114 probabilities = torch.sigmoid(input[valid]) 

115 intersection = (probabilities * target[valid]).sum() 

116 sums = probabilities.sum() + target[valid].sum() 

117 J = intersection / (sums - intersection + eps) 

118 

119 # this implements the support for looking just into the RoI 

120 H = torch.nn.functional.binary_cross_entropy_with_logits( 

121 input[valid], target[valid], reduction="mean" 

122 ) 

123 return (self.alpha * H) + ((1 - self.alpha) * (1 - J)) 

124 

125 

126class MultiWeightedBCELogitsLoss(WeightedBCELogitsLoss): 

127 """Weighted Binary Cross-Entropy Loss for multi-layered inputs (e.g. for 

128 Holistically-Nested Edge Detection in [XIE-2015]_).""" 

129 

130 def __init__(self): 

131 super().__init__() 

132 

133 def forward(self, input, target, mask): 

134 """ 

135 Parameters 

136 ---------- 

137 

138 input : iterable over :py:class:`torch.Tensor` 

139 Value produced by the model to be evaluated, with the shape ``[L, 

140 n, c, h, w]`` 

141 

142 target : :py:class:`torch.Tensor` 

143 Ground-truth information with the shape ``[n, c, h, w]`` 

144 

145 mask : :py:class:`torch.Tensor` 

146 Mask to be use for specifying the region of interest where to 

147 compute the loss, with the shape ``[n, c, h, w]`` 

148 

149 

150 Returns 

151 ------- 

152 

153 loss : torch.Tensor 

154 The average loss for all input data 

155 

156 """ 

157 

158 return torch.cat( 

159 [ 

160 super(MultiWeightedBCELogitsLoss, self) 

161 .forward(i, target, mask) 

162 .unsqueeze(0) 

163 for i in input 

164 ] 

165 ).mean() 

166 

167 

168class MultiSoftJaccardBCELogitsLoss(SoftJaccardBCELogitsLoss): 

169 """Implements Equation 3 in [IGLOVIKOV-2018]_ for the multi-output 

170 networks such as HED or Little W-Net. 

171 

172 Attributes 

173 ---------- 

174 

175 alpha : float 

176 determines the weighting of SoftJaccard and BCE. Default: ``0.3`` 

177 """ 

178 

179 def __init__(self, alpha=0.7): 

180 super().__init__(alpha=alpha) 

181 

182 def forward(self, input, target, mask): 

183 """ 

184 Parameters 

185 ---------- 

186 

187 input : iterable over :py:class:`torch.Tensor` 

188 Value produced by the model to be evaluated, with the shape ``[L, 

189 n, c, h, w]`` 

190 

191 target : :py:class:`torch.Tensor` 

192 Ground-truth information with the shape ``[n, c, h, w]`` 

193 

194 mask : :py:class:`torch.Tensor` 

195 Mask to be use for specifying the region of interest where to 

196 compute the loss, with the shape ``[n, c, h, w]`` 

197 

198 

199 Returns 

200 ------- 

201 

202 loss : torch.Tensor 

203 The average loss for all input data 

204 

205 """ 

206 

207 return torch.cat( 

208 [ 

209 super(MultiSoftJaccardBCELogitsLoss, self) 

210 .forward(i, target, mask) 

211 .unsqueeze(0) 

212 for i in input 

213 ] 

214 ).mean() 

215 

216 

217class MixJacLoss(_Loss): 

218 """ 

219 

220 Parameters 

221 ---------- 

222 

223 lambda_u : int 

224 determines the weighting of SoftJaccard and BCE. 

225 

226 """ 

227 

228 def __init__( 

229 self, 

230 lambda_u=100, 

231 jacalpha=0.7, 

232 size_average=None, 

233 reduce=None, 

234 reduction="mean", 

235 pos_weight=None, 

236 ): 

237 super().__init__(size_average, reduce, reduction) 

238 self.lambda_u = lambda_u 

239 self.labeled_loss = SoftJaccardBCELogitsLoss(alpha=jacalpha) 

240 self.unlabeled_loss = torch.nn.BCEWithLogitsLoss() 

241 

242 def forward( 

243 self, input, target, unlabeled_input, unlabeled_target, ramp_up_factor 

244 ): 

245 """ 

246 Parameters 

247 ---------- 

248 

249 input : :py:class:`torch.Tensor` 

250 target : :py:class:`torch.Tensor` 

251 unlabeled_input : :py:class:`torch.Tensor` 

252 unlabeled_target : :py:class:`torch.Tensor` 

253 ramp_up_factor : float 

254 

255 Returns 

256 ------- 

257 

258 list 

259 

260 """ 

261 ll = self.labeled_loss(input, target) 

262 ul = self.unlabeled_loss(unlabeled_input, unlabeled_target) 

263 

264 loss = ll + self.lambda_u * ramp_up_factor * ul 

265 return loss, ll, ul