1#!/usr/bin/env python
2# coding=utf-8
3
4import os
5import click
6
7from bob.extension.scripts.click_helper import (
8 verbosity_option,
9 ConfigCommand,
10 ResourceOption,
11)
12
13from ..engine.evaluator import run
14
15import logging
16
17logger = logging.getLogger(__name__)
18
19
20def _validate_threshold(t, dataset):
21 """Validates the user threshold selection. Returns parsed threshold."""
22
23 if t is None:
24 return 0.5
25
26 try:
27 # we try to convert it to float first
28 t = float(t)
29 if t < 0.0 or t > 1.0:
30 raise ValueError("Float thresholds must be within range [0.0, 1.0]")
31 except ValueError:
32 # it is a bit of text - assert dataset with name is available
33 if not isinstance(dataset, dict):
34 raise ValueError(
35 "Threshold should be a floating-point number "
36 "if your provide only a single dataset for evaluation"
37 )
38 if t not in dataset:
39 raise ValueError(
40 f"Text thresholds should match dataset names, "
41 f"but {t} is not available among the datasets provided ("
42 f"({', '.join(dataset.keys())})"
43 )
44
45 return t
46
47
48@click.command(
49 entry_point_group="bob.med.tb.config",
50 cls=ConfigCommand,
51 epilog="""Examples:
52
53\b
54 1. Runs evaluation on an existing dataset configuration:
55\b
56 $ bob tb evaluate -vv montgomery --predictions-folder=path/to/predictions --output-folder=path/to/results
57""",
58)
59@click.option(
60 "--output-folder",
61 "-o",
62 help="Path where to store the analysis result (created if does not exist)",
63 required=True,
64 default="results",
65 type=click.Path(),
66 cls=ResourceOption,
67)
68@click.option(
69 "--predictions-folder",
70 "-p",
71 help="Path where predictions are currently stored",
72 required=True,
73 type=click.Path(exists=True, file_okay=False, dir_okay=True),
74 cls=ResourceOption,
75)
76@click.option(
77 "--dataset",
78 "-d",
79 help="A torch.utils.data.dataset.Dataset instance implementing a dataset "
80 "to be used for evaluation purposes, possibly including all pre-processing "
81 "pipelines required or, optionally, a dictionary mapping string keys to "
82 "torch.utils.data.dataset.Dataset instances. All keys that do not start "
83 "with an underscore (_) will be processed.",
84 required=True,
85 cls=ResourceOption,
86)
87@click.option(
88 "--threshold",
89 "-t",
90 help="This number is used to define positives and negatives from "
91 "probability maps, and report F1-scores (a priori). It "
92 "should either come from the training set or a separate validation set "
93 "to avoid biasing the analysis. Optionally, if you provide a multi-set "
94 "dataset as input, this may also be the name of an existing set from "
95 "which the threshold will be estimated (highest F1-score) and then "
96 "applied to the subsequent sets. This number is also used to print "
97 "the test set F1-score a priori performance",
98 default=None,
99 show_default=False,
100 required=False,
101 cls=ResourceOption,
102)
103@click.option(
104 "--steps",
105 "-S",
106 help="This number is used to define the number of threshold steps to "
107 "consider when evaluating the highest possible F1-score on test data.",
108 default=1000,
109 show_default=True,
110 required=True,
111 cls=ResourceOption,
112)
113@verbosity_option(cls=ResourceOption)
114def evaluate(
115 output_folder,
116 predictions_folder,
117 dataset,
118 threshold,
119 steps,
120 **kwargs,
121):
122 """Evaluates a CNN on a tuberculosis prediction task.
123
124 Note: batch size of 1 is required on the predictions.
125 """
126
127 threshold = _validate_threshold(threshold, dataset)
128
129 if not isinstance(dataset, dict):
130 dataset = {"test": dataset}
131
132 if isinstance(threshold, str):
133 # first run evaluation for reference dataset
134 logger.info(f"Evaluating threshold on '{threshold}' set")
135 f1_threshold, eer_threshold = run(
136 dataset[threshold], threshold, predictions_folder, steps=steps
137 )
138 if f1_threshold != None and eer_threshold != None:
139 logger.info(f"Set --f1_threshold={f1_threshold:.5f}")
140 logger.info(f"Set --eer_threshold={eer_threshold:.5f}")
141
142 # now run with the
143 for k, v in dataset.items():
144 if k.startswith("_"):
145 logger.info(f"Skipping dataset '{k}' (not to be evaluated)")
146 continue
147 logger.info(f"Analyzing '{k}' set...")
148 run(
149 v,
150 k,
151 predictions_folder,
152 output_folder,
153 f1_thresh=f1_threshold,
154 eer_thresh=eer_threshold,
155 steps=steps,
156 )