Coverage for src/deepdraw/utils/plot.py: 89%

89 statements  

« prev     ^ index     » next       coverage.py v7.3.1, created at 2023-11-30 15:00 +0100

1# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> 

2# 

3# SPDX-License-Identifier: GPL-3.0-or-later 

4 

5import contextlib 

6import logging 

7 

8from itertools import cycle 

9 

10import matplotlib 

11import matplotlib.pyplot as plt 

12import numpy 

13 

14matplotlib.use("agg") 

15logger = logging.getLogger(__name__) 

16 

17 

18@contextlib.contextmanager 

19def _precision_recall_canvas(title=None, limits=[0.0, 1.0, 0.0, 1.0]): 

20 """Generates a canvas to draw precision-recall curves. 

21 

22 Works like a context manager, yielding a figure and an axes set in which 

23 the precision-recall curves should be added to. The figure already 

24 contains F1-ISO lines and is preset to a 0-1 square region. Once the 

25 context is finished, ``fig.tight_layout()`` is called. 

26 

27 

28 Parameters 

29 ---------- 

30 

31 title : :py:class:`str`, Optional 

32 Optional title to add to this plot 

33 

34 limits : :py:class:`tuple`, Optional 

35 If set, a 4-tuple containing the bounds of the plot for the x and y 

36 axis respectively (format: ``[x_low, x_high, y_low, y_high]``). If not 

37 set, use normal bounds (``[0, 1, 0, 1]``). 

38 

39 

40 Yields 

41 ------ 

42 

43 figure : matplotlib.figure.Figure 

44 The figure that should be finally returned to the user 

45 

46 axes : matplotlib.figure.Axes 

47 An axis set where to precision-recall plots should be added to 

48 """ 

49 

50 fig, axes1 = plt.subplots(1) 

51 

52 # Names and bounds 

53 axes1.set_xlabel("Recall") 

54 axes1.set_ylabel("Precision") 

55 axes1.set_xlim(limits[:2]) 

56 axes1.set_ylim(limits[2:]) 

57 

58 if title is not None: 

59 axes1.set_title(title) 

60 

61 axes1.grid(linestyle="--", linewidth=1, color="gray", alpha=0.2) 

62 axes2 = axes1.twinx() 

63 

64 # Annotates plot with F1-score iso-lines 

65 f_scores = numpy.linspace(0.1, 0.9, num=9) 

66 tick_locs = [] 

67 tick_labels = [] 

68 for f_score in f_scores: 

69 x = numpy.linspace(0.01, 1) 

70 y = f_score * x / (2 * x - f_score) 

71 (l,) = plt.plot(x[y >= 0], y[y >= 0], color="green", alpha=0.1) 

72 tick_locs.append(y[-1]) 

73 tick_labels.append("%.1f" % f_score) 

74 axes2.tick_params(axis="y", which="both", pad=0, right=False, left=False) 

75 axes2.set_ylabel("iso-F", color="green", alpha=0.3) 

76 axes2.set_ylim([0.0, 1.0]) 

77 axes2.yaxis.set_label_coords(1.015, 0.97) 

78 axes2.set_yticks(tick_locs) # notice these are invisible 

79 for k in axes2.set_yticklabels(tick_labels): 

80 k.set_color("green") 

81 k.set_alpha(0.3) 

82 k.set_size(8) 

83 

84 # we should see some of axes 1 axes 

85 axes1.spines["right"].set_visible(False) 

86 axes1.spines["top"].set_visible(False) 

87 axes1.spines["left"].set_position( 

88 ("data", limits[0] - (0.015 * (limits[1] - limits[0]))) 

89 ) 

90 axes1.spines["bottom"].set_position( 

91 ("data", limits[2] - (0.015 * (limits[3] - limits[2]))) 

92 ) 

93 

94 # we shouldn't see any of axes 2 axes 

95 axes2.spines["right"].set_visible(False) 

96 axes2.spines["top"].set_visible(False) 

97 axes2.spines["left"].set_visible(False) 

98 axes2.spines["bottom"].set_visible(False) 

99 

100 # yield execution, lets user draw precision-recall plots, and the legend 

101 # before tighteneing the layout 

102 yield fig, axes1 

103 

104 plt.tight_layout() 

105 

106 

107def precision_recall_f1iso(data, limits): 

108 """Creates a precision-recall plot. 

109 

110 This function creates and returns a Matplotlib figure with a 

111 precision-recall plot. The plot will be annotated with F1-score iso-lines 

112 (in which the F1-score maintains the same value). 

113 

114 This function specially supports "second-annotator" entries by plotting a 

115 line showing the comparison between the default annotator being analyzed 

116 and a second "opinion". Second annotator dataframes contain a single entry 

117 (threshold=0.5), given the nature of the binary map comparisons. 

118 

119 

120 Parameters 

121 ---------- 

122 

123 data : dict 

124 A dictionary in which keys are strings defining plot labels and values 

125 are dictionaries with two entries: 

126 

127 * ``df``: :py:class:`pandas.DataFrame` 

128 

129 A dataframe that is produced by our evaluator engine, indexed by 

130 integer "thresholds", containing the following columns: 

131 ``threshold``, ``tp``, ``fp``, ``tn``, ``fn``, ``mean_precision``, 

132 ``mode_precision``, ``lower_precision``, ``upper_precision``, 

133 ``mean_recall``, ``mode_recall``, ``lower_recall``, ``upper_recall``, 

134 ``mean_specificity``, ``mode_specificity``, ``lower_specificity``, 

135 ``upper_specificity``, ``mean_accuracy``, ``mode_accuracy``, 

136 ``lower_accuracy``, ``upper_accuracy``, ``mean_jaccard``, 

137 ``mode_jaccard``, ``lower_jaccard``, ``upper_jaccard``, 

138 ``mean_f1_score``, ``mode_f1_score``, ``lower_f1_score``, 

139 ``upper_f1_score``, ``frequentist_precision``, 

140 ``frequentist_recall``, ``frequentist_specificity``, 

141 ``frequentist_accuracy``, ``frequentist_jaccard``, 

142 ``frequentist_f1_score``. 

143 

144 * ``threshold``: :py:class:`list` 

145 

146 A threshold to graph with a dot for each set. Specific 

147 threshold values do not affect "second-annotator" dataframes. 

148 

149 limits : tuple 

150 A 4-tuple containing the bounds of the plot for the x and y axis 

151 respectively (format: ``[x_low, x_high, y_low, y_high]``). If not set, 

152 use normal bounds (``[0, 1, 0, 1]``). 

153 

154 

155 Returns 

156 ------- 

157 

158 figure : matplotlib.figure.Figure 

159 A matplotlib figure you can save or display (uses an ``agg`` backend) 

160 """ 

161 

162 lines = ["-", "--", "-.", ":"] 

163 colors = [ 

164 "#1f77b4", 

165 "#ff7f0e", 

166 "#2ca02c", 

167 "#d62728", 

168 "#9467bd", 

169 "#8c564b", 

170 "#e377c2", 

171 "#7f7f7f", 

172 "#bcbd22", 

173 "#17becf", 

174 ] 

175 colorcycler = cycle(colors) 

176 linecycler = cycle(lines) 

177 

178 with _precision_recall_canvas(title=None, limits=limits) as (fig, axes): 

179 legend = [] 

180 

181 for name, value in data.items(): 

182 df = value["df"] 

183 threshold = value["threshold"] 

184 

185 # plots only from the point where recall reaches its maximum, 

186 # otherwise, we don't see a curve... 

187 max_recall = df.mean_recall.idxmax() 

188 pi = df.mean_precision[max_recall:] 

189 ri = df.mean_recall[max_recall:] 

190 # valid = (pi + ri) > 0 

191 

192 # optimal point along the curve 

193 bins = len(df) 

194 index = int(round(bins * threshold)) 

195 index = min(index, len(df) - 1) # avoids out of range indexing 

196 

197 # plots Recall/Precision as threshold changes 

198 label = f"{name} (F1={df.mean_f1_score[index]:.4f})" 

199 color = next(colorcycler) 

200 

201 if len(df) == 1: 

202 # plot black dot for F1-score at select threshold 

203 (marker,) = axes.plot( 

204 df.mean_recall[index], 

205 df.mean_precision[index], 

206 marker="*", 

207 markersize=6, 

208 color=color, 

209 alpha=0.8, 

210 linestyle="None", 

211 ) 

212 (line,) = axes.plot( 

213 df.mean_recall[index], 

214 df.mean_precision[index], 

215 linestyle="None", 

216 color=color, 

217 alpha=0.2, 

218 ) 

219 legend.append(([marker, line], label)) 

220 else: 

221 # line first, so marker gets on top 

222 style = next(linecycler) 

223 (line,) = axes.plot( 

224 ri[pi > 0], pi[pi > 0], color=color, linestyle=style 

225 ) 

226 (marker,) = axes.plot( 

227 df.mean_recall[index], 

228 df.mean_precision[index], 

229 marker="o", 

230 linestyle=style, 

231 markersize=4, 

232 color=color, 

233 alpha=0.8, 

234 ) 

235 legend.append(([marker, line], label)) 

236 

237 if limits: 

238 axes.set_xlim(limits[:2]) 

239 axes.set_ylim(limits[2:]) 

240 

241 if len(legend) > 1: 

242 axes.legend( 

243 [tuple(k[0]) for k in legend], 

244 [k[1] for k in legend], 

245 loc="lower left", 

246 fancybox=True, 

247 framealpha=0.7, 

248 ) 

249 

250 return fig 

251 

252 

253def loss_curve(df): 

254 """Creates a loss curve in a Matplotlib figure. 

255 

256 Parameters 

257 ---------- 

258 

259 df : :py:class:`pandas.DataFrame` 

260 A dataframe containing, at least, "epoch", "median-loss" and 

261 "learning-rate" columns, that will be plotted. 

262 

263 Returns 

264 ------- 

265 

266 figure : matplotlib.figure.Figure 

267 A figure, that may be saved or displayed 

268 """ 

269 

270 ax1 = df.plot(x="epoch", y="median-loss", grid=True) 

271 ax1.set_ylabel("Median Loss") 

272 ax1.grid(linestyle="--", linewidth=1, color="gray", alpha=0.2) 

273 ax2 = df["learning-rate"].plot( 

274 secondary_y=True, 

275 legend=True, 

276 grid=True, 

277 ) 

278 ax2.set_ylabel("Learning Rate") 

279 ax1.set_xlabel("Epoch") 

280 plt.tight_layout() 

281 fig = ax1.get_figure() 

282 return fig