Coverage for src/deepdraw/script/compare.py: 91%

78 statements  

« prev     ^ index     » next       coverage.py v7.3.1, created at 2023-11-30 15:00 +0100

1# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> 

2# 

3# SPDX-License-Identifier: GPL-3.0-or-later 

4 

5import os 

6 

7import click 

8import pandas 

9import tabulate 

10 

11from clapper.click import verbosity_option 

12from clapper.logging import setup 

13from tqdm import tqdm 

14 

15logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") 

16 

17 

18@click.command( 

19 epilog="""Examples: 

20 

21\b 

22 1. Compares system A and B, with their own pre-computed measure files: 

23 

24 .. code:: sh 

25 

26 $ deepdraw compare -vv A path/to/A/train.csv B path/to/B/test.csv 

27""", 

28) 

29@click.argument( 

30 "label_path", 

31 nargs=-1, 

32) 

33@click.option( 

34 "--output-figure", 

35 "-f", 

36 help="Path where write the output figure (any extension supported by " 

37 "matplotlib is possible). If not provided, does not produce a figure.", 

38 required=False, 

39 default=None, 

40 type=click.Path(dir_okay=False, file_okay=True), 

41) 

42@click.option( 

43 "--table-format", 

44 "-T", 

45 help="The format to use for the comparison table", 

46 show_default=True, 

47 required=True, 

48 default="rst", 

49 type=click.Choice(tabulate.tabulate_formats), 

50) 

51@click.option( 

52 "--output-table", 

53 "-u", 

54 help="Path where write the output table. If not provided, does not write " 

55 "write a table to file, only to stdout.", 

56 required=False, 

57 default=None, 

58 type=click.Path(dir_okay=False, file_okay=True), 

59) 

60@click.option( 

61 "--threshold", 

62 "-t", 

63 help="This number is used to select which F1-score to use for " 

64 "representing a system performance. If not set, we report the maximum " 

65 "F1-score in the set, which is equivalent to threshold selection a " 

66 "posteriori (biased estimator), unless the performance file being " 

67 "considered already was pre-tunned, and contains a 'threshold_a_priori' " 

68 "column which we then use to pick a threshold for the dataset. " 

69 "You can override this behaviour by either setting this value to a " 

70 "floating-point number in the range [0.0, 1.0], or to a string, naming " 

71 "one of the systems which will be used to calculate the threshold " 

72 "leading to the maximum F1-score and then applied to all other sets.", 

73 default=None, 

74 show_default=False, 

75 required=False, 

76) 

77@click.option( 

78 "--plot-limits", 

79 "-L", 

80 help="""If set, must be a 4-tuple containing the bounds of the plot for 

81 the x and y axis respectively (format: x_low, x_high, y_low, 

82 y_high]). If not set, use normal bounds ([0, 1, 0, 1]) for the 

83 performance curve.""", 

84 default=[0.0, 1.0, 0.0, 1.0], 

85 show_default=True, 

86 nargs=4, 

87 type=float, 

88) 

89@verbosity_option( 

90 logger=logger, 

91) 

92@click.pass_context 

93def compare( 

94 ctx, 

95 label_path, 

96 output_figure, 

97 table_format, 

98 output_table, 

99 threshold, 

100 plot_limits, 

101 verbose, 

102 **kwargs, 

103): 

104 def _validate_threshold(t, dataset): 

105 """Validate the user threshold selection. 

106 

107 Returns parsed threshold. 

108 """ 

109 if t is None: 

110 return t 

111 

112 try: 

113 # we try to convert it to float first 

114 t = float(t) 

115 if t < 0.0 or t > 1.0: 

116 raise ValueError( 

117 "Float thresholds must be within range [0.0, 1.0]" 

118 ) 

119 except ValueError: 

120 # it is a bit of text - assert dataset with name is available 

121 if not isinstance(dataset, dict): 

122 raise ValueError( 

123 "Threshold should be a floating-point number " 

124 "if your provide only a single dataset for evaluation" 

125 ) 

126 if t not in dataset: 

127 raise ValueError( 

128 f"Text thresholds should match dataset names, " 

129 f"but {t} is not available among the datasets provided (" 

130 f"({', '.join(dataset.keys())})" 

131 ) 

132 

133 return t 

134 

135 def _load(data, threshold=None): 

136 """Plot comparison chart of all evaluated models. 

137 

138 Parameters 

139 ---------- 

140 

141 data : dict 

142 A dict in which keys are the names of the systems and the values are 

143 paths to ``measures.csv`` style files. 

144 

145 threshold : :py:class:`float`, :py:class:`str`, Optional 

146 A value indicating which threshold to choose for selecting a score. 

147 If set to ``None``, then use the maximum F1-score on that measures file. 

148 If set to a floating-point value, then use the score that is 

149 obtained on that particular threshold. If set to a string, it should 

150 match one of the keys in ``data``. It then first calculate the 

151 threshold reaching the maximum score on that particular dataset and 

152 then applies that threshold to all other sets. Obs: If the task 

153 is segmentation, the score used is the F1-Score. 

154 

155 

156 Returns 

157 ------- 

158 

159 data : dict 

160 A dict in which keys are the names of the systems and the values are 

161 dictionaries that contain two keys: 

162 

163 * ``df``: A :py:class:`pandas.DataFrame` with the measures data loaded 

164 to 

165 * ``threshold``: A threshold to be used for summarization, depending on 

166 the ``threshold`` parameter set on the input 

167 """ 

168 

169 col_name = "mean_f1_score" 

170 score_name = "F1-score" 

171 

172 if isinstance(threshold, str): 

173 logger.info( 

174 f"Calculating threshold from maximum {score_name} at " 

175 f"'{threshold}' dataset..." 

176 ) 

177 measures_path = data[threshold] 

178 df = pandas.read_csv(measures_path) 

179 use_threshold = df.threshold[df[col_name].idxmax()] 

180 logger.info(f"Dataset '*': threshold = {use_threshold:.3f}'") 

181 

182 elif isinstance(threshold, float): 

183 use_threshold = threshold 

184 logger.info(f"Dataset '*': threshold = {use_threshold:.3f}'") 

185 

186 # loads all data 

187 retval = {} 

188 for name, measures_path in tqdm(data.items(), desc="sample"): 

189 logger.info(f"Loading measures from {measures_path}...") 

190 df = pandas.read_csv(measures_path) 

191 

192 if threshold is None: 

193 if "threshold_a_priori" in df: 

194 use_threshold = df.threshold[df.threshold_a_priori.idxmax()] 

195 logger.info( 

196 f"Dataset '{name}': threshold (a priori) = " 

197 f"{use_threshold:.3f}'" 

198 ) 

199 else: 

200 use_threshold = df.threshold[df[col_name].idxmax()] 

201 logger.info( 

202 f"Dataset '{name}': threshold (a posteriori) = " 

203 f"{use_threshold:.3f}'" 

204 ) 

205 

206 retval[name] = dict(df=df, threshold=use_threshold) 

207 

208 return retval 

209 

210 # hack to get a dictionary from arguments passed to input 

211 if len(label_path) % 2 != 0: 

212 raise click.ClickException( 

213 "Input label-paths should be doubles" 

214 " composed of name-path entries" 

215 ) 

216 data = dict(zip(label_path[::2], label_path[1::2])) 

217 

218 threshold = _validate_threshold(threshold, data) 

219 

220 # load all data measures 

221 data = _load(data, threshold=threshold) 

222 

223 from ..utils.plot import precision_recall_f1iso 

224 from ..utils.table import performance_table 

225 

226 if output_figure is not None: 

227 output_figure = os.path.realpath(output_figure) 

228 logger.info(f"Creating and saving plot at {output_figure}...") 

229 os.makedirs(os.path.dirname(output_figure), exist_ok=True) 

230 fig = precision_recall_f1iso(data, limits=plot_limits) 

231 fig.savefig(output_figure) 

232 fig.clear() 

233 

234 logger.info("Tabulating performance summary...") 

235 table = performance_table(data, table_format) 

236 click.echo(table) 

237 if output_table is not None: 

238 output_table = os.path.realpath(output_table) 

239 logger.info(f"Saving table at {output_table}...") 

240 os.makedirs(os.path.dirname(output_table), exist_ok=True) 

241 with open(output_table, "w") as f: 

242 f.write(table)