1#!/usr/bin/env python
2# coding=utf-8
3
4import os
5import click
6
7from bob.extension.scripts.click_helper import (
8 verbosity_option,
9 AliasedGroup,
10)
11
12import torch
13import re
14import pandas
15import tabulate
16from matplotlib.backends.backend_pdf import PdfPages
17
18from ..utils.plot import precision_recall_f1iso
19from ..utils.plot import roc_curve
20from ..utils.table import performance_table
21
22import logging
23logger = logging.getLogger(__name__)
24
25
26def _validate_threshold(t, dataset):
27 """Validates the user threshold selection. Returns parsed threshold."""
28
29 if t is None:
30 return t
31
32 # we try to convert it to float first
33 t = float(t)
34 if t < 0.0 or t > 1.0:
35 raise ValueError("Thresholds must be within range [0.0, 1.0]")
36
37 return t
38
39
40def _load(data, threshold):
41 """Plots comparison chart of all evaluated models
42
43 Parameters
44 ----------
45
46 data : dict
47 A dict in which keys are the names of the systems and the values are
48 paths to ``predictions.csv`` style files.
49
50 threshold : :py:class:`float`
51 A threshold for the final classification.
52
53
54 Returns
55 -------
56
57 data : dict
58 A dict in which keys are the names of the systems and the values are
59 dictionaries that contain two keys:
60
61 * ``df``: A :py:class:`pandas.DataFrame` with the predictions data
62 loaded to
63 * ``threshold``: The ``threshold`` parameter set on the input
64
65 """
66
67 use_threshold = threshold
68 logger.info(f"Dataset '*': threshold = {use_threshold:.3f}'")
69
70 # loads all data
71 retval = {}
72 for name, predictions_path in data.items():
73
74 # Load predictions
75 logger.info(f"Loading predictions from {predictions_path}...")
76 pred_data = pandas.read_csv(predictions_path)
77 pred = torch.Tensor([eval(re.sub(' +', ' ', x.replace('\n', '')).replace(' ', ',')) if isinstance(x, str) else x for x in pred_data['likelihood'].values]).double().flatten()
78 gt = torch.Tensor([eval(re.sub(' +', ' ', x.replace('\n', '')).replace(' ', ',')) if isinstance(x, str) else x for x in pred_data['ground_truth'].values]).double().flatten()
79
80 pred_data['likelihood'] = pred
81 pred_data['ground_truth'] = gt
82
83 retval[name] = dict(df=pred_data, threshold=use_threshold)
84
85 return retval
86
87
88@click.command(
89 epilog="""Examples:
90
91\b
92 1. Compares system A and B, with their own predictions files:
93\b
94 $ bob tb compare -vv A path/to/A/predictions.csv B path/to/B/predictions.csv
95""",
96)
97@click.argument(
98 'label_path',
99 nargs=-1,
100 )
101@click.option(
102 "--output-figure",
103 "-f",
104 help="Path where write the output figure (any extension supported by "
105 "matplotlib is possible). If not provided, does not produce a figure.",
106 required=False,
107 default=None,
108 type=click.Path(dir_okay=False, file_okay=True),
109)
110@click.option(
111 "--table-format",
112 "-T",
113 help="The format to use for the comparison table",
114 show_default=True,
115 required=True,
116 default="rst",
117 type=click.Choice(tabulate.tabulate_formats),
118)
119@click.option(
120 "--output-table",
121 "-u",
122 help="Path where write the output table. If not provided, does not write "
123 "write a table to file, only to stdout.",
124 required=False,
125 default=None,
126 type=click.Path(dir_okay=False, file_okay=True),
127)
128@click.option(
129 "--threshold",
130 "-t",
131 help="This number is used to separate positive and negative cases "
132 "by thresholding their score.",
133 default=None,
134 show_default=False,
135 required=False,
136)
137@verbosity_option()
138def compare(label_path, output_figure, table_format, output_table,
139 threshold, **kwargs):
140 """Compares multiple systems together"""
141
142 # hack to get a dictionary from arguments passed to input
143 if len(label_path) % 2 != 0:
144 raise click.ClickException("Input label-paths should be doubles"
145 " composed of name-path entries")
146 data = dict(zip(label_path[::2], label_path[1::2]))
147
148 threshold = _validate_threshold(threshold, data)
149
150 # load all data measures
151 data = _load(data, threshold=threshold)
152
153 if output_figure is not None:
154 output_figure = os.path.realpath(output_figure)
155 logger.info(f"Creating and saving plot at {output_figure}...")
156 os.makedirs(os.path.dirname(output_figure), exist_ok=True)
157 pdf = PdfPages(output_figure)
158 pdf.savefig(precision_recall_f1iso(data))
159 pdf.savefig(roc_curve(data))
160 pdf.close()
161
162 logger.info("Tabulating performance summary...")
163 table = performance_table(data, table_format)
164 click.echo(table)
165 if output_table is not None:
166 output_table = os.path.realpath(output_table)
167 logger.info(f"Saving table at {output_table}...")
168 os.makedirs(os.path.dirname(output_table), exist_ok=True)
169 with open(output_table, "wt") as f:
170 f.write(table)