Coverage for /scratch/builds/bob/bob.med.tb/miniconda/conda-bld/bob.med.tb_1637571489937/_test_env_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placeho/lib/python3.8/site-packages/bob/med/tb/utils/plot.py: 98%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

121 statements  

1#!/usr/bin/env python 

2# -*- coding: utf-8 -*- 

3 

4import contextlib 

5from itertools import cycle 

6 

7import numpy 

8import pandas 

9from sklearn.metrics import auc, precision_recall_curve as pr_curve, roc_curve as r_curve 

10 

11import matplotlib 

12matplotlib.use("agg") 

13import matplotlib.pyplot as plt 

14 

15import logging 

16 

17logger = logging.getLogger(__name__) 

18 

19@contextlib.contextmanager 

20def _precision_recall_canvas(title=None): 

21 """Generates a canvas to draw precision-recall curves 

22 

23 Works like a context manager, yielding a figure and an axes set in which 

24 the precision-recall curves should be added to. The figure already 

25 contains F1-ISO lines and is preset to a 0-1 square region. Once the 

26 context is finished, ``fig.tight_layout()`` is called. 

27 

28 

29 Parameters 

30 ---------- 

31 

32 title : :py:class:`str`, Optional 

33 Optional title to add to this plot 

34 

35 

36 Yields 

37 ------ 

38 

39 figure : matplotlib.figure.Figure 

40 The figure that should be finally returned to the user 

41 

42 axes : matplotlib.figure.Axes 

43 An axis set where to precision-recall plots should be added to 

44 

45 """ 

46 

47 fig, axes1 = plt.subplots(1) 

48 

49 # Names and bounds 

50 axes1.set_xlabel("Recall") 

51 axes1.set_ylabel("Precision") 

52 axes1.set_xlim([0.0, 1.0]) 

53 axes1.set_ylim([0.0, 1.0]) 

54 

55 if title is not None: 

56 axes1.set_title(title) 

57 

58 axes1.grid(linestyle="--", linewidth=1, color="gray", alpha=0.2) 

59 axes2 = axes1.twinx() 

60 

61 # Annotates plot with F1-score iso-lines 

62 f_scores = numpy.linspace(0.1, 0.9, num=9) 

63 tick_locs = [] 

64 tick_labels = [] 

65 for f_score in f_scores: 

66 x = numpy.linspace(0.01, 1) 

67 y = f_score * x / (2 * x - f_score) 

68 (l,) = plt.plot(x[y >= 0], y[y >= 0], color="green", alpha=0.1) 

69 tick_locs.append(y[-1]) 

70 tick_labels.append("%.1f" % f_score) 

71 axes2.tick_params(axis="y", which="both", pad=0, right=False, left=False) 

72 axes2.set_ylabel("iso-F", color="green", alpha=0.3) 

73 axes2.set_ylim([0.0, 1.0]) 

74 axes2.yaxis.set_label_coords(1.015, 0.97) 

75 axes2.set_yticks(tick_locs) # notice these are invisible 

76 for k in axes2.set_yticklabels(tick_labels): 

77 k.set_color("green") 

78 k.set_alpha(0.3) 

79 k.set_size(8) 

80 

81 # we should see some of axes 1 axes 

82 axes1.spines["right"].set_visible(False) 

83 axes1.spines["top"].set_visible(False) 

84 axes1.spines["left"].set_position(("data", -0.015)) 

85 axes1.spines["bottom"].set_position(("data", -0.015)) 

86 

87 # we shouldn't see any of axes 2 axes 

88 axes2.spines["right"].set_visible(False) 

89 axes2.spines["top"].set_visible(False) 

90 axes2.spines["left"].set_visible(False) 

91 axes2.spines["bottom"].set_visible(False) 

92 

93 # yield execution, lets user draw precision-recall plots, and the legend 

94 # before tighteneing the layout 

95 yield fig, axes1 

96 

97 plt.tight_layout() 

98 

99 

100def precision_recall_f1iso(data): 

101 """Creates a precision-recall plot 

102 

103 This function creates and returns a Matplotlib figure with a 

104 precision-recall plot. The plot will be annotated with F1-score  

105 iso-lines (in which the F1-score maintains the same value). 

106 

107 

108 Parameters 

109 ---------- 

110 

111 data : dict 

112 A dictionary in which keys are strings defining plot labels and values 

113 are dictionaries with two entries: 

114 

115 * ``df``: :py:class:`pandas.DataFrame` 

116 

117 A dataframe that is produced by our predictor engine containing  

118 the following columns: ``filename``, ``likelihood``,  

119 ``ground_truth``. 

120 

121 * ``threshold``: :py:class:`list` 

122 

123 A threshold for each set. Not used here. 

124 

125 

126 Returns 

127 ------- 

128 

129 figure : matplotlib.figure.Figure 

130 A matplotlib figure you can save or display (uses an ``agg`` backend) 

131 

132 """ 

133 

134 lines = ["-", "--", "-.", ":"] 

135 colors = [ 

136 "#1f77b4", 

137 "#ff7f0e", 

138 "#2ca02c", 

139 "#d62728", 

140 "#9467bd", 

141 "#8c564b", 

142 "#e377c2", 

143 "#7f7f7f", 

144 "#bcbd22", 

145 "#17becf", 

146 ] 

147 colorcycler = cycle(colors) 

148 linecycler = cycle(lines) 

149 

150 with _precision_recall_canvas(title=None) as (fig, axes): 

151 

152 legend = [] 

153 

154 for name, value in data.items(): 

155 

156 df = value["df"] 

157 

158 # plots Recall/Precision curve 

159 prec, recall, _ = pr_curve(df['ground_truth'], df['likelihood']) 

160 _auc = auc(recall, prec) 

161 label = f"{name} (AUC={_auc:.2f})" 

162 color = next(colorcycler) 

163 style = next(linecycler) 

164 

165 line, = axes.plot( 

166 recall, 

167 prec, 

168 color=color, 

169 linestyle=style 

170 ) 

171 legend.append((line, label)) 

172 

173 if len(label) > 1: 

174 axes.legend( 

175 [k[0] for k in legend], 

176 [k[1] for k in legend], 

177 loc="lower left", 

178 fancybox=True, 

179 framealpha=0.7, 

180 ) 

181 

182 return fig 

183 

184 

185def roc_curve(data, title=None): 

186 """Creates a ROC plot 

187 

188 This function creates and returns a Matplotlib figure with a 

189 ROC plot. 

190 

191 

192 Parameters 

193 ---------- 

194 

195 data : dict 

196 A dictionary in which keys are strings defining plot labels and values 

197 are dictionaries with two entries: 

198 

199 * ``df``: :py:class:`pandas.DataFrame` 

200 

201 A dataframe that is produced by our predictor engine containing  

202 the following columns: ``filename``, ``likelihood``,  

203 ``ground_truth``. 

204 

205 * ``threshold``: :py:class:`list` 

206 

207 A threshold for each set. Not used here. 

208 

209 

210 Returns 

211 ------- 

212 

213 figure : matplotlib.figure.Figure 

214 A matplotlib figure you can save or display (uses an ``agg`` backend) 

215 

216 """ 

217 

218 fig, axes = plt.subplots(1) 

219 

220 # Names and bounds 

221 axes.set_xlabel("1 - specificity") 

222 axes.set_ylabel("Sensitivity") 

223 axes.set_xlim([0.0, 1.0]) 

224 axes.set_ylim([0.0, 1.0]) 

225 

226 # we should see some of axes 1 axes 

227 axes.spines["right"].set_visible(False) 

228 axes.spines["top"].set_visible(False) 

229 axes.spines["left"].set_position(("data", -0.015)) 

230 axes.spines["bottom"].set_position(("data", -0.015)) 

231 

232 if title is not None: 

233 axes.set_title(title) 

234 

235 axes.grid(linestyle="--", linewidth=1, color="gray", alpha=0.2) 

236 

237 plt.tight_layout() 

238 

239 lines = ["-", "--", "-.", ":"] 

240 colors = [ 

241 "#1f77b4", 

242 "#ff7f0e", 

243 "#2ca02c", 

244 "#d62728", 

245 "#9467bd", 

246 "#8c564b", 

247 "#e377c2", 

248 "#7f7f7f", 

249 "#bcbd22", 

250 "#17becf", 

251 ] 

252 colorcycler = cycle(colors) 

253 linecycler = cycle(lines) 

254 

255 legend = [] 

256 

257 for name, value in data.items(): 

258 

259 df = value["df"] 

260 

261 # plots roc curve 

262 fpr, tpr, _ = r_curve(df['ground_truth'], df['likelihood']) 

263 _auc = auc(fpr, tpr) 

264 label = f"{name} (AUC={_auc:.2f})" 

265 color = next(colorcycler) 

266 style = next(linecycler) 

267 

268 line, = axes.plot( 

269 fpr, 

270 tpr, 

271 color=color, 

272 linestyle=style 

273 ) 

274 legend.append((line, label)) 

275 

276 if len(label) > 1: 

277 axes.legend( 

278 [k[0] for k in legend], 

279 [k[1] for k in legend], 

280 loc="lower right", 

281 fancybox=True, 

282 framealpha=0.7, 

283 ) 

284 

285 return fig 

286 

287 

288def relevance_analysis_plot(data, title=None): 

289 """Create an histogram plot to show the relative importance of features 

290 

291 

292 Parameters 

293 ---------- 

294 

295 data : :py:class:`list` 

296 The list of values (one for each feature) 

297 

298 

299 Returns 

300 ------- 

301 

302 figure : matplotlib.figure.Figure 

303 A matplotlib figure you can save or display (uses an ``agg`` backend) 

304 

305 """ 

306 

307 fig, axes = plt.subplots(1, 1, figsize=(6,6)) 

308 

309 # Names and bounds 

310 axes.set_xlabel("Features") 

311 axes.set_ylabel("Importance") 

312 

313 # we should see some of axes 1 axes 

314 axes.spines["right"].set_visible(False) 

315 axes.spines["top"].set_visible(False) 

316 

317 if title is not None: 

318 axes.set_title(title) 

319 

320 #818C2E = likely 

321 #F2921D = could be 

322 #8C3503 = unlikely 

323 

324 labels = ['Cardiomegaly', 'Emphysema', 'Pleural effusion', 

325 'Hernia', 'Infiltration', 'Mass', 'Nodule', 

326 'Atelectasis', 'Pneumothorax', 'Pleural thickening', 

327 'Pneumonia', 'Fibrosis', 'Edema', 'Consolidation'] 

328 bars = axes.bar(labels, data, color='#8C3503') 

329 

330 bars[2].set_color('#818C2E') 

331 bars[4].set_color('#818C2E') 

332 bars[10].set_color('#818C2E') 

333 bars[5].set_color('#F2921D') 

334 bars[6].set_color('#F2921D') 

335 bars[7].set_color('#F2921D') 

336 bars[11].set_color('#F2921D') 

337 bars[13].set_color('#F2921D') 

338 

339 for tick in axes.get_xticklabels(): 

340 tick.set_rotation(90) 

341 

342 fig.tight_layout() 

343 

344 return fig