Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1#!/usr/bin/env python
2# coding=utf-8
4import logging
5import os
7import click
8import pandas
9import tabulate
11from bob.extension.scripts.click_helper import verbosity_option
13from ..utils.plot import precision_recall_f1iso
14from ..utils.table import performance_table
16logger = logging.getLogger(__name__)
19def _validate_threshold(t, dataset):
20 """Validates the user threshold selection. Returns parsed threshold."""
22 if t is None:
23 return t
25 try:
26 # we try to convert it to float first
27 t = float(t)
28 if t < 0.0 or t > 1.0:
29 raise ValueError("Float thresholds must be within range [0.0, 1.0]")
30 except ValueError:
31 # it is a bit of text - assert dataset with name is available
32 if not isinstance(dataset, dict):
33 raise ValueError(
34 "Threshold should be a floating-point number "
35 "if your provide only a single dataset for evaluation"
36 )
37 if t not in dataset:
38 raise ValueError(
39 f"Text thresholds should match dataset names, "
40 f"but {t} is not available among the datasets provided ("
41 f"({', '.join(dataset.keys())})"
42 )
44 return t
47def _load(data, threshold=None):
48 """Plots comparison chart of all evaluated models
50 Parameters
51 ----------
53 data : dict
54 A dict in which keys are the names of the systems and the values are
55 paths to ``measures.csv`` style files.
57 threshold : :py:class:`float`, :py:class:`str`, Optional
58 A value indicating which threshold to choose for selecting a "F1-score"
59 If set to ``None``, then use the maximum F1-score on that measures file.
60 If set to a floating-point value, then use the F1-score that is
61 obtained on that particular threshold. If set to a string, it should
62 match one of the keys in ``data``. It then first calculate the
63 threshold reaching the maximum F1-score on that particular dataset and
64 then applies that threshold to all other sets.
67 Returns
68 -------
70 data : dict
71 A dict in which keys are the names of the systems and the values are
72 dictionaries that contain two keys:
74 * ``df``: A :py:class:`pandas.DataFrame` with the measures data loaded
75 to
76 * ``threshold``: A threshold to be used for summarization, depending on
77 the ``threshold`` parameter set on the input
79 """
81 if isinstance(threshold, str):
82 logger.info(
83 f"Calculating threshold from maximum F1-score at "
84 f"'{threshold}' dataset..."
85 )
86 measures_path = data[threshold]
87 df = pandas.read_csv(measures_path)
88 use_threshold = df.threshold[df.mean_f1_score.idxmax()]
89 logger.info(f"Dataset '*': threshold = {use_threshold:.3f}'")
91 elif isinstance(threshold, float):
92 use_threshold = threshold
93 logger.info(f"Dataset '*': threshold = {use_threshold:.3f}'")
95 # loads all data
96 retval = {}
97 for name, measures_path in data.items():
99 logger.info(f"Loading measures from {measures_path}...")
100 df = pandas.read_csv(measures_path)
102 if threshold is None:
104 if "threshold_a_priori" in df:
105 use_threshold = df.threshold[df.threshold_a_priori.idxmax()]
106 logger.info(
107 f"Dataset '{name}': threshold (a priori) = "
108 f"{use_threshold:.3f}'"
109 )
110 else:
111 use_threshold = df.threshold[df.mean_f1_score.idxmax()]
112 logger.info(
113 f"Dataset '{name}': threshold (a posteriori) = "
114 f"{use_threshold:.3f}'"
115 )
117 retval[name] = dict(df=df, threshold=use_threshold)
119 return retval
122@click.command(
123 epilog="""Examples:
125\b
126 1. Compares system A and B, with their own pre-computed measure files:
127\b
128 $ bob binseg compare -vv A path/to/A/train.csv B path/to/B/test.csv
129""",
130)
131@click.argument(
132 "label_path",
133 nargs=-1,
134)
135@click.option(
136 "--output-figure",
137 "-f",
138 help="Path where write the output figure (any extension supported by "
139 "matplotlib is possible). If not provided, does not produce a figure.",
140 required=False,
141 default=None,
142 type=click.Path(dir_okay=False, file_okay=True),
143)
144@click.option(
145 "--table-format",
146 "-T",
147 help="The format to use for the comparison table",
148 show_default=True,
149 required=True,
150 default="rst",
151 type=click.Choice(tabulate.tabulate_formats),
152)
153@click.option(
154 "--output-table",
155 "-u",
156 help="Path where write the output table. If not provided, does not write "
157 "write a table to file, only to stdout.",
158 required=False,
159 default=None,
160 type=click.Path(dir_okay=False, file_okay=True),
161)
162@click.option(
163 "--threshold",
164 "-t",
165 help="This number is used to select which F1-score to use for "
166 "representing a system performance. If not set, we report the maximum "
167 "F1-score in the set, which is equivalent to threshold selection a "
168 "posteriori (biased estimator), unless the performance file being "
169 "considered already was pre-tunned, and contains a 'threshold_a_priori' "
170 "column which we then use to pick a threshold for the dataset. "
171 "You can override this behaviour by either setting this value to a "
172 "floating-point number in the range [0.0, 1.0], or to a string, naming "
173 "one of the systems which will be used to calculate the threshold "
174 "leading to the maximum F1-score and then applied to all other sets.",
175 default=None,
176 show_default=False,
177 required=False,
178)
179@verbosity_option()
180def compare(
181 label_path, output_figure, table_format, output_table, threshold, **kwargs
182):
183 """Compares multiple systems together"""
185 # hack to get a dictionary from arguments passed to input
186 if len(label_path) % 2 != 0:
187 raise click.ClickException(
188 "Input label-paths should be doubles"
189 " composed of name-path entries"
190 )
191 data = dict(zip(label_path[::2], label_path[1::2]))
193 threshold = _validate_threshold(threshold, data)
195 # load all data measures
196 data = _load(data, threshold=threshold)
198 if output_figure is not None:
199 output_figure = os.path.realpath(output_figure)
200 logger.info(f"Creating and saving plot at {output_figure}...")
201 os.makedirs(os.path.dirname(output_figure), exist_ok=True)
202 fig = precision_recall_f1iso(data, credible=True)
203 fig.savefig(output_figure)
205 logger.info("Tabulating performance summary...")
206 table = performance_table(data, table_format)
207 click.echo(table)
208 if output_table is not None:
209 output_table = os.path.realpath(output_table)
210 logger.info(f"Saving table at {output_table}...")
211 os.makedirs(os.path.dirname(output_table), exist_ok=True)
212 with open(output_table, "wt") as f:
213 f.write(table)