Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1#!/usr/bin/env python 

2# -*- coding: utf-8 -*- 

3 

4import contextlib 

5import logging 

6 

7from itertools import cycle 

8 

9import matplotlib 

10import matplotlib.pyplot as plt 

11import numpy 

12 

13matplotlib.use("agg") 

14logger = logging.getLogger(__name__) 

15 

16 

17def _concave_hull(x, y, lx, ux, ly, uy): 

18 """Calculates a approximate (concave) hull from arc centers and sizes 

19 

20 Each ellipse is approximated as a number of discrete points distributed 

21 over the ellipse border following an homogeneous angle distribution. 

22 

23 

24 Parameters 

25 ---------- 

26 

27 x : numpy.ndarray 

28 1D array with x coordinates of ellipse centers 

29 

30 y : numpy.ndarray 

31 1D array with y coordinates of ellipse centers 

32 

33 lx, ux, ly, uy : numpy.ndarray 

34 1D array(s) with upper and lower widths and heights for your deformed 

35 ellipse 

36 

37 

38 Returns 

39 ------- 

40 

41 points : numpy.ndarray 

42 2D array containing the ``(x, y)`` coordinates of the concave hull 

43 encompassing all defined arcs. 

44 

45 """ 

46 

47 def _irregular_ellipse_points(_x, _y, _lx, _ux, _ly, _uy, steps=100): 

48 """Generates border points for an irregular ellipse 

49 

50 This functions distributes points according to a rotation angle rather 

51 than uniformily with respect to a particular axis. The result is a 

52 more homogeneous border representation for the ellipse. 

53 """ 

54 up = _uy - _y 

55 down = _y - _ly 

56 left = _x - _lx 

57 right = _ux - _x 

58 

59 angles = numpy.arange(0, numpy.pi / 2, step=2 * numpy.pi / steps) 

60 points = numpy.ndarray((0, 2)) 

61 

62 # upper left part (90 -> 180 degrees) 

63 px = 2 * left * numpy.cos(angles) 

64 py = (up / left) * numpy.sqrt(numpy.square(2 * left) - numpy.square(px)) 

65 # order: x and y increase 

66 points = numpy.vstack((points, numpy.array([_x - px, _y + py]).T)) 

67 

68 # upper right part (0 -> 90 degrees) 

69 px = 2 * right * numpy.cos(angles) 

70 py = (up / right) * numpy.sqrt( 

71 numpy.square(2 * right) - numpy.square(px) 

72 ) 

73 # order: x increases and y decreases 

74 points = numpy.vstack( 

75 (points, numpy.flipud(numpy.array([_x + px, _y + py]).T)) 

76 ) 

77 

78 # lower right part (180 -> 270 degrees) 

79 px = 2 * right * numpy.cos(angles) 

80 py = (down / right) * numpy.sqrt( 

81 numpy.square(2 * right) - numpy.square(px) 

82 ) 

83 # order: x increases and y decreases 

84 points = numpy.vstack((points, numpy.array([_x + px, _y - py]).T)) 

85 

86 # lower left part (180 -> 270 degrees) 

87 px = 2 * left * numpy.cos(angles) 

88 py = (down / left) * numpy.sqrt( 

89 numpy.square(2 * left) - numpy.square(px) 

90 ) 

91 # order: x decreases and y increases 

92 points = numpy.vstack( 

93 (points, numpy.flipud(numpy.array([_x - px, _y - py]).T)) 

94 ) 

95 

96 return points 

97 

98 retval = numpy.ndarray((0, 2)) 

99 for (k, l, m, n, o, p) in zip(x, y, lx, ux, ly, uy): 

100 retval = numpy.vstack( 

101 ( 

102 retval, 

103 [numpy.nan, numpy.nan], 

104 _irregular_ellipse_points(k, l, m, n, o, p), 

105 ) 

106 ) 

107 return retval 

108 

109 

110@contextlib.contextmanager 

111def _precision_recall_canvas(title=None): 

112 """Generates a canvas to draw precision-recall curves 

113 

114 Works like a context manager, yielding a figure and an axes set in which 

115 the precision-recall curves should be added to. The figure already 

116 contains F1-ISO lines and is preset to a 0-1 square region. Once the 

117 context is finished, ``fig.tight_layout()`` is called. 

118 

119 

120 Parameters 

121 ---------- 

122 

123 title : :py:class:`str`, Optional 

124 Optional title to add to this plot 

125 

126 

127 Yields 

128 ------ 

129 

130 figure : matplotlib.figure.Figure 

131 The figure that should be finally returned to the user 

132 

133 axes : matplotlib.figure.Axes 

134 An axis set where to precision-recall plots should be added to 

135 

136 """ 

137 

138 fig, axes1 = plt.subplots(1) 

139 

140 # Names and bounds 

141 axes1.set_xlabel("Recall") 

142 axes1.set_ylabel("Precision") 

143 axes1.set_xlim([0.0, 1.0]) 

144 axes1.set_ylim([0.0, 1.0]) 

145 

146 if title is not None: 

147 axes1.set_title(title) 

148 

149 axes1.grid(linestyle="--", linewidth=1, color="gray", alpha=0.2) 

150 axes2 = axes1.twinx() 

151 

152 # Annotates plot with F1-score iso-lines 

153 f_scores = numpy.linspace(0.1, 0.9, num=9) 

154 tick_locs = [] 

155 tick_labels = [] 

156 for f_score in f_scores: 

157 x = numpy.linspace(0.01, 1) 

158 y = f_score * x / (2 * x - f_score) 

159 (l,) = plt.plot(x[y >= 0], y[y >= 0], color="green", alpha=0.1) 

160 tick_locs.append(y[-1]) 

161 tick_labels.append("%.1f" % f_score) 

162 axes2.tick_params(axis="y", which="both", pad=0, right=False, left=False) 

163 axes2.set_ylabel("iso-F", color="green", alpha=0.3) 

164 axes2.set_ylim([0.0, 1.0]) 

165 axes2.yaxis.set_label_coords(1.015, 0.97) 

166 axes2.set_yticks(tick_locs) # notice these are invisible 

167 for k in axes2.set_yticklabels(tick_labels): 

168 k.set_color("green") 

169 k.set_alpha(0.3) 

170 k.set_size(8) 

171 

172 # we should see some of axes 1 axes 

173 axes1.spines["right"].set_visible(False) 

174 axes1.spines["top"].set_visible(False) 

175 axes1.spines["left"].set_position(("data", -0.015)) 

176 axes1.spines["bottom"].set_position(("data", -0.015)) 

177 

178 # we shouldn't see any of axes 2 axes 

179 axes2.spines["right"].set_visible(False) 

180 axes2.spines["top"].set_visible(False) 

181 axes2.spines["left"].set_visible(False) 

182 axes2.spines["bottom"].set_visible(False) 

183 

184 # yield execution, lets user draw precision-recall plots, and the legend 

185 # before tighteneing the layout 

186 yield fig, axes1 

187 

188 plt.tight_layout() 

189 

190 

191def precision_recall_f1iso(data, credible=True): 

192 """Creates a precision-recall plot with credible intervals 

193 

194 This function creates and returns a Matplotlib figure with a 

195 precision-recall plot containing shaded credible intervals (on the 

196 precision-recall measurements). The plot will be annotated with F1-score 

197 iso-lines (in which the F1-score maintains the same value). 

198 

199 This function specially supports "second-annotator" entries by plotting a 

200 line showing the comparison between the default annotator being analyzed 

201 and a second "opinion". Second annotator dataframes contain a single entry 

202 (threshold=0.5), given the nature of the binary map comparisons. 

203 

204 

205 Parameters 

206 ---------- 

207 

208 data : dict 

209 A dictionary in which keys are strings defining plot labels and values 

210 are dictionaries with two entries: 

211 

212 * ``df``: :py:class:`pandas.DataFrame` 

213 

214 A dataframe that is produced by our evaluator engine, indexed by 

215 integer "thresholds", containing the following columns: 

216 ``threshold``, ``tp``, ``fp``, ``tn``, ``fn``, ``mean_precision``, 

217 ``mode_precision``, ``lower_precision``, ``upper_precision``, 

218 ``mean_recall``, ``mode_recall``, ``lower_recall``, ``upper_recall``, 

219 ``mean_specificity``, ``mode_specificity``, ``lower_specificity``, 

220 ``upper_specificity``, ``mean_accuracy``, ``mode_accuracy``, 

221 ``lower_accuracy``, ``upper_accuracy``, ``mean_jaccard``, 

222 ``mode_jaccard``, ``lower_jaccard``, ``upper_jaccard``, 

223 ``mean_f1_score``, ``mode_f1_score``, ``lower_f1_score``, 

224 ``upper_f1_score``, ``frequentist_precision``, 

225 ``frequentist_recall``, ``frequentist_specificity``, 

226 ``frequentist_accuracy``, ``frequentist_jaccard``, 

227 ``frequentist_f1_score``. 

228 

229 * ``threshold``: :py:class:`list` 

230 

231 A threshold to graph with a dot for each set. Specific 

232 threshold values do not affect "second-annotator" dataframes. 

233 

234 credible : :py:class:`bool`, Optional 

235 If set, draw credible intervals for each line, using ``upper_*`` and 

236 ``lower_*`` entries. 

237 

238 

239 Returns 

240 ------- 

241 

242 figure : matplotlib.figure.Figure 

243 A matplotlib figure you can save or display (uses an ``agg`` backend) 

244 

245 """ 

246 

247 lines = ["-", "--", "-.", ":"] 

248 colors = [ 

249 "#1f77b4", 

250 "#ff7f0e", 

251 "#2ca02c", 

252 "#d62728", 

253 "#9467bd", 

254 "#8c564b", 

255 "#e377c2", 

256 "#7f7f7f", 

257 "#bcbd22", 

258 "#17becf", 

259 ] 

260 colorcycler = cycle(colors) 

261 linecycler = cycle(lines) 

262 

263 with _precision_recall_canvas(title=None) as (fig, axes): 

264 

265 legend = [] 

266 

267 for name, value in data.items(): 

268 

269 df = value["df"] 

270 threshold = value["threshold"] 

271 

272 # plots only from the point where recall reaches its maximum, 

273 # otherwise, we don't see a curve... 

274 max_recall = df.mean_recall.idxmax() 

275 pi = df.mean_precision[max_recall:] 

276 ri = df.mean_recall[max_recall:] 

277 # valid = (pi + ri) > 0 

278 

279 # optimal point along the curve 

280 bins = len(df) 

281 index = int(round(bins * threshold)) 

282 index = min(index, len(df) - 1) # avoids out of range indexing 

283 

284 # plots Recall/Precision as threshold changes 

285 label = f"{name} (F1={df.mean_f1_score[index]:.4f})" 

286 color = next(colorcycler) 

287 

288 if len(df) == 1: 

289 # plot black dot for F1-score at select threshold 

290 (marker,) = axes.plot( 

291 df.mean_recall[index], 

292 df.mean_precision[index], 

293 marker="*", 

294 markersize=6, 

295 color=color, 

296 alpha=0.8, 

297 linestyle="None", 

298 ) 

299 (line,) = axes.plot( 

300 df.mean_recall[index], 

301 df.mean_precision[index], 

302 linestyle="None", 

303 color=color, 

304 alpha=0.2, 

305 ) 

306 legend.append(([marker, line], label)) 

307 else: 

308 # line first, so marker gets on top 

309 style = next(linecycler) 

310 (line,) = axes.plot( 

311 ri[pi > 0], pi[pi > 0], color=color, linestyle=style 

312 ) 

313 (marker,) = axes.plot( 

314 df.mean_recall[index], 

315 df.mean_precision[index], 

316 marker="o", 

317 linestyle=style, 

318 markersize=4, 

319 color=color, 

320 alpha=0.8, 

321 ) 

322 legend.append(([marker, line], label)) 

323 

324 if credible: 

325 

326 hull = _concave_hull( 

327 df.mean_recall, 

328 df.mean_precision, 

329 df.lower_recall, 

330 df.upper_recall, 

331 df.lower_precision, 

332 df.upper_precision, 

333 ) 

334 p = plt.Polygon( 

335 hull, 

336 facecolor=color, 

337 alpha=0.2, 

338 edgecolor="none", 

339 lw=0.2, 

340 closed=True, 

341 ) 

342 axes.add_patch(p) 

343 legend[-1][0].append(p) 

344 

345 if len(label) > 1: 

346 axes.legend( 

347 [tuple(k[0]) for k in legend], 

348 [k[1] for k in legend], 

349 loc="lower left", 

350 fancybox=True, 

351 framealpha=0.7, 

352 ) 

353 

354 return fig 

355 

356 

357def loss_curve(df): 

358 """Creates a loss curve in a Matplotlib figure. 

359 

360 Parameters 

361 ---------- 

362 

363 df : :py:class:`pandas.DataFrame` 

364 A dataframe containing, at least, "epoch", "median-loss" and 

365 "learning-rate" columns, that will be plotted. 

366 

367 Returns 

368 ------- 

369 

370 figure : matplotlib.figure.Figure 

371 A figure, that may be saved or displayed 

372 

373 """ 

374 

375 ax1 = df.plot(x="epoch", y="median-loss", grid=True) 

376 ax1.set_ylabel("Median Loss") 

377 ax1.grid(linestyle="--", linewidth=1, color="gray", alpha=0.2) 

378 ax2 = df["learning-rate"].plot( 

379 secondary_y=True, 

380 legend=True, 

381 grid=True, 

382 ) 

383 ax2.set_ylabel("Learning Rate") 

384 ax1.set_xlabel("Epoch") 

385 plt.tight_layout() 

386 fig = ax1.get_figure() 

387 return fig