Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1"""Loss implementations""" 

2 

3import torch 

4 

5from torch.nn.modules.loss import _Loss 

6 

7 

8class WeightedBCELogitsLoss(_Loss): 

9 """Calculates sum of weighted cross entropy loss. 

10 

11 Implements Equation 1 in [MANINIS-2016]_. The weight depends on the 

12 current proportion between negatives and positives in the ground-truth 

13 sample being analyzed. 

14 """ 

15 

16 def __init__(self): 

17 super(WeightedBCELogitsLoss, self).__init__() 

18 

19 def forward(self, input, target, mask): 

20 """ 

21 

22 Parameters 

23 ---------- 

24 

25 input : :py:class:`torch.Tensor` 

26 Value produced by the model to be evaluated, with the shape ``[n, c, 

27 h, w]`` 

28 

29 target : :py:class:`torch.Tensor` 

30 Ground-truth information with the shape ``[n, c, h, w]`` 

31 

32 mask : :py:class:`torch.Tensor` 

33 Mask to be use for specifying the region of interest where to 

34 compute the loss, with the shape ``[n, c, h, w]`` 

35 

36 

37 Returns 

38 ------- 

39 

40 loss : :py:class:`torch.Tensor` 

41 The average loss for all input data 

42 

43 """ 

44 

45 # calculates the proportion of negatives to the total number of pixels 

46 # available in the masked region 

47 valid = mask > 0.5 

48 num_pos = target[valid].sum() 

49 num_neg = valid.sum() - num_pos 

50 pos_weight = num_neg / num_pos 

51 

52 return torch.nn.functional.binary_cross_entropy_with_logits( 

53 input[valid], target[valid], reduction="mean", pos_weight=pos_weight 

54 ) 

55 

56 

57class SoftJaccardBCELogitsLoss(_Loss): 

58 """ 

59 Implements the generalized loss function of Equation (3) in 

60 [IGLOVIKOV-2018]_, with J being the Jaccard distance, and H, the Binary 

61 Cross-Entropy Loss: 

62 

63 .. math:: 

64 

65 L = \alpha H + (1-\alpha)(1-J) 

66 

67 

68 Our implementation is based on :py:class:`torch.nn.BCEWithLogitsLoss`. 

69 

70 

71 Attributes 

72 ---------- 

73 

74 alpha : float 

75 determines the weighting of J and H. Default: ``0.7`` 

76 

77 """ 

78 

79 def __init__(self, alpha=0.7): 

80 super(SoftJaccardBCELogitsLoss, self).__init__() 

81 self.alpha = alpha 

82 

83 def forward(self, input, target, mask): 

84 """ 

85 

86 Parameters 

87 ---------- 

88 

89 input : :py:class:`torch.Tensor` 

90 Value produced by the model to be evaluated, with the shape ``[n, c, 

91 h, w]`` 

92 

93 target : :py:class:`torch.Tensor` 

94 Ground-truth information with the shape ``[n, c, h, w]`` 

95 

96 mask : :py:class:`torch.Tensor` 

97 Mask to be use for specifying the region of interest where to 

98 compute the loss, with the shape ``[n, c, h, w]`` 

99 

100 

101 Returns 

102 ------- 

103 

104 loss : :py:class:`torch.Tensor` 

105 Loss, in a single entry 

106 

107 """ 

108 

109 eps = 1e-8 

110 valid = mask > 0.5 

111 probabilities = torch.sigmoid(input[valid]) 

112 intersection = (probabilities * target[valid]).sum() 

113 sums = probabilities.sum() + target[valid].sum() 

114 J = intersection / (sums - intersection + eps) 

115 

116 # this implements the support for looking just into the RoI 

117 H = torch.nn.functional.binary_cross_entropy_with_logits( 

118 input[valid], target[valid], reduction="mean" 

119 ) 

120 return (self.alpha * H) + ((1 - self.alpha) * (1 - J)) 

121 

122 

123class MultiWeightedBCELogitsLoss(WeightedBCELogitsLoss): 

124 """ 

125 Weighted Binary Cross-Entropy Loss for multi-layered inputs (e.g. for 

126 Holistically-Nested Edge Detection in [XIE-2015]_). 

127 """ 

128 

129 def __init__(self): 

130 super(MultiWeightedBCELogitsLoss, self).__init__() 

131 

132 def forward(self, input, target, mask): 

133 """ 

134 Parameters 

135 ---------- 

136 

137 input : iterable over :py:class:`torch.Tensor` 

138 Value produced by the model to be evaluated, with the shape ``[L, 

139 n, c, h, w]`` 

140 

141 target : :py:class:`torch.Tensor` 

142 Ground-truth information with the shape ``[n, c, h, w]`` 

143 

144 mask : :py:class:`torch.Tensor` 

145 Mask to be use for specifying the region of interest where to 

146 compute the loss, with the shape ``[n, c, h, w]`` 

147 

148 

149 Returns 

150 ------- 

151 

152 loss : torch.Tensor 

153 The average loss for all input data 

154 

155 """ 

156 

157 return torch.cat( 

158 [ 

159 super(MultiWeightedBCELogitsLoss, self) 

160 .forward(i, target, mask) 

161 .unsqueeze(0) 

162 for i in input 

163 ] 

164 ).mean() 

165 

166 

167class MultiSoftJaccardBCELogitsLoss(SoftJaccardBCELogitsLoss): 

168 """ 

169 

170 Implements Equation 3 in [IGLOVIKOV-2018]_ for the multi-output networks 

171 such as HED or Little W-Net. 

172 

173 

174 Attributes 

175 ---------- 

176 

177 alpha : float 

178 determines the weighting of SoftJaccard and BCE. Default: ``0.3`` 

179 

180 """ 

181 

182 def __init__(self, alpha=0.7): 

183 super(MultiSoftJaccardBCELogitsLoss, self).__init__(alpha=alpha) 

184 

185 def forward(self, input, target, mask): 

186 """ 

187 Parameters 

188 ---------- 

189 

190 input : iterable over :py:class:`torch.Tensor` 

191 Value produced by the model to be evaluated, with the shape ``[L, 

192 n, c, h, w]`` 

193 

194 target : :py:class:`torch.Tensor` 

195 Ground-truth information with the shape ``[n, c, h, w]`` 

196 

197 mask : :py:class:`torch.Tensor` 

198 Mask to be use for specifying the region of interest where to 

199 compute the loss, with the shape ``[n, c, h, w]`` 

200 

201 

202 Returns 

203 ------- 

204 

205 loss : torch.Tensor 

206 The average loss for all input data 

207 

208 """ 

209 

210 return torch.cat( 

211 [ 

212 super(MultiSoftJaccardBCELogitsLoss, self) 

213 .forward(i, target, mask) 

214 .unsqueeze(0) 

215 for i in input 

216 ] 

217 ).mean() 

218 

219 

220class MixJacLoss(_Loss): 

221 """ 

222 

223 Parameters 

224 ---------- 

225 

226 lambda_u : int 

227 determines the weighting of SoftJaccard and BCE. 

228 

229 """ 

230 

231 def __init__( 

232 self, 

233 lambda_u=100, 

234 jacalpha=0.7, 

235 size_average=None, 

236 reduce=None, 

237 reduction="mean", 

238 pos_weight=None, 

239 ): 

240 super(MixJacLoss, self).__init__(size_average, reduce, reduction) 

241 self.lambda_u = lambda_u 

242 self.labeled_loss = SoftJaccardBCELogitsLoss(alpha=jacalpha) 

243 self.unlabeled_loss = torch.nn.BCEWithLogitsLoss() 

244 

245 def forward( 

246 self, input, target, unlabeled_input, unlabeled_target, ramp_up_factor 

247 ): 

248 """ 

249 Parameters 

250 ---------- 

251 

252 input : :py:class:`torch.Tensor` 

253 target : :py:class:`torch.Tensor` 

254 unlabeled_input : :py:class:`torch.Tensor` 

255 unlabeled_target : :py:class:`torch.Tensor` 

256 ramp_up_factor : float 

257 

258 Returns 

259 ------- 

260 

261 list 

262 

263 """ 

264 ll = self.labeled_loss(input, target) 

265 ul = self.unlabeled_loss(unlabeled_input, unlabeled_target) 

266 

267 loss = ll + self.lambda_u * ramp_up_factor * ul 

268 return loss, ll, ul