Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
4import contextlib
5import logging
7from itertools import cycle
9import matplotlib
10import matplotlib.pyplot as plt
11import numpy
13matplotlib.use("agg")
14logger = logging.getLogger(__name__)
17def _concave_hull(x, y, lx, ux, ly, uy):
18 """Calculates a approximate (concave) hull from arc centers and sizes
20 Each ellipse is approximated as a number of discrete points distributed
21 over the ellipse border following an homogeneous angle distribution.
24 Parameters
25 ----------
27 x : numpy.ndarray
28 1D array with x coordinates of ellipse centers
30 y : numpy.ndarray
31 1D array with y coordinates of ellipse centers
33 lx, ux, ly, uy : numpy.ndarray
34 1D array(s) with upper and lower widths and heights for your deformed
35 ellipse
38 Returns
39 -------
41 points : numpy.ndarray
42 2D array containing the ``(x, y)`` coordinates of the concave hull
43 encompassing all defined arcs.
45 """
47 def _irregular_ellipse_points(_x, _y, _lx, _ux, _ly, _uy, steps=100):
48 """Generates border points for an irregular ellipse
50 This functions distributes points according to a rotation angle rather
51 than uniformily with respect to a particular axis. The result is a
52 more homogeneous border representation for the ellipse.
53 """
54 up = _uy - _y
55 down = _y - _ly
56 left = _x - _lx
57 right = _ux - _x
59 angles = numpy.arange(0, numpy.pi / 2, step=2 * numpy.pi / steps)
60 points = numpy.ndarray((0, 2))
62 # upper left part (90 -> 180 degrees)
63 px = 2 * left * numpy.cos(angles)
64 py = (up / left) * numpy.sqrt(numpy.square(2 * left) - numpy.square(px))
65 # order: x and y increase
66 points = numpy.vstack((points, numpy.array([_x - px, _y + py]).T))
68 # upper right part (0 -> 90 degrees)
69 px = 2 * right * numpy.cos(angles)
70 py = (up / right) * numpy.sqrt(
71 numpy.square(2 * right) - numpy.square(px)
72 )
73 # order: x increases and y decreases
74 points = numpy.vstack(
75 (points, numpy.flipud(numpy.array([_x + px, _y + py]).T))
76 )
78 # lower right part (180 -> 270 degrees)
79 px = 2 * right * numpy.cos(angles)
80 py = (down / right) * numpy.sqrt(
81 numpy.square(2 * right) - numpy.square(px)
82 )
83 # order: x increases and y decreases
84 points = numpy.vstack((points, numpy.array([_x + px, _y - py]).T))
86 # lower left part (180 -> 270 degrees)
87 px = 2 * left * numpy.cos(angles)
88 py = (down / left) * numpy.sqrt(
89 numpy.square(2 * left) - numpy.square(px)
90 )
91 # order: x decreases and y increases
92 points = numpy.vstack(
93 (points, numpy.flipud(numpy.array([_x - px, _y - py]).T))
94 )
96 return points
98 retval = numpy.ndarray((0, 2))
99 for (k, l, m, n, o, p) in zip(x, y, lx, ux, ly, uy):
100 retval = numpy.vstack(
101 (
102 retval,
103 [numpy.nan, numpy.nan],
104 _irregular_ellipse_points(k, l, m, n, o, p),
105 )
106 )
107 return retval
110@contextlib.contextmanager
111def _precision_recall_canvas(title=None):
112 """Generates a canvas to draw precision-recall curves
114 Works like a context manager, yielding a figure and an axes set in which
115 the precision-recall curves should be added to. The figure already
116 contains F1-ISO lines and is preset to a 0-1 square region. Once the
117 context is finished, ``fig.tight_layout()`` is called.
120 Parameters
121 ----------
123 title : :py:class:`str`, Optional
124 Optional title to add to this plot
127 Yields
128 ------
130 figure : matplotlib.figure.Figure
131 The figure that should be finally returned to the user
133 axes : matplotlib.figure.Axes
134 An axis set where to precision-recall plots should be added to
136 """
138 fig, axes1 = plt.subplots(1)
140 # Names and bounds
141 axes1.set_xlabel("Recall")
142 axes1.set_ylabel("Precision")
143 axes1.set_xlim([0.0, 1.0])
144 axes1.set_ylim([0.0, 1.0])
146 if title is not None:
147 axes1.set_title(title)
149 axes1.grid(linestyle="--", linewidth=1, color="gray", alpha=0.2)
150 axes2 = axes1.twinx()
152 # Annotates plot with F1-score iso-lines
153 f_scores = numpy.linspace(0.1, 0.9, num=9)
154 tick_locs = []
155 tick_labels = []
156 for f_score in f_scores:
157 x = numpy.linspace(0.01, 1)
158 y = f_score * x / (2 * x - f_score)
159 (l,) = plt.plot(x[y >= 0], y[y >= 0], color="green", alpha=0.1)
160 tick_locs.append(y[-1])
161 tick_labels.append("%.1f" % f_score)
162 axes2.tick_params(axis="y", which="both", pad=0, right=False, left=False)
163 axes2.set_ylabel("iso-F", color="green", alpha=0.3)
164 axes2.set_ylim([0.0, 1.0])
165 axes2.yaxis.set_label_coords(1.015, 0.97)
166 axes2.set_yticks(tick_locs) # notice these are invisible
167 for k in axes2.set_yticklabels(tick_labels):
168 k.set_color("green")
169 k.set_alpha(0.3)
170 k.set_size(8)
172 # we should see some of axes 1 axes
173 axes1.spines["right"].set_visible(False)
174 axes1.spines["top"].set_visible(False)
175 axes1.spines["left"].set_position(("data", -0.015))
176 axes1.spines["bottom"].set_position(("data", -0.015))
178 # we shouldn't see any of axes 2 axes
179 axes2.spines["right"].set_visible(False)
180 axes2.spines["top"].set_visible(False)
181 axes2.spines["left"].set_visible(False)
182 axes2.spines["bottom"].set_visible(False)
184 # yield execution, lets user draw precision-recall plots, and the legend
185 # before tighteneing the layout
186 yield fig, axes1
188 plt.tight_layout()
191def precision_recall_f1iso(data, credible=True):
192 """Creates a precision-recall plot with credible intervals
194 This function creates and returns a Matplotlib figure with a
195 precision-recall plot containing shaded credible intervals (on the
196 precision-recall measurements). The plot will be annotated with F1-score
197 iso-lines (in which the F1-score maintains the same value).
199 This function specially supports "second-annotator" entries by plotting a
200 line showing the comparison between the default annotator being analyzed
201 and a second "opinion". Second annotator dataframes contain a single entry
202 (threshold=0.5), given the nature of the binary map comparisons.
205 Parameters
206 ----------
208 data : dict
209 A dictionary in which keys are strings defining plot labels and values
210 are dictionaries with two entries:
212 * ``df``: :py:class:`pandas.DataFrame`
214 A dataframe that is produced by our evaluator engine, indexed by
215 integer "thresholds", containing the following columns:
216 ``threshold``, ``tp``, ``fp``, ``tn``, ``fn``, ``mean_precision``,
217 ``mode_precision``, ``lower_precision``, ``upper_precision``,
218 ``mean_recall``, ``mode_recall``, ``lower_recall``, ``upper_recall``,
219 ``mean_specificity``, ``mode_specificity``, ``lower_specificity``,
220 ``upper_specificity``, ``mean_accuracy``, ``mode_accuracy``,
221 ``lower_accuracy``, ``upper_accuracy``, ``mean_jaccard``,
222 ``mode_jaccard``, ``lower_jaccard``, ``upper_jaccard``,
223 ``mean_f1_score``, ``mode_f1_score``, ``lower_f1_score``,
224 ``upper_f1_score``, ``frequentist_precision``,
225 ``frequentist_recall``, ``frequentist_specificity``,
226 ``frequentist_accuracy``, ``frequentist_jaccard``,
227 ``frequentist_f1_score``.
229 * ``threshold``: :py:class:`list`
231 A threshold to graph with a dot for each set. Specific
232 threshold values do not affect "second-annotator" dataframes.
234 credible : :py:class:`bool`, Optional
235 If set, draw credible intervals for each line, using ``upper_*`` and
236 ``lower_*`` entries.
239 Returns
240 -------
242 figure : matplotlib.figure.Figure
243 A matplotlib figure you can save or display (uses an ``agg`` backend)
245 """
247 lines = ["-", "--", "-.", ":"]
248 colors = [
249 "#1f77b4",
250 "#ff7f0e",
251 "#2ca02c",
252 "#d62728",
253 "#9467bd",
254 "#8c564b",
255 "#e377c2",
256 "#7f7f7f",
257 "#bcbd22",
258 "#17becf",
259 ]
260 colorcycler = cycle(colors)
261 linecycler = cycle(lines)
263 with _precision_recall_canvas(title=None) as (fig, axes):
265 legend = []
267 for name, value in data.items():
269 df = value["df"]
270 threshold = value["threshold"]
272 # plots only from the point where recall reaches its maximum,
273 # otherwise, we don't see a curve...
274 max_recall = df.mean_recall.idxmax()
275 pi = df.mean_precision[max_recall:]
276 ri = df.mean_recall[max_recall:]
277 # valid = (pi + ri) > 0
279 # optimal point along the curve
280 bins = len(df)
281 index = int(round(bins * threshold))
282 index = min(index, len(df) - 1) # avoids out of range indexing
284 # plots Recall/Precision as threshold changes
285 label = f"{name} (F1={df.mean_f1_score[index]:.4f})"
286 color = next(colorcycler)
288 if len(df) == 1:
289 # plot black dot for F1-score at select threshold
290 (marker,) = axes.plot(
291 df.mean_recall[index],
292 df.mean_precision[index],
293 marker="*",
294 markersize=6,
295 color=color,
296 alpha=0.8,
297 linestyle="None",
298 )
299 (line,) = axes.plot(
300 df.mean_recall[index],
301 df.mean_precision[index],
302 linestyle="None",
303 color=color,
304 alpha=0.2,
305 )
306 legend.append(([marker, line], label))
307 else:
308 # line first, so marker gets on top
309 style = next(linecycler)
310 (line,) = axes.plot(
311 ri[pi > 0], pi[pi > 0], color=color, linestyle=style
312 )
313 (marker,) = axes.plot(
314 df.mean_recall[index],
315 df.mean_precision[index],
316 marker="o",
317 linestyle=style,
318 markersize=4,
319 color=color,
320 alpha=0.8,
321 )
322 legend.append(([marker, line], label))
324 if credible:
326 hull = _concave_hull(
327 df.mean_recall,
328 df.mean_precision,
329 df.lower_recall,
330 df.upper_recall,
331 df.lower_precision,
332 df.upper_precision,
333 )
334 p = plt.Polygon(
335 hull,
336 facecolor=color,
337 alpha=0.2,
338 edgecolor="none",
339 lw=0.2,
340 closed=True,
341 )
342 axes.add_patch(p)
343 legend[-1][0].append(p)
345 if len(label) > 1:
346 axes.legend(
347 [tuple(k[0]) for k in legend],
348 [k[1] for k in legend],
349 loc="lower left",
350 fancybox=True,
351 framealpha=0.7,
352 )
354 return fig
357def loss_curve(df):
358 """Creates a loss curve in a Matplotlib figure.
360 Parameters
361 ----------
363 df : :py:class:`pandas.DataFrame`
364 A dataframe containing, at least, "epoch", "median-loss" and
365 "learning-rate" columns, that will be plotted.
367 Returns
368 -------
370 figure : matplotlib.figure.Figure
371 A figure, that may be saved or displayed
373 """
375 ax1 = df.plot(x="epoch", y="median-loss", grid=True)
376 ax1.set_ylabel("Median Loss")
377 ax1.grid(linestyle="--", linewidth=1, color="gray", alpha=0.2)
378 ax2 = df["learning-rate"].plot(
379 secondary_y=True,
380 legend=True,
381 grid=True,
382 )
383 ax2.set_ylabel("Learning Rate")
384 ax1.set_xlabel("Epoch")
385 plt.tight_layout()
386 fig = ax1.get_figure()
387 return fig