Coverage for src/deepdraw/script/compare.py: 91%
78 statements
« prev ^ index » next coverage.py v7.3.1, created at 2023-11-30 15:00 +0100
« prev ^ index » next coverage.py v7.3.1, created at 2023-11-30 15:00 +0100
1# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
2#
3# SPDX-License-Identifier: GPL-3.0-or-later
5import os
7import click
8import pandas
9import tabulate
11from clapper.click import verbosity_option
12from clapper.logging import setup
13from tqdm import tqdm
15logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
18@click.command(
19 epilog="""Examples:
21\b
22 1. Compares system A and B, with their own pre-computed measure files:
24 .. code:: sh
26 $ deepdraw compare -vv A path/to/A/train.csv B path/to/B/test.csv
27""",
28)
29@click.argument(
30 "label_path",
31 nargs=-1,
32)
33@click.option(
34 "--output-figure",
35 "-f",
36 help="Path where write the output figure (any extension supported by "
37 "matplotlib is possible). If not provided, does not produce a figure.",
38 required=False,
39 default=None,
40 type=click.Path(dir_okay=False, file_okay=True),
41)
42@click.option(
43 "--table-format",
44 "-T",
45 help="The format to use for the comparison table",
46 show_default=True,
47 required=True,
48 default="rst",
49 type=click.Choice(tabulate.tabulate_formats),
50)
51@click.option(
52 "--output-table",
53 "-u",
54 help="Path where write the output table. If not provided, does not write "
55 "write a table to file, only to stdout.",
56 required=False,
57 default=None,
58 type=click.Path(dir_okay=False, file_okay=True),
59)
60@click.option(
61 "--threshold",
62 "-t",
63 help="This number is used to select which F1-score to use for "
64 "representing a system performance. If not set, we report the maximum "
65 "F1-score in the set, which is equivalent to threshold selection a "
66 "posteriori (biased estimator), unless the performance file being "
67 "considered already was pre-tunned, and contains a 'threshold_a_priori' "
68 "column which we then use to pick a threshold for the dataset. "
69 "You can override this behaviour by either setting this value to a "
70 "floating-point number in the range [0.0, 1.0], or to a string, naming "
71 "one of the systems which will be used to calculate the threshold "
72 "leading to the maximum F1-score and then applied to all other sets.",
73 default=None,
74 show_default=False,
75 required=False,
76)
77@click.option(
78 "--plot-limits",
79 "-L",
80 help="""If set, must be a 4-tuple containing the bounds of the plot for
81 the x and y axis respectively (format: x_low, x_high, y_low,
82 y_high]). If not set, use normal bounds ([0, 1, 0, 1]) for the
83 performance curve.""",
84 default=[0.0, 1.0, 0.0, 1.0],
85 show_default=True,
86 nargs=4,
87 type=float,
88)
89@verbosity_option(
90 logger=logger,
91)
92@click.pass_context
93def compare(
94 ctx,
95 label_path,
96 output_figure,
97 table_format,
98 output_table,
99 threshold,
100 plot_limits,
101 verbose,
102 **kwargs,
103):
104 def _validate_threshold(t, dataset):
105 """Validate the user threshold selection.
107 Returns parsed threshold.
108 """
109 if t is None:
110 return t
112 try:
113 # we try to convert it to float first
114 t = float(t)
115 if t < 0.0 or t > 1.0:
116 raise ValueError(
117 "Float thresholds must be within range [0.0, 1.0]"
118 )
119 except ValueError:
120 # it is a bit of text - assert dataset with name is available
121 if not isinstance(dataset, dict):
122 raise ValueError(
123 "Threshold should be a floating-point number "
124 "if your provide only a single dataset for evaluation"
125 )
126 if t not in dataset:
127 raise ValueError(
128 f"Text thresholds should match dataset names, "
129 f"but {t} is not available among the datasets provided ("
130 f"({', '.join(dataset.keys())})"
131 )
133 return t
135 def _load(data, threshold=None):
136 """Plot comparison chart of all evaluated models.
138 Parameters
139 ----------
141 data : dict
142 A dict in which keys are the names of the systems and the values are
143 paths to ``measures.csv`` style files.
145 threshold : :py:class:`float`, :py:class:`str`, Optional
146 A value indicating which threshold to choose for selecting a score.
147 If set to ``None``, then use the maximum F1-score on that measures file.
148 If set to a floating-point value, then use the score that is
149 obtained on that particular threshold. If set to a string, it should
150 match one of the keys in ``data``. It then first calculate the
151 threshold reaching the maximum score on that particular dataset and
152 then applies that threshold to all other sets. Obs: If the task
153 is segmentation, the score used is the F1-Score.
156 Returns
157 -------
159 data : dict
160 A dict in which keys are the names of the systems and the values are
161 dictionaries that contain two keys:
163 * ``df``: A :py:class:`pandas.DataFrame` with the measures data loaded
164 to
165 * ``threshold``: A threshold to be used for summarization, depending on
166 the ``threshold`` parameter set on the input
167 """
169 col_name = "mean_f1_score"
170 score_name = "F1-score"
172 if isinstance(threshold, str):
173 logger.info(
174 f"Calculating threshold from maximum {score_name} at "
175 f"'{threshold}' dataset..."
176 )
177 measures_path = data[threshold]
178 df = pandas.read_csv(measures_path)
179 use_threshold = df.threshold[df[col_name].idxmax()]
180 logger.info(f"Dataset '*': threshold = {use_threshold:.3f}'")
182 elif isinstance(threshold, float):
183 use_threshold = threshold
184 logger.info(f"Dataset '*': threshold = {use_threshold:.3f}'")
186 # loads all data
187 retval = {}
188 for name, measures_path in tqdm(data.items(), desc="sample"):
189 logger.info(f"Loading measures from {measures_path}...")
190 df = pandas.read_csv(measures_path)
192 if threshold is None:
193 if "threshold_a_priori" in df:
194 use_threshold = df.threshold[df.threshold_a_priori.idxmax()]
195 logger.info(
196 f"Dataset '{name}': threshold (a priori) = "
197 f"{use_threshold:.3f}'"
198 )
199 else:
200 use_threshold = df.threshold[df[col_name].idxmax()]
201 logger.info(
202 f"Dataset '{name}': threshold (a posteriori) = "
203 f"{use_threshold:.3f}'"
204 )
206 retval[name] = dict(df=df, threshold=use_threshold)
208 return retval
210 # hack to get a dictionary from arguments passed to input
211 if len(label_path) % 2 != 0:
212 raise click.ClickException(
213 "Input label-paths should be doubles"
214 " composed of name-path entries"
215 )
216 data = dict(zip(label_path[::2], label_path[1::2]))
218 threshold = _validate_threshold(threshold, data)
220 # load all data measures
221 data = _load(data, threshold=threshold)
223 from ..utils.plot import precision_recall_f1iso
224 from ..utils.table import performance_table
226 if output_figure is not None:
227 output_figure = os.path.realpath(output_figure)
228 logger.info(f"Creating and saving plot at {output_figure}...")
229 os.makedirs(os.path.dirname(output_figure), exist_ok=True)
230 fig = precision_recall_f1iso(data, limits=plot_limits)
231 fig.savefig(output_figure)
232 fig.clear()
234 logger.info("Tabulating performance summary...")
235 table = performance_table(data, table_format)
236 click.echo(table)
237 if output_table is not None:
238 output_table = os.path.realpath(output_table)
239 logger.info(f"Saving table at {output_table}...")
240 os.makedirs(os.path.dirname(output_table), exist_ok=True)
241 with open(output_table, "w") as f:
242 f.write(table)