Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1#!/usr/bin/env python 

2# coding=utf-8 

3 

4import logging 

5import os 

6 

7import click 

8import pandas 

9import tabulate 

10 

11from bob.extension.scripts.click_helper import verbosity_option 

12 

13from ..utils.plot import precision_recall_f1iso 

14from ..utils.table import performance_table 

15 

16logger = logging.getLogger(__name__) 

17 

18 

19def _validate_threshold(t, dataset): 

20 """Validates the user threshold selection. Returns parsed threshold.""" 

21 

22 if t is None: 

23 return t 

24 

25 try: 

26 # we try to convert it to float first 

27 t = float(t) 

28 if t < 0.0 or t > 1.0: 

29 raise ValueError("Float thresholds must be within range [0.0, 1.0]") 

30 except ValueError: 

31 # it is a bit of text - assert dataset with name is available 

32 if not isinstance(dataset, dict): 

33 raise ValueError( 

34 "Threshold should be a floating-point number " 

35 "if your provide only a single dataset for evaluation" 

36 ) 

37 if t not in dataset: 

38 raise ValueError( 

39 f"Text thresholds should match dataset names, " 

40 f"but {t} is not available among the datasets provided (" 

41 f"({', '.join(dataset.keys())})" 

42 ) 

43 

44 return t 

45 

46 

47def _load(data, threshold=None): 

48 """Plots comparison chart of all evaluated models 

49 

50 Parameters 

51 ---------- 

52 

53 data : dict 

54 A dict in which keys are the names of the systems and the values are 

55 paths to ``measures.csv`` style files. 

56 

57 threshold : :py:class:`float`, :py:class:`str`, Optional 

58 A value indicating which threshold to choose for selecting a "F1-score" 

59 If set to ``None``, then use the maximum F1-score on that measures file. 

60 If set to a floating-point value, then use the F1-score that is 

61 obtained on that particular threshold. If set to a string, it should 

62 match one of the keys in ``data``. It then first calculate the 

63 threshold reaching the maximum F1-score on that particular dataset and 

64 then applies that threshold to all other sets. 

65 

66 

67 Returns 

68 ------- 

69 

70 data : dict 

71 A dict in which keys are the names of the systems and the values are 

72 dictionaries that contain two keys: 

73 

74 * ``df``: A :py:class:`pandas.DataFrame` with the measures data loaded 

75 to 

76 * ``threshold``: A threshold to be used for summarization, depending on 

77 the ``threshold`` parameter set on the input 

78 

79 """ 

80 

81 if isinstance(threshold, str): 

82 logger.info( 

83 f"Calculating threshold from maximum F1-score at " 

84 f"'{threshold}' dataset..." 

85 ) 

86 measures_path = data[threshold] 

87 df = pandas.read_csv(measures_path) 

88 use_threshold = df.threshold[df.mean_f1_score.idxmax()] 

89 logger.info(f"Dataset '*': threshold = {use_threshold:.3f}'") 

90 

91 elif isinstance(threshold, float): 

92 use_threshold = threshold 

93 logger.info(f"Dataset '*': threshold = {use_threshold:.3f}'") 

94 

95 # loads all data 

96 retval = {} 

97 for name, measures_path in data.items(): 

98 

99 logger.info(f"Loading measures from {measures_path}...") 

100 df = pandas.read_csv(measures_path) 

101 

102 if threshold is None: 

103 

104 if "threshold_a_priori" in df: 

105 use_threshold = df.threshold[df.threshold_a_priori.idxmax()] 

106 logger.info( 

107 f"Dataset '{name}': threshold (a priori) = " 

108 f"{use_threshold:.3f}'" 

109 ) 

110 else: 

111 use_threshold = df.threshold[df.mean_f1_score.idxmax()] 

112 logger.info( 

113 f"Dataset '{name}': threshold (a posteriori) = " 

114 f"{use_threshold:.3f}'" 

115 ) 

116 

117 retval[name] = dict(df=df, threshold=use_threshold) 

118 

119 return retval 

120 

121 

122@click.command( 

123 epilog="""Examples: 

124 

125\b 

126 1. Compares system A and B, with their own pre-computed measure files: 

127\b 

128 $ bob binseg compare -vv A path/to/A/train.csv B path/to/B/test.csv 

129""", 

130) 

131@click.argument( 

132 "label_path", 

133 nargs=-1, 

134) 

135@click.option( 

136 "--output-figure", 

137 "-f", 

138 help="Path where write the output figure (any extension supported by " 

139 "matplotlib is possible). If not provided, does not produce a figure.", 

140 required=False, 

141 default=None, 

142 type=click.Path(dir_okay=False, file_okay=True), 

143) 

144@click.option( 

145 "--table-format", 

146 "-T", 

147 help="The format to use for the comparison table", 

148 show_default=True, 

149 required=True, 

150 default="rst", 

151 type=click.Choice(tabulate.tabulate_formats), 

152) 

153@click.option( 

154 "--output-table", 

155 "-u", 

156 help="Path where write the output table. If not provided, does not write " 

157 "write a table to file, only to stdout.", 

158 required=False, 

159 default=None, 

160 type=click.Path(dir_okay=False, file_okay=True), 

161) 

162@click.option( 

163 "--threshold", 

164 "-t", 

165 help="This number is used to select which F1-score to use for " 

166 "representing a system performance. If not set, we report the maximum " 

167 "F1-score in the set, which is equivalent to threshold selection a " 

168 "posteriori (biased estimator), unless the performance file being " 

169 "considered already was pre-tunned, and contains a 'threshold_a_priori' " 

170 "column which we then use to pick a threshold for the dataset. " 

171 "You can override this behaviour by either setting this value to a " 

172 "floating-point number in the range [0.0, 1.0], or to a string, naming " 

173 "one of the systems which will be used to calculate the threshold " 

174 "leading to the maximum F1-score and then applied to all other sets.", 

175 default=None, 

176 show_default=False, 

177 required=False, 

178) 

179@verbosity_option() 

180def compare( 

181 label_path, output_figure, table_format, output_table, threshold, **kwargs 

182): 

183 """Compares multiple systems together""" 

184 

185 # hack to get a dictionary from arguments passed to input 

186 if len(label_path) % 2 != 0: 

187 raise click.ClickException( 

188 "Input label-paths should be doubles" 

189 " composed of name-path entries" 

190 ) 

191 data = dict(zip(label_path[::2], label_path[1::2])) 

192 

193 threshold = _validate_threshold(threshold, data) 

194 

195 # load all data measures 

196 data = _load(data, threshold=threshold) 

197 

198 if output_figure is not None: 

199 output_figure = os.path.realpath(output_figure) 

200 logger.info(f"Creating and saving plot at {output_figure}...") 

201 os.makedirs(os.path.dirname(output_figure), exist_ok=True) 

202 fig = precision_recall_f1iso(data, credible=True) 

203 fig.savefig(output_figure) 

204 

205 logger.info("Tabulating performance summary...") 

206 table = performance_table(data, table_format) 

207 click.echo(table) 

208 if output_table is not None: 

209 output_table = os.path.realpath(output_table) 

210 logger.info(f"Saving table at {output_table}...") 

211 os.makedirs(os.path.dirname(output_table), exist_ok=True) 

212 with open(output_table, "wt") as f: 

213 f.write(table)