1#!/usr/bin/env python
2# coding=utf-8
3
4
5import tabulate
6
7from .measure import auc
8
9
10def performance_table(data, fmt):
11 """Tables result comparison in a given format
12
13
14 Parameters
15 ----------
16
17 data : dict
18 A dictionary in which keys are strings defining plot labels and values
19 are dictionaries with two entries:
20
21 * ``df``: :py:class:`pandas.DataFrame`
22
23 A dataframe that is produced by our evaluator engine, indexed by
24 integer "thresholds", containing the following columns:
25 ``threshold``, ``tp``, ``fp``, ``tn``, ``fn``, ``mean_precision``,
26 ``mode_precision``, ``lower_precision``, ``upper_precision``,
27 ``mean_recall``, ``mode_recall``, ``lower_recall``, ``upper_recall``,
28 ``mean_specificity``, ``mode_specificity``, ``lower_specificity``,
29 ``upper_specificity``, ``mean_accuracy``, ``mode_accuracy``,
30 ``lower_accuracy``, ``upper_accuracy``, ``mean_jaccard``,
31 ``mode_jaccard``, ``lower_jaccard``, ``upper_jaccard``,
32 ``mean_f1_score``, ``mode_f1_score``, ``lower_f1_score``,
33 ``upper_f1_score``, ``frequentist_precision``,
34 ``frequentist_recall``, ``frequentist_specificity``,
35 ``frequentist_accuracy``, ``frequentist_jaccard``,
36 ``frequentist_f1_score``.
37
38 * ``threshold``: :py:class:`list`
39
40 A threshold to graph with a dot for each set. Specific
41 threshold values do not affect "second-annotator" dataframes.
42
43
44 fmt : str
45 One of the formats supported by tabulate.
46
47
48 Returns
49 -------
50
51 table : str
52 A table in a specific format
53
54 """
55
56 headers = [
57 "Dataset",
58 "T",
59 "E(F1)",
60 "CI(F1)",
61 "AUC",
62 "CI(AUC)",
63 ]
64
65 table = []
66 for k, v in data.items():
67 entry = [
68 k,
69 v["threshold"],
70 ]
71
72 # statistics based on the "assigned" threshold (a priori, less biased)
73 bins = len(v["df"])
74 index = int(round(bins * v["threshold"]))
75 index = min(index, len(v["df"]) - 1) # avoids out of range indexing
76 entry.append(v["df"].mean_f1_score[index])
77 entry.append(
78 f"{v['df'].lower_f1_score[index]:.3f}-{v['df'].upper_f1_score[index]:.3f}"
79 )
80
81 # AUC PR curve
82 entry.append(
83 auc(
84 v["df"]["mean_recall"].to_numpy(),
85 v["df"]["mean_precision"].to_numpy(),
86 )
87 )
88 lower_auc = auc(
89 v["df"]["lower_recall"].to_numpy(),
90 v["df"]["lower_precision"].to_numpy(),
91 )
92 upper_auc = auc(
93 v["df"]["upper_recall"].to_numpy(),
94 v["df"]["upper_precision"].to_numpy(),
95 )
96 entry.append(f"{lower_auc:.3f}-{upper_auc:.3f}")
97
98 table.append(entry)
99
100 return tabulate.tabulate(
101 table, headers, tablefmt=fmt, floatfmt=".3f", stralign="right"
102 )
103
104
105def performance_table_detection(data, fmt):
106 """Tables result comparison in a given format
107
108
109 Parameters
110 ----------
111
112 data : dict
113 A dictionary in which keys are strings defining plot labels and values
114 are dictionaries with two entries:
115
116 * ``df``: :py:class:`pandas.DataFrame`
117
118 A dataframe that is produced by our evaluator engine, indexed by
119 integer "thresholds", containing the following columns:
120 ``threshold``, ``iou``.
121
122 * ``threshold``: :py:class:`list`
123
124 A threshold to graph with a dot for each set. Specific
125 threshold values do not affect "second-annotator" dataframes.
126
127
128 fmt : str
129 One of the formats supported by tabulate.
130
131
132 Returns
133 -------
134
135 table : str
136 A table in a specific format
137
138 """
139
140 headers = [
141 "Dataset",
142 "T",
143 "E(IoU)",
144 "E(Intersection)",
145 "E(Intersection_Extension_5%)",
146 "E(Intersection_Extension_10%)",
147 ]
148
149 table = []
150 for k, v in data.items():
151 entry = [
152 k,
153 v["threshold"],
154 ]
155
156 bins = len(v["df"])
157 index = int(round(bins * v["threshold"]))
158 index = min(index, len(v["df"]) - 1) # avoids out of range indexing
159 entry.append(v["df"].mean_iou[index])
160 entry.append(v["df"]["mean_intersection"][index])
161 entry.append(v["df"]["mean_intersection_extension_5%"][index])
162 entry.append(v["df"]["mean_intersection_extension_10%"][index])
163
164 table.append(entry)
165
166 return tabulate.tabulate(
167 table, headers, tablefmt=fmt, floatfmt=".3f", stralign="right"
168 )