Coverage for src/deepdraw/utils/plot.py: 89%
89 statements
« prev ^ index » next coverage.py v7.3.1, created at 2023-11-30 15:00 +0100
« prev ^ index » next coverage.py v7.3.1, created at 2023-11-30 15:00 +0100
1# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
2#
3# SPDX-License-Identifier: GPL-3.0-or-later
5import contextlib
6import logging
8from itertools import cycle
10import matplotlib
11import matplotlib.pyplot as plt
12import numpy
14matplotlib.use("agg")
15logger = logging.getLogger(__name__)
18@contextlib.contextmanager
19def _precision_recall_canvas(title=None, limits=[0.0, 1.0, 0.0, 1.0]):
20 """Generates a canvas to draw precision-recall curves.
22 Works like a context manager, yielding a figure and an axes set in which
23 the precision-recall curves should be added to. The figure already
24 contains F1-ISO lines and is preset to a 0-1 square region. Once the
25 context is finished, ``fig.tight_layout()`` is called.
28 Parameters
29 ----------
31 title : :py:class:`str`, Optional
32 Optional title to add to this plot
34 limits : :py:class:`tuple`, Optional
35 If set, a 4-tuple containing the bounds of the plot for the x and y
36 axis respectively (format: ``[x_low, x_high, y_low, y_high]``). If not
37 set, use normal bounds (``[0, 1, 0, 1]``).
40 Yields
41 ------
43 figure : matplotlib.figure.Figure
44 The figure that should be finally returned to the user
46 axes : matplotlib.figure.Axes
47 An axis set where to precision-recall plots should be added to
48 """
50 fig, axes1 = plt.subplots(1)
52 # Names and bounds
53 axes1.set_xlabel("Recall")
54 axes1.set_ylabel("Precision")
55 axes1.set_xlim(limits[:2])
56 axes1.set_ylim(limits[2:])
58 if title is not None:
59 axes1.set_title(title)
61 axes1.grid(linestyle="--", linewidth=1, color="gray", alpha=0.2)
62 axes2 = axes1.twinx()
64 # Annotates plot with F1-score iso-lines
65 f_scores = numpy.linspace(0.1, 0.9, num=9)
66 tick_locs = []
67 tick_labels = []
68 for f_score in f_scores:
69 x = numpy.linspace(0.01, 1)
70 y = f_score * x / (2 * x - f_score)
71 (l,) = plt.plot(x[y >= 0], y[y >= 0], color="green", alpha=0.1)
72 tick_locs.append(y[-1])
73 tick_labels.append("%.1f" % f_score)
74 axes2.tick_params(axis="y", which="both", pad=0, right=False, left=False)
75 axes2.set_ylabel("iso-F", color="green", alpha=0.3)
76 axes2.set_ylim([0.0, 1.0])
77 axes2.yaxis.set_label_coords(1.015, 0.97)
78 axes2.set_yticks(tick_locs) # notice these are invisible
79 for k in axes2.set_yticklabels(tick_labels):
80 k.set_color("green")
81 k.set_alpha(0.3)
82 k.set_size(8)
84 # we should see some of axes 1 axes
85 axes1.spines["right"].set_visible(False)
86 axes1.spines["top"].set_visible(False)
87 axes1.spines["left"].set_position(
88 ("data", limits[0] - (0.015 * (limits[1] - limits[0])))
89 )
90 axes1.spines["bottom"].set_position(
91 ("data", limits[2] - (0.015 * (limits[3] - limits[2])))
92 )
94 # we shouldn't see any of axes 2 axes
95 axes2.spines["right"].set_visible(False)
96 axes2.spines["top"].set_visible(False)
97 axes2.spines["left"].set_visible(False)
98 axes2.spines["bottom"].set_visible(False)
100 # yield execution, lets user draw precision-recall plots, and the legend
101 # before tighteneing the layout
102 yield fig, axes1
104 plt.tight_layout()
107def precision_recall_f1iso(data, limits):
108 """Creates a precision-recall plot.
110 This function creates and returns a Matplotlib figure with a
111 precision-recall plot. The plot will be annotated with F1-score iso-lines
112 (in which the F1-score maintains the same value).
114 This function specially supports "second-annotator" entries by plotting a
115 line showing the comparison between the default annotator being analyzed
116 and a second "opinion". Second annotator dataframes contain a single entry
117 (threshold=0.5), given the nature of the binary map comparisons.
120 Parameters
121 ----------
123 data : dict
124 A dictionary in which keys are strings defining plot labels and values
125 are dictionaries with two entries:
127 * ``df``: :py:class:`pandas.DataFrame`
129 A dataframe that is produced by our evaluator engine, indexed by
130 integer "thresholds", containing the following columns:
131 ``threshold``, ``tp``, ``fp``, ``tn``, ``fn``, ``mean_precision``,
132 ``mode_precision``, ``lower_precision``, ``upper_precision``,
133 ``mean_recall``, ``mode_recall``, ``lower_recall``, ``upper_recall``,
134 ``mean_specificity``, ``mode_specificity``, ``lower_specificity``,
135 ``upper_specificity``, ``mean_accuracy``, ``mode_accuracy``,
136 ``lower_accuracy``, ``upper_accuracy``, ``mean_jaccard``,
137 ``mode_jaccard``, ``lower_jaccard``, ``upper_jaccard``,
138 ``mean_f1_score``, ``mode_f1_score``, ``lower_f1_score``,
139 ``upper_f1_score``, ``frequentist_precision``,
140 ``frequentist_recall``, ``frequentist_specificity``,
141 ``frequentist_accuracy``, ``frequentist_jaccard``,
142 ``frequentist_f1_score``.
144 * ``threshold``: :py:class:`list`
146 A threshold to graph with a dot for each set. Specific
147 threshold values do not affect "second-annotator" dataframes.
149 limits : tuple
150 A 4-tuple containing the bounds of the plot for the x and y axis
151 respectively (format: ``[x_low, x_high, y_low, y_high]``). If not set,
152 use normal bounds (``[0, 1, 0, 1]``).
155 Returns
156 -------
158 figure : matplotlib.figure.Figure
159 A matplotlib figure you can save or display (uses an ``agg`` backend)
160 """
162 lines = ["-", "--", "-.", ":"]
163 colors = [
164 "#1f77b4",
165 "#ff7f0e",
166 "#2ca02c",
167 "#d62728",
168 "#9467bd",
169 "#8c564b",
170 "#e377c2",
171 "#7f7f7f",
172 "#bcbd22",
173 "#17becf",
174 ]
175 colorcycler = cycle(colors)
176 linecycler = cycle(lines)
178 with _precision_recall_canvas(title=None, limits=limits) as (fig, axes):
179 legend = []
181 for name, value in data.items():
182 df = value["df"]
183 threshold = value["threshold"]
185 # plots only from the point where recall reaches its maximum,
186 # otherwise, we don't see a curve...
187 max_recall = df.mean_recall.idxmax()
188 pi = df.mean_precision[max_recall:]
189 ri = df.mean_recall[max_recall:]
190 # valid = (pi + ri) > 0
192 # optimal point along the curve
193 bins = len(df)
194 index = int(round(bins * threshold))
195 index = min(index, len(df) - 1) # avoids out of range indexing
197 # plots Recall/Precision as threshold changes
198 label = f"{name} (F1={df.mean_f1_score[index]:.4f})"
199 color = next(colorcycler)
201 if len(df) == 1:
202 # plot black dot for F1-score at select threshold
203 (marker,) = axes.plot(
204 df.mean_recall[index],
205 df.mean_precision[index],
206 marker="*",
207 markersize=6,
208 color=color,
209 alpha=0.8,
210 linestyle="None",
211 )
212 (line,) = axes.plot(
213 df.mean_recall[index],
214 df.mean_precision[index],
215 linestyle="None",
216 color=color,
217 alpha=0.2,
218 )
219 legend.append(([marker, line], label))
220 else:
221 # line first, so marker gets on top
222 style = next(linecycler)
223 (line,) = axes.plot(
224 ri[pi > 0], pi[pi > 0], color=color, linestyle=style
225 )
226 (marker,) = axes.plot(
227 df.mean_recall[index],
228 df.mean_precision[index],
229 marker="o",
230 linestyle=style,
231 markersize=4,
232 color=color,
233 alpha=0.8,
234 )
235 legend.append(([marker, line], label))
237 if limits:
238 axes.set_xlim(limits[:2])
239 axes.set_ylim(limits[2:])
241 if len(legend) > 1:
242 axes.legend(
243 [tuple(k[0]) for k in legend],
244 [k[1] for k in legend],
245 loc="lower left",
246 fancybox=True,
247 framealpha=0.7,
248 )
250 return fig
253def loss_curve(df):
254 """Creates a loss curve in a Matplotlib figure.
256 Parameters
257 ----------
259 df : :py:class:`pandas.DataFrame`
260 A dataframe containing, at least, "epoch", "median-loss" and
261 "learning-rate" columns, that will be plotted.
263 Returns
264 -------
266 figure : matplotlib.figure.Figure
267 A figure, that may be saved or displayed
268 """
270 ax1 = df.plot(x="epoch", y="median-loss", grid=True)
271 ax1.set_ylabel("Median Loss")
272 ax1.grid(linestyle="--", linewidth=1, color="gray", alpha=0.2)
273 ax2 = df["learning-rate"].plot(
274 secondary_y=True,
275 legend=True,
276 grid=True,
277 )
278 ax2.set_ylabel("Learning Rate")
279 ax1.set_xlabel("Epoch")
280 plt.tight_layout()
281 fig = ax1.get_figure()
282 return fig