Coverage for src/deepdraw/utils/table.py: 100%
18 statements
« prev ^ index » next coverage.py v7.3.1, created at 2023-11-30 15:00 +0100
« prev ^ index » next coverage.py v7.3.1, created at 2023-11-30 15:00 +0100
1# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
2#
3# SPDX-License-Identifier: GPL-3.0-or-later
5import tabulate
7from .measure import auc
10def performance_table(data, fmt):
11 """Tables result comparison in a given format.
13 Parameters
14 ----------
16 data : dict
17 A dictionary in which keys are strings defining plot labels and values
18 are dictionaries with two entries:
20 * ``df``: :py:class:`pandas.DataFrame`
22 A dataframe that is produced by our evaluator engine, indexed by
23 integer "thresholds", containing the following columns:
24 ``threshold``, ``tp``, ``fp``, ``tn``, ``fn``, ``mean_precision``,
25 ``mode_precision``, ``lower_precision``, ``upper_precision``,
26 ``mean_recall``, ``mode_recall``, ``lower_recall``, ``upper_recall``,
27 ``mean_specificity``, ``mode_specificity``, ``lower_specificity``,
28 ``upper_specificity``, ``mean_accuracy``, ``mode_accuracy``,
29 ``lower_accuracy``, ``upper_accuracy``, ``mean_jaccard``,
30 ``mode_jaccard``, ``lower_jaccard``, ``upper_jaccard``,
31 ``mean_f1_score``, ``mode_f1_score``, ``lower_f1_score``,
32 ``upper_f1_score``, ``frequentist_precision``,
33 ``frequentist_recall``, ``frequentist_specificity``,
34 ``frequentist_accuracy``, ``frequentist_jaccard``,
35 ``frequentist_f1_score``.
37 * ``threshold``: :py:class:`list`
39 A threshold to graph with a dot for each set. Specific
40 threshold values do not affect "second-annotator" dataframes.
43 fmt : str
44 One of the formats supported by tabulate.
47 Returns
48 -------
50 table : str
51 A table in a specific format
52 """
54 headers = [
55 "Dataset",
56 "T",
57 "E(F1)",
58 "CI(F1)",
59 "AUC",
60 "CI(AUC)",
61 ]
63 table = []
64 for k, v in data.items():
65 entry = [
66 k,
67 v["threshold"],
68 ]
70 # statistics based on the "assigned" threshold (a priori, less biased)
71 bins = len(v["df"])
72 index = int(round(bins * v["threshold"]))
73 index = min(index, len(v["df"]) - 1) # avoids out of range indexing
74 entry.append(v["df"].mean_f1_score[index])
75 entry.append(
76 f"{v['df'].lower_f1_score[index]:.3f}-{v['df'].upper_f1_score[index]:.3f}"
77 )
79 # AUC PR curve
80 entry.append(
81 auc(
82 v["df"]["mean_recall"].to_numpy(),
83 v["df"]["mean_precision"].to_numpy(),
84 )
85 )
86 lower_auc = auc(
87 v["df"]["lower_recall"].to_numpy(),
88 v["df"]["lower_precision"].to_numpy(),
89 )
90 upper_auc = auc(
91 v["df"]["upper_recall"].to_numpy(),
92 v["df"]["upper_precision"].to_numpy(),
93 )
94 entry.append(f"{lower_auc:.3f}-{upper_auc:.3f}")
96 table.append(entry)
98 return tabulate.tabulate(
99 table, headers, tablefmt=fmt, floatfmt=".3f", stralign="right"
100 )