Coverage for /scratch/builds/bob/bob.ip.binseg/miniconda/conda-bld/bob.ip.binseg_1673966692152/_test_env_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_p/lib/python3.10/site-packages/bob/ip/common/utils/plot.py: 89%

89 statements  

« prev     ^ index     » next       coverage.py v7.0.5, created at 2023-01-17 15:03 +0000

1#!/usr/bin/env python 

2# -*- coding: utf-8 -*- 

3 

4import contextlib 

5import logging 

6 

7from itertools import cycle 

8 

9import matplotlib 

10import matplotlib.pyplot as plt 

11import numpy 

12 

13matplotlib.use("agg") 

14logger = logging.getLogger(__name__) 

15 

16 

17@contextlib.contextmanager 

18def _precision_recall_canvas(title=None, limits=[0.0, 1.0, 0.0, 1.0]): 

19 """Generates a canvas to draw precision-recall curves 

20 

21 Works like a context manager, yielding a figure and an axes set in which 

22 the precision-recall curves should be added to. The figure already 

23 contains F1-ISO lines and is preset to a 0-1 square region. Once the 

24 context is finished, ``fig.tight_layout()`` is called. 

25 

26 

27 Parameters 

28 ---------- 

29 

30 title : :py:class:`str`, Optional 

31 Optional title to add to this plot 

32 

33 limits : :py:class:`tuple`, Optional 

34 If set, a 4-tuple containing the bounds of the plot for the x and y 

35 axis respectively (format: ``[x_low, x_high, y_low, y_high]``). If not 

36 set, use normal bounds (``[0, 1, 0, 1]``). 

37 

38 

39 Yields 

40 ------ 

41 

42 figure : matplotlib.figure.Figure 

43 The figure that should be finally returned to the user 

44 

45 axes : matplotlib.figure.Axes 

46 An axis set where to precision-recall plots should be added to 

47 

48 """ 

49 

50 fig, axes1 = plt.subplots(1) 

51 

52 # Names and bounds 

53 axes1.set_xlabel("Recall") 

54 axes1.set_ylabel("Precision") 

55 axes1.set_xlim(limits[:2]) 

56 axes1.set_ylim(limits[2:]) 

57 

58 if title is not None: 

59 axes1.set_title(title) 

60 

61 axes1.grid(linestyle="--", linewidth=1, color="gray", alpha=0.2) 

62 axes2 = axes1.twinx() 

63 

64 # Annotates plot with F1-score iso-lines 

65 f_scores = numpy.linspace(0.1, 0.9, num=9) 

66 tick_locs = [] 

67 tick_labels = [] 

68 for f_score in f_scores: 

69 x = numpy.linspace(0.01, 1) 

70 y = f_score * x / (2 * x - f_score) 

71 (l,) = plt.plot(x[y >= 0], y[y >= 0], color="green", alpha=0.1) 

72 tick_locs.append(y[-1]) 

73 tick_labels.append("%.1f" % f_score) 

74 axes2.tick_params(axis="y", which="both", pad=0, right=False, left=False) 

75 axes2.set_ylabel("iso-F", color="green", alpha=0.3) 

76 axes2.set_ylim([0.0, 1.0]) 

77 axes2.yaxis.set_label_coords(1.015, 0.97) 

78 axes2.set_yticks(tick_locs) # notice these are invisible 

79 for k in axes2.set_yticklabels(tick_labels): 

80 k.set_color("green") 

81 k.set_alpha(0.3) 

82 k.set_size(8) 

83 

84 # we should see some of axes 1 axes 

85 axes1.spines["right"].set_visible(False) 

86 axes1.spines["top"].set_visible(False) 

87 axes1.spines["left"].set_position( 

88 ("data", limits[0] - (0.015 * (limits[1] - limits[0]))) 

89 ) 

90 axes1.spines["bottom"].set_position( 

91 ("data", limits[2] - (0.015 * (limits[3] - limits[2]))) 

92 ) 

93 

94 # we shouldn't see any of axes 2 axes 

95 axes2.spines["right"].set_visible(False) 

96 axes2.spines["top"].set_visible(False) 

97 axes2.spines["left"].set_visible(False) 

98 axes2.spines["bottom"].set_visible(False) 

99 

100 # yield execution, lets user draw precision-recall plots, and the legend 

101 # before tighteneing the layout 

102 yield fig, axes1 

103 

104 plt.tight_layout() 

105 

106 

107def precision_recall_f1iso(data, limits): 

108 """Creates a precision-recall plot 

109 

110 This function creates and returns a Matplotlib figure with a 

111 precision-recall plot. The plot will be annotated with F1-score iso-lines 

112 (in which the F1-score maintains the same value). 

113 

114 This function specially supports "second-annotator" entries by plotting a 

115 line showing the comparison between the default annotator being analyzed 

116 and a second "opinion". Second annotator dataframes contain a single entry 

117 (threshold=0.5), given the nature of the binary map comparisons. 

118 

119 

120 Parameters 

121 ---------- 

122 

123 data : dict 

124 A dictionary in which keys are strings defining plot labels and values 

125 are dictionaries with two entries: 

126 

127 * ``df``: :py:class:`pandas.DataFrame` 

128 

129 A dataframe that is produced by our evaluator engine, indexed by 

130 integer "thresholds", containing the following columns: 

131 ``threshold``, ``tp``, ``fp``, ``tn``, ``fn``, ``mean_precision``, 

132 ``mode_precision``, ``lower_precision``, ``upper_precision``, 

133 ``mean_recall``, ``mode_recall``, ``lower_recall``, ``upper_recall``, 

134 ``mean_specificity``, ``mode_specificity``, ``lower_specificity``, 

135 ``upper_specificity``, ``mean_accuracy``, ``mode_accuracy``, 

136 ``lower_accuracy``, ``upper_accuracy``, ``mean_jaccard``, 

137 ``mode_jaccard``, ``lower_jaccard``, ``upper_jaccard``, 

138 ``mean_f1_score``, ``mode_f1_score``, ``lower_f1_score``, 

139 ``upper_f1_score``, ``frequentist_precision``, 

140 ``frequentist_recall``, ``frequentist_specificity``, 

141 ``frequentist_accuracy``, ``frequentist_jaccard``, 

142 ``frequentist_f1_score``. 

143 

144 * ``threshold``: :py:class:`list` 

145 

146 A threshold to graph with a dot for each set. Specific 

147 threshold values do not affect "second-annotator" dataframes. 

148 

149 limits : tuple 

150 A 4-tuple containing the bounds of the plot for the x and y axis 

151 respectively (format: ``[x_low, x_high, y_low, y_high]``). If not set, 

152 use normal bounds (``[0, 1, 0, 1]``). 

153 

154 

155 Returns 

156 ------- 

157 

158 figure : matplotlib.figure.Figure 

159 A matplotlib figure you can save or display (uses an ``agg`` backend) 

160 

161 """ 

162 

163 lines = ["-", "--", "-.", ":"] 

164 colors = [ 

165 "#1f77b4", 

166 "#ff7f0e", 

167 "#2ca02c", 

168 "#d62728", 

169 "#9467bd", 

170 "#8c564b", 

171 "#e377c2", 

172 "#7f7f7f", 

173 "#bcbd22", 

174 "#17becf", 

175 ] 

176 colorcycler = cycle(colors) 

177 linecycler = cycle(lines) 

178 

179 with _precision_recall_canvas(title=None, limits=limits) as (fig, axes): 

180 

181 legend = [] 

182 

183 for name, value in data.items(): 

184 

185 df = value["df"] 

186 threshold = value["threshold"] 

187 

188 # plots only from the point where recall reaches its maximum, 

189 # otherwise, we don't see a curve... 

190 max_recall = df.mean_recall.idxmax() 

191 pi = df.mean_precision[max_recall:] 

192 ri = df.mean_recall[max_recall:] 

193 # valid = (pi + ri) > 0 

194 

195 # optimal point along the curve 

196 bins = len(df) 

197 index = int(round(bins * threshold)) 

198 index = min(index, len(df) - 1) # avoids out of range indexing 

199 

200 # plots Recall/Precision as threshold changes 

201 label = f"{name} (F1={df.mean_f1_score[index]:.4f})" 

202 color = next(colorcycler) 

203 

204 if len(df) == 1: 

205 # plot black dot for F1-score at select threshold 

206 (marker,) = axes.plot( 

207 df.mean_recall[index], 

208 df.mean_precision[index], 

209 marker="*", 

210 markersize=6, 

211 color=color, 

212 alpha=0.8, 

213 linestyle="None", 

214 ) 

215 (line,) = axes.plot( 

216 df.mean_recall[index], 

217 df.mean_precision[index], 

218 linestyle="None", 

219 color=color, 

220 alpha=0.2, 

221 ) 

222 legend.append(([marker, line], label)) 

223 else: 

224 # line first, so marker gets on top 

225 style = next(linecycler) 

226 (line,) = axes.plot( 

227 ri[pi > 0], pi[pi > 0], color=color, linestyle=style 

228 ) 

229 (marker,) = axes.plot( 

230 df.mean_recall[index], 

231 df.mean_precision[index], 

232 marker="o", 

233 linestyle=style, 

234 markersize=4, 

235 color=color, 

236 alpha=0.8, 

237 ) 

238 legend.append(([marker, line], label)) 

239 

240 if limits: 

241 axes.set_xlim(limits[:2]) 

242 axes.set_ylim(limits[2:]) 

243 

244 if len(legend) > 1: 

245 axes.legend( 

246 [tuple(k[0]) for k in legend], 

247 [k[1] for k in legend], 

248 loc="lower left", 

249 fancybox=True, 

250 framealpha=0.7, 

251 ) 

252 

253 return fig 

254 

255 

256def loss_curve(df): 

257 """Creates a loss curve in a Matplotlib figure. 

258 

259 Parameters 

260 ---------- 

261 

262 df : :py:class:`pandas.DataFrame` 

263 A dataframe containing, at least, "epoch", "median-loss" and 

264 "learning-rate" columns, that will be plotted. 

265 

266 Returns 

267 ------- 

268 

269 figure : matplotlib.figure.Figure 

270 A figure, that may be saved or displayed 

271 

272 """ 

273 

274 ax1 = df.plot(x="epoch", y="median-loss", grid=True) 

275 ax1.set_ylabel("Median Loss") 

276 ax1.grid(linestyle="--", linewidth=1, color="gray", alpha=0.2) 

277 ax2 = df["learning-rate"].plot( 

278 secondary_y=True, 

279 legend=True, 

280 grid=True, 

281 ) 

282 ax2.set_ylabel("Learning Rate") 

283 ax1.set_xlabel("Epoch") 

284 plt.tight_layout() 

285 fig = ax1.get_figure() 

286 return fig