1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3
4import contextlib
5from itertools import cycle
6
7import numpy
8import pandas
9from sklearn.metrics import auc, precision_recall_curve as pr_curve, roc_curve as r_curve
10
11import matplotlib
12matplotlib.use("agg")
13import matplotlib.pyplot as plt
14
15import logging
16
17logger = logging.getLogger(__name__)
18
19@contextlib.contextmanager
20def _precision_recall_canvas(title=None):
21 """Generates a canvas to draw precision-recall curves
22
23 Works like a context manager, yielding a figure and an axes set in which
24 the precision-recall curves should be added to. The figure already
25 contains F1-ISO lines and is preset to a 0-1 square region. Once the
26 context is finished, ``fig.tight_layout()`` is called.
27
28
29 Parameters
30 ----------
31
32 title : :py:class:`str`, Optional
33 Optional title to add to this plot
34
35
36 Yields
37 ------
38
39 figure : matplotlib.figure.Figure
40 The figure that should be finally returned to the user
41
42 axes : matplotlib.figure.Axes
43 An axis set where to precision-recall plots should be added to
44
45 """
46
47 fig, axes1 = plt.subplots(1)
48
49 # Names and bounds
50 axes1.set_xlabel("Recall")
51 axes1.set_ylabel("Precision")
52 axes1.set_xlim([0.0, 1.0])
53 axes1.set_ylim([0.0, 1.0])
54
55 if title is not None:
56 axes1.set_title(title)
57
58 axes1.grid(linestyle="--", linewidth=1, color="gray", alpha=0.2)
59 axes2 = axes1.twinx()
60
61 # Annotates plot with F1-score iso-lines
62 f_scores = numpy.linspace(0.1, 0.9, num=9)
63 tick_locs = []
64 tick_labels = []
65 for f_score in f_scores:
66 x = numpy.linspace(0.01, 1)
67 y = f_score * x / (2 * x - f_score)
68 (l,) = plt.plot(x[y >= 0], y[y >= 0], color="green", alpha=0.1)
69 tick_locs.append(y[-1])
70 tick_labels.append("%.1f" % f_score)
71 axes2.tick_params(axis="y", which="both", pad=0, right=False, left=False)
72 axes2.set_ylabel("iso-F", color="green", alpha=0.3)
73 axes2.set_ylim([0.0, 1.0])
74 axes2.yaxis.set_label_coords(1.015, 0.97)
75 axes2.set_yticks(tick_locs) # notice these are invisible
76 for k in axes2.set_yticklabels(tick_labels):
77 k.set_color("green")
78 k.set_alpha(0.3)
79 k.set_size(8)
80
81 # we should see some of axes 1 axes
82 axes1.spines["right"].set_visible(False)
83 axes1.spines["top"].set_visible(False)
84 axes1.spines["left"].set_position(("data", -0.015))
85 axes1.spines["bottom"].set_position(("data", -0.015))
86
87 # we shouldn't see any of axes 2 axes
88 axes2.spines["right"].set_visible(False)
89 axes2.spines["top"].set_visible(False)
90 axes2.spines["left"].set_visible(False)
91 axes2.spines["bottom"].set_visible(False)
92
93 # yield execution, lets user draw precision-recall plots, and the legend
94 # before tighteneing the layout
95 yield fig, axes1
96
97 plt.tight_layout()
98
99
100def precision_recall_f1iso(data):
101 """Creates a precision-recall plot
102
103 This function creates and returns a Matplotlib figure with a
104 precision-recall plot. The plot will be annotated with F1-score
105 iso-lines (in which the F1-score maintains the same value).
106
107
108 Parameters
109 ----------
110
111 data : dict
112 A dictionary in which keys are strings defining plot labels and values
113 are dictionaries with two entries:
114
115 * ``df``: :py:class:`pandas.DataFrame`
116
117 A dataframe that is produced by our predictor engine containing
118 the following columns: ``filename``, ``likelihood``,
119 ``ground_truth``.
120
121 * ``threshold``: :py:class:`list`
122
123 A threshold for each set. Not used here.
124
125
126 Returns
127 -------
128
129 figure : matplotlib.figure.Figure
130 A matplotlib figure you can save or display (uses an ``agg`` backend)
131
132 """
133
134 lines = ["-", "--", "-.", ":"]
135 colors = [
136 "#1f77b4",
137 "#ff7f0e",
138 "#2ca02c",
139 "#d62728",
140 "#9467bd",
141 "#8c564b",
142 "#e377c2",
143 "#7f7f7f",
144 "#bcbd22",
145 "#17becf",
146 ]
147 colorcycler = cycle(colors)
148 linecycler = cycle(lines)
149
150 with _precision_recall_canvas(title=None) as (fig, axes):
151
152 legend = []
153
154 for name, value in data.items():
155
156 df = value["df"]
157
158 # plots Recall/Precision curve
159 prec, recall, _ = pr_curve(df['ground_truth'], df['likelihood'])
160 _auc = auc(recall, prec)
161 label = f"{name} (AUC={_auc:.2f})"
162 color = next(colorcycler)
163 style = next(linecycler)
164
165 line, = axes.plot(
166 recall,
167 prec,
168 color=color,
169 linestyle=style
170 )
171 legend.append((line, label))
172
173 if len(label) > 1:
174 axes.legend(
175 [k[0] for k in legend],
176 [k[1] for k in legend],
177 loc="lower left",
178 fancybox=True,
179 framealpha=0.7,
180 )
181
182 return fig
183
184
185def roc_curve(data, title=None):
186 """Creates a ROC plot
187
188 This function creates and returns a Matplotlib figure with a
189 ROC plot.
190
191
192 Parameters
193 ----------
194
195 data : dict
196 A dictionary in which keys are strings defining plot labels and values
197 are dictionaries with two entries:
198
199 * ``df``: :py:class:`pandas.DataFrame`
200
201 A dataframe that is produced by our predictor engine containing
202 the following columns: ``filename``, ``likelihood``,
203 ``ground_truth``.
204
205 * ``threshold``: :py:class:`list`
206
207 A threshold for each set. Not used here.
208
209
210 Returns
211 -------
212
213 figure : matplotlib.figure.Figure
214 A matplotlib figure you can save or display (uses an ``agg`` backend)
215
216 """
217
218 fig, axes = plt.subplots(1)
219
220 # Names and bounds
221 axes.set_xlabel("1 - specificity")
222 axes.set_ylabel("Sensitivity")
223 axes.set_xlim([0.0, 1.0])
224 axes.set_ylim([0.0, 1.0])
225
226 # we should see some of axes 1 axes
227 axes.spines["right"].set_visible(False)
228 axes.spines["top"].set_visible(False)
229 axes.spines["left"].set_position(("data", -0.015))
230 axes.spines["bottom"].set_position(("data", -0.015))
231
232 if title is not None:
233 axes.set_title(title)
234
235 axes.grid(linestyle="--", linewidth=1, color="gray", alpha=0.2)
236
237 plt.tight_layout()
238
239 lines = ["-", "--", "-.", ":"]
240 colors = [
241 "#1f77b4",
242 "#ff7f0e",
243 "#2ca02c",
244 "#d62728",
245 "#9467bd",
246 "#8c564b",
247 "#e377c2",
248 "#7f7f7f",
249 "#bcbd22",
250 "#17becf",
251 ]
252 colorcycler = cycle(colors)
253 linecycler = cycle(lines)
254
255 legend = []
256
257 for name, value in data.items():
258
259 df = value["df"]
260
261 # plots roc curve
262 fpr, tpr, _ = r_curve(df['ground_truth'], df['likelihood'])
263 _auc = auc(fpr, tpr)
264 label = f"{name} (AUC={_auc:.2f})"
265 color = next(colorcycler)
266 style = next(linecycler)
267
268 line, = axes.plot(
269 fpr,
270 tpr,
271 color=color,
272 linestyle=style
273 )
274 legend.append((line, label))
275
276 if len(label) > 1:
277 axes.legend(
278 [k[0] for k in legend],
279 [k[1] for k in legend],
280 loc="lower right",
281 fancybox=True,
282 framealpha=0.7,
283 )
284
285 return fig
286
287
288def relevance_analysis_plot(data, title=None):
289 """Create an histogram plot to show the relative importance of features
290
291
292 Parameters
293 ----------
294
295 data : :py:class:`list`
296 The list of values (one for each feature)
297
298
299 Returns
300 -------
301
302 figure : matplotlib.figure.Figure
303 A matplotlib figure you can save or display (uses an ``agg`` backend)
304
305 """
306
307 fig, axes = plt.subplots(1, 1, figsize=(6,6))
308
309 # Names and bounds
310 axes.set_xlabel("Features")
311 axes.set_ylabel("Importance")
312
313 # we should see some of axes 1 axes
314 axes.spines["right"].set_visible(False)
315 axes.spines["top"].set_visible(False)
316
317 if title is not None:
318 axes.set_title(title)
319
320 #818C2E = likely
321 #F2921D = could be
322 #8C3503 = unlikely
323
324 labels = ['Cardiomegaly', 'Emphysema', 'Pleural effusion',
325 'Hernia', 'Infiltration', 'Mass', 'Nodule',
326 'Atelectasis', 'Pneumothorax', 'Pleural thickening',
327 'Pneumonia', 'Fibrosis', 'Edema', 'Consolidation']
328 bars = axes.bar(labels, data, color='#8C3503')
329
330 bars[2].set_color('#818C2E')
331 bars[4].set_color('#818C2E')
332 bars[10].set_color('#818C2E')
333 bars[5].set_color('#F2921D')
334 bars[6].set_color('#F2921D')
335 bars[7].set_color('#F2921D')
336 bars[11].set_color('#F2921D')
337 bars[13].set_color('#F2921D')
338
339 for tick in axes.get_xticklabels():
340 tick.set_rotation(90)
341
342 fig.tight_layout()
343
344 return fig