Coverage for /scratch/builds/bob/bob.ip.binseg/miniconda/conda-bld/bob.ip.binseg_1673966692152/_test_env_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_p/lib/python3.10/site-packages/bob/ip/common/script/compare.py: 90%

72 statements  

« prev     ^ index     » next       coverage.py v7.0.5, created at 2023-01-17 15:03 +0000

1#!/usr/bin/env python 

2# coding=utf-8 

3 

4import logging 

5import os 

6 

7import click 

8import pandas 

9 

10from tqdm import tqdm 

11 

12logger = logging.getLogger(__name__) 

13 

14 

15def base_compare( 

16 label_path, 

17 output_figure, 

18 output_table, 

19 threshold, 

20 plot_limits, 

21 detection, 

22 verbose, 

23 table_format="rst", 

24 **kwargs, 

25): 

26 """Compare multiple systems together.""" 

27 

28 def _validate_threshold(t, dataset): 

29 """Validate the user threshold selection. Returns parsed threshold.""" 

30 if t is None: 

31 return t 

32 

33 try: 

34 # we try to convert it to float first 

35 t = float(t) 

36 if t < 0.0 or t > 1.0: 

37 raise ValueError( 

38 "Float thresholds must be within range [0.0, 1.0]" 

39 ) 

40 except ValueError: 

41 # it is a bit of text - assert dataset with name is available 

42 if not isinstance(dataset, dict): 

43 raise ValueError( 

44 "Threshold should be a floating-point number " 

45 "if your provide only a single dataset for evaluation" 

46 ) 

47 if t not in dataset: 

48 raise ValueError( 

49 f"Text thresholds should match dataset names, " 

50 f"but {t} is not available among the datasets provided (" 

51 f"({', '.join(dataset.keys())})" 

52 ) 

53 

54 return t 

55 

56 def _load(data, detection, threshold=None): 

57 """Plot comparison chart of all evaluated models. 

58 

59 Parameters 

60 ---------- 

61 

62 data : dict 

63 A dict in which keys are the names of the systems and the values are 

64 paths to ``measures.csv`` style files. 

65 

66 threshold : :py:class:`float`, :py:class:`str`, Optional 

67 A value indicating which threshold to choose for selecting a score. 

68 If set to ``None``, then use the maximum F1-score on that measures file. 

69 If set to a floating-point value, then use the score that is 

70 obtained on that particular threshold. If set to a string, it should 

71 match one of the keys in ``data``. It then first calculate the 

72 threshold reaching the maximum score on that particular dataset and 

73 then applies that threshold to all other sets. Obs: If the task 

74 is segmentation, the score used is the F1-Score; for the detection 

75 task the score used is the Intersection Over Union (IoU). 

76 

77 

78 Returns 

79 ------- 

80 

81 data : dict 

82 A dict in which keys are the names of the systems and the values are 

83 dictionaries that contain two keys: 

84 

85 * ``df``: A :py:class:`pandas.DataFrame` with the measures data loaded 

86 to 

87 * ``threshold``: A threshold to be used for summarization, depending on 

88 the ``threshold`` parameter set on the input 

89 

90 """ 

91 if detection: 

92 col_name = "mean_iou" 

93 score_name = "IoU-score" 

94 

95 else: 

96 col_name = "mean_f1_score" 

97 score_name = "F1-score" 

98 

99 if isinstance(threshold, str): 

100 logger.info( 

101 f"Calculating threshold from maximum {score_name} at " 

102 f"'{threshold}' dataset..." 

103 ) 

104 measures_path = data[threshold] 

105 df = pandas.read_csv(measures_path) 

106 use_threshold = df.threshold[df[col_name].idxmax()] 

107 logger.info(f"Dataset '*': threshold = {use_threshold:.3f}'") 

108 

109 elif isinstance(threshold, float): 

110 use_threshold = threshold 

111 logger.info(f"Dataset '*': threshold = {use_threshold:.3f}'") 

112 

113 # loads all data 

114 retval = {} 

115 for name, measures_path in tqdm(data.items(), desc="sample"): 

116 

117 logger.info(f"Loading measures from {measures_path}...") 

118 df = pandas.read_csv(measures_path) 

119 

120 if threshold is None: 

121 

122 if "threshold_a_priori" in df: 

123 use_threshold = df.threshold[df.threshold_a_priori.idxmax()] 

124 logger.info( 

125 f"Dataset '{name}': threshold (a priori) = " 

126 f"{use_threshold:.3f}'" 

127 ) 

128 else: 

129 use_threshold = df.threshold[df[col_name].idxmax()] 

130 logger.info( 

131 f"Dataset '{name}': threshold (a posteriori) = " 

132 f"{use_threshold:.3f}'" 

133 ) 

134 

135 retval[name] = dict(df=df, threshold=use_threshold) 

136 

137 return retval 

138 

139 # hack to get a dictionary from arguments passed to input 

140 if len(label_path) % 2 != 0: 

141 raise click.ClickException( 

142 "Input label-paths should be doubles" 

143 " composed of name-path entries" 

144 ) 

145 data = dict(zip(label_path[::2], label_path[1::2])) 

146 

147 threshold = _validate_threshold(threshold, data) 

148 

149 # load all data measures 

150 data = _load(data, detection=detection, threshold=threshold) 

151 

152 if detection: 

153 from ..utils.table import ( 

154 performance_table_detection as performance_table, 

155 ) 

156 

157 else: 

158 from ..utils.plot import precision_recall_f1iso 

159 from ..utils.table import performance_table 

160 

161 if output_figure is not None: 

162 output_figure = os.path.realpath(output_figure) 

163 logger.info(f"Creating and saving plot at {output_figure}...") 

164 os.makedirs(os.path.dirname(output_figure), exist_ok=True) 

165 fig = precision_recall_f1iso(data, limits=plot_limits) 

166 fig.savefig(output_figure) 

167 fig.clear() 

168 

169 logger.info("Tabulating performance summary...") 

170 table = performance_table(data, table_format) 

171 click.echo(table) 

172 if output_table is not None: 

173 output_table = os.path.realpath(output_table) 

174 logger.info(f"Saving table at {output_table}...") 

175 os.makedirs(os.path.dirname(output_table), exist_ok=True) 

176 with open(output_table, "wt") as f: 

177 f.write(table)