1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3
4import contextlib
5import logging
6
7from itertools import cycle
8
9import matplotlib
10import matplotlib.pyplot as plt
11import numpy
12
13matplotlib.use("agg")
14logger = logging.getLogger(__name__)
15
16
17@contextlib.contextmanager
18def _precision_recall_canvas(title=None, limits=[0.0, 1.0, 0.0, 1.0]):
19 """Generates a canvas to draw precision-recall curves
20
21 Works like a context manager, yielding a figure and an axes set in which
22 the precision-recall curves should be added to. The figure already
23 contains F1-ISO lines and is preset to a 0-1 square region. Once the
24 context is finished, ``fig.tight_layout()`` is called.
25
26
27 Parameters
28 ----------
29
30 title : :py:class:`str`, Optional
31 Optional title to add to this plot
32
33 limits : :py:class:`tuple`, Optional
34 If set, a 4-tuple containing the bounds of the plot for the x and y
35 axis respectively (format: ``[x_low, x_high, y_low, y_high]``). If not
36 set, use normal bounds (``[0, 1, 0, 1]``).
37
38
39 Yields
40 ------
41
42 figure : matplotlib.figure.Figure
43 The figure that should be finally returned to the user
44
45 axes : matplotlib.figure.Axes
46 An axis set where to precision-recall plots should be added to
47
48 """
49
50 fig, axes1 = plt.subplots(1)
51
52 # Names and bounds
53 axes1.set_xlabel("Recall")
54 axes1.set_ylabel("Precision")
55 axes1.set_xlim(limits[:2])
56 axes1.set_ylim(limits[2:])
57
58 if title is not None:
59 axes1.set_title(title)
60
61 axes1.grid(linestyle="--", linewidth=1, color="gray", alpha=0.2)
62 axes2 = axes1.twinx()
63
64 # Annotates plot with F1-score iso-lines
65 f_scores = numpy.linspace(0.1, 0.9, num=9)
66 tick_locs = []
67 tick_labels = []
68 for f_score in f_scores:
69 x = numpy.linspace(0.01, 1)
70 y = f_score * x / (2 * x - f_score)
71 (l,) = plt.plot(x[y >= 0], y[y >= 0], color="green", alpha=0.1)
72 tick_locs.append(y[-1])
73 tick_labels.append("%.1f" % f_score)
74 axes2.tick_params(axis="y", which="both", pad=0, right=False, left=False)
75 axes2.set_ylabel("iso-F", color="green", alpha=0.3)
76 axes2.set_ylim([0.0, 1.0])
77 axes2.yaxis.set_label_coords(1.015, 0.97)
78 axes2.set_yticks(tick_locs) # notice these are invisible
79 for k in axes2.set_yticklabels(tick_labels):
80 k.set_color("green")
81 k.set_alpha(0.3)
82 k.set_size(8)
83
84 # we should see some of axes 1 axes
85 axes1.spines["right"].set_visible(False)
86 axes1.spines["top"].set_visible(False)
87 axes1.spines["left"].set_position(
88 ("data", limits[0] - (0.015 * (limits[1] - limits[0])))
89 )
90 axes1.spines["bottom"].set_position(
91 ("data", limits[2] - (0.015 * (limits[3] - limits[2])))
92 )
93
94 # we shouldn't see any of axes 2 axes
95 axes2.spines["right"].set_visible(False)
96 axes2.spines["top"].set_visible(False)
97 axes2.spines["left"].set_visible(False)
98 axes2.spines["bottom"].set_visible(False)
99
100 # yield execution, lets user draw precision-recall plots, and the legend
101 # before tighteneing the layout
102 yield fig, axes1
103
104 plt.tight_layout()
105
106
107def precision_recall_f1iso(data, limits):
108 """Creates a precision-recall plot
109
110 This function creates and returns a Matplotlib figure with a
111 precision-recall plot. The plot will be annotated with F1-score iso-lines
112 (in which the F1-score maintains the same value).
113
114 This function specially supports "second-annotator" entries by plotting a
115 line showing the comparison between the default annotator being analyzed
116 and a second "opinion". Second annotator dataframes contain a single entry
117 (threshold=0.5), given the nature of the binary map comparisons.
118
119
120 Parameters
121 ----------
122
123 data : dict
124 A dictionary in which keys are strings defining plot labels and values
125 are dictionaries with two entries:
126
127 * ``df``: :py:class:`pandas.DataFrame`
128
129 A dataframe that is produced by our evaluator engine, indexed by
130 integer "thresholds", containing the following columns:
131 ``threshold``, ``tp``, ``fp``, ``tn``, ``fn``, ``mean_precision``,
132 ``mode_precision``, ``lower_precision``, ``upper_precision``,
133 ``mean_recall``, ``mode_recall``, ``lower_recall``, ``upper_recall``,
134 ``mean_specificity``, ``mode_specificity``, ``lower_specificity``,
135 ``upper_specificity``, ``mean_accuracy``, ``mode_accuracy``,
136 ``lower_accuracy``, ``upper_accuracy``, ``mean_jaccard``,
137 ``mode_jaccard``, ``lower_jaccard``, ``upper_jaccard``,
138 ``mean_f1_score``, ``mode_f1_score``, ``lower_f1_score``,
139 ``upper_f1_score``, ``frequentist_precision``,
140 ``frequentist_recall``, ``frequentist_specificity``,
141 ``frequentist_accuracy``, ``frequentist_jaccard``,
142 ``frequentist_f1_score``.
143
144 * ``threshold``: :py:class:`list`
145
146 A threshold to graph with a dot for each set. Specific
147 threshold values do not affect "second-annotator" dataframes.
148
149 limits : tuple
150 A 4-tuple containing the bounds of the plot for the x and y axis
151 respectively (format: ``[x_low, x_high, y_low, y_high]``). If not set,
152 use normal bounds (``[0, 1, 0, 1]``).
153
154
155 Returns
156 -------
157
158 figure : matplotlib.figure.Figure
159 A matplotlib figure you can save or display (uses an ``agg`` backend)
160
161 """
162
163 lines = ["-", "--", "-.", ":"]
164 colors = [
165 "#1f77b4",
166 "#ff7f0e",
167 "#2ca02c",
168 "#d62728",
169 "#9467bd",
170 "#8c564b",
171 "#e377c2",
172 "#7f7f7f",
173 "#bcbd22",
174 "#17becf",
175 ]
176 colorcycler = cycle(colors)
177 linecycler = cycle(lines)
178
179 with _precision_recall_canvas(title=None, limits=limits) as (fig, axes):
180
181 legend = []
182
183 for name, value in data.items():
184
185 df = value["df"]
186 threshold = value["threshold"]
187
188 # plots only from the point where recall reaches its maximum,
189 # otherwise, we don't see a curve...
190 max_recall = df.mean_recall.idxmax()
191 pi = df.mean_precision[max_recall:]
192 ri = df.mean_recall[max_recall:]
193 # valid = (pi + ri) > 0
194
195 # optimal point along the curve
196 bins = len(df)
197 index = int(round(bins * threshold))
198 index = min(index, len(df) - 1) # avoids out of range indexing
199
200 # plots Recall/Precision as threshold changes
201 label = f"{name} (F1={df.mean_f1_score[index]:.4f})"
202 color = next(colorcycler)
203
204 if len(df) == 1:
205 # plot black dot for F1-score at select threshold
206 (marker,) = axes.plot(
207 df.mean_recall[index],
208 df.mean_precision[index],
209 marker="*",
210 markersize=6,
211 color=color,
212 alpha=0.8,
213 linestyle="None",
214 )
215 (line,) = axes.plot(
216 df.mean_recall[index],
217 df.mean_precision[index],
218 linestyle="None",
219 color=color,
220 alpha=0.2,
221 )
222 legend.append(([marker, line], label))
223 else:
224 # line first, so marker gets on top
225 style = next(linecycler)
226 (line,) = axes.plot(
227 ri[pi > 0], pi[pi > 0], color=color, linestyle=style
228 )
229 (marker,) = axes.plot(
230 df.mean_recall[index],
231 df.mean_precision[index],
232 marker="o",
233 linestyle=style,
234 markersize=4,
235 color=color,
236 alpha=0.8,
237 )
238 legend.append(([marker, line], label))
239
240 if limits:
241 axes.set_xlim(limits[:2])
242 axes.set_ylim(limits[2:])
243
244 if len(legend) > 1:
245 axes.legend(
246 [tuple(k[0]) for k in legend],
247 [k[1] for k in legend],
248 loc="lower left",
249 fancybox=True,
250 framealpha=0.7,
251 )
252
253 return fig
254
255
256def loss_curve(df):
257 """Creates a loss curve in a Matplotlib figure.
258
259 Parameters
260 ----------
261
262 df : :py:class:`pandas.DataFrame`
263 A dataframe containing, at least, "epoch", "median-loss" and
264 "learning-rate" columns, that will be plotted.
265
266 Returns
267 -------
268
269 figure : matplotlib.figure.Figure
270 A figure, that may be saved or displayed
271
272 """
273
274 ax1 = df.plot(x="epoch", y="median-loss", grid=True)
275 ax1.set_ylabel("Median Loss")
276 ax1.grid(linestyle="--", linewidth=1, color="gray", alpha=0.2)
277 ax2 = df["learning-rate"].plot(
278 secondary_y=True,
279 legend=True,
280 grid=True,
281 )
282 ax2.set_ylabel("Learning Rate")
283 ax1.set_xlabel("Epoch")
284 plt.tight_layout()
285 fig = ax1.get_figure()
286 return fig