1#!/usr/bin/env python
2# coding=utf-8
3
4import logging
5import os
6
7import click
8import pandas
9
10from tqdm import tqdm
11
12logger = logging.getLogger(__name__)
13
14
15def base_compare(
16 label_path,
17 output_figure,
18 output_table,
19 threshold,
20 plot_limits,
21 detection,
22 verbose,
23 table_format="rst",
24 **kwargs,
25):
26 """Compare multiple systems together."""
27
28 def _validate_threshold(t, dataset):
29 """Validate the user threshold selection. Returns parsed threshold."""
30 if t is None:
31 return t
32
33 try:
34 # we try to convert it to float first
35 t = float(t)
36 if t < 0.0 or t > 1.0:
37 raise ValueError(
38 "Float thresholds must be within range [0.0, 1.0]"
39 )
40 except ValueError:
41 # it is a bit of text - assert dataset with name is available
42 if not isinstance(dataset, dict):
43 raise ValueError(
44 "Threshold should be a floating-point number "
45 "if your provide only a single dataset for evaluation"
46 )
47 if t not in dataset:
48 raise ValueError(
49 f"Text thresholds should match dataset names, "
50 f"but {t} is not available among the datasets provided ("
51 f"({', '.join(dataset.keys())})"
52 )
53
54 return t
55
56 def _load(data, detection, threshold=None):
57 """Plot comparison chart of all evaluated models.
58
59 Parameters
60 ----------
61
62 data : dict
63 A dict in which keys are the names of the systems and the values are
64 paths to ``measures.csv`` style files.
65
66 threshold : :py:class:`float`, :py:class:`str`, Optional
67 A value indicating which threshold to choose for selecting a score.
68 If set to ``None``, then use the maximum F1-score on that measures file.
69 If set to a floating-point value, then use the score that is
70 obtained on that particular threshold. If set to a string, it should
71 match one of the keys in ``data``. It then first calculate the
72 threshold reaching the maximum score on that particular dataset and
73 then applies that threshold to all other sets. Obs: If the task
74 is segmentation, the score used is the F1-Score; for the detection
75 task the score used is the Intersection Over Union (IoU).
76
77
78 Returns
79 -------
80
81 data : dict
82 A dict in which keys are the names of the systems and the values are
83 dictionaries that contain two keys:
84
85 * ``df``: A :py:class:`pandas.DataFrame` with the measures data loaded
86 to
87 * ``threshold``: A threshold to be used for summarization, depending on
88 the ``threshold`` parameter set on the input
89
90 """
91 if detection:
92 col_name = "mean_iou"
93 score_name = "IoU-score"
94
95 else:
96 col_name = "mean_f1_score"
97 score_name = "F1-score"
98
99 if isinstance(threshold, str):
100 logger.info(
101 f"Calculating threshold from maximum {score_name} at "
102 f"'{threshold}' dataset..."
103 )
104 measures_path = data[threshold]
105 df = pandas.read_csv(measures_path)
106 use_threshold = df.threshold[df[col_name].idxmax()]
107 logger.info(f"Dataset '*': threshold = {use_threshold:.3f}'")
108
109 elif isinstance(threshold, float):
110 use_threshold = threshold
111 logger.info(f"Dataset '*': threshold = {use_threshold:.3f}'")
112
113 # loads all data
114 retval = {}
115 for name, measures_path in tqdm(data.items(), desc="sample"):
116
117 logger.info(f"Loading measures from {measures_path}...")
118 df = pandas.read_csv(measures_path)
119
120 if threshold is None:
121
122 if "threshold_a_priori" in df:
123 use_threshold = df.threshold[df.threshold_a_priori.idxmax()]
124 logger.info(
125 f"Dataset '{name}': threshold (a priori) = "
126 f"{use_threshold:.3f}'"
127 )
128 else:
129 use_threshold = df.threshold[df[col_name].idxmax()]
130 logger.info(
131 f"Dataset '{name}': threshold (a posteriori) = "
132 f"{use_threshold:.3f}'"
133 )
134
135 retval[name] = dict(df=df, threshold=use_threshold)
136
137 return retval
138
139 # hack to get a dictionary from arguments passed to input
140 if len(label_path) % 2 != 0:
141 raise click.ClickException(
142 "Input label-paths should be doubles"
143 " composed of name-path entries"
144 )
145 data = dict(zip(label_path[::2], label_path[1::2]))
146
147 threshold = _validate_threshold(threshold, data)
148
149 # load all data measures
150 data = _load(data, detection=detection, threshold=threshold)
151
152 if detection:
153 from ..utils.table import (
154 performance_table_detection as performance_table,
155 )
156
157 else:
158 from ..utils.plot import precision_recall_f1iso
159 from ..utils.table import performance_table
160
161 if output_figure is not None:
162 output_figure = os.path.realpath(output_figure)
163 logger.info(f"Creating and saving plot at {output_figure}...")
164 os.makedirs(os.path.dirname(output_figure), exist_ok=True)
165 fig = precision_recall_f1iso(data, limits=plot_limits)
166 fig.savefig(output_figure)
167 fig.clear()
168
169 logger.info("Tabulating performance summary...")
170 table = performance_table(data, table_format)
171 click.echo(table)
172 if output_table is not None:
173 output_table = os.path.realpath(output_table)
174 logger.info(f"Saving table at {output_table}...")
175 os.makedirs(os.path.dirname(output_table), exist_ok=True)
176 with open(output_table, "wt") as f:
177 f.write(table)