Coverage for /scratch/builds/bob/bob.med.tb/miniconda/conda-bld/bob.med.tb_1637571489937/_test_env_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placeho/lib/python3.8/site-packages/bob/med/tb/scripts/compare.py: 95%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

63 statements  

1#!/usr/bin/env python 

2# coding=utf-8 

3 

4import os 

5import click 

6 

7from bob.extension.scripts.click_helper import ( 

8 verbosity_option, 

9 AliasedGroup, 

10) 

11 

12import torch 

13import re 

14import pandas 

15import tabulate 

16from matplotlib.backends.backend_pdf import PdfPages 

17 

18from ..utils.plot import precision_recall_f1iso 

19from ..utils.plot import roc_curve 

20from ..utils.table import performance_table 

21 

22import logging 

23logger = logging.getLogger(__name__) 

24 

25 

26def _validate_threshold(t, dataset): 

27 """Validates the user threshold selection. Returns parsed threshold.""" 

28 

29 if t is None: 

30 return t 

31 

32 # we try to convert it to float first 

33 t = float(t) 

34 if t < 0.0 or t > 1.0: 

35 raise ValueError("Thresholds must be within range [0.0, 1.0]") 

36 

37 return t 

38 

39 

40def _load(data, threshold): 

41 """Plots comparison chart of all evaluated models 

42 

43 Parameters 

44 ---------- 

45 

46 data : dict 

47 A dict in which keys are the names of the systems and the values are 

48 paths to ``predictions.csv`` style files. 

49 

50 threshold : :py:class:`float` 

51 A threshold for the final classification. 

52 

53 

54 Returns 

55 ------- 

56 

57 data : dict 

58 A dict in which keys are the names of the systems and the values are 

59 dictionaries that contain two keys: 

60 

61 * ``df``: A :py:class:`pandas.DataFrame` with the predictions data  

62 loaded to 

63 * ``threshold``: The ``threshold`` parameter set on the input 

64 

65 """ 

66 

67 use_threshold = threshold 

68 logger.info(f"Dataset '*': threshold = {use_threshold:.3f}'") 

69 

70 # loads all data 

71 retval = {} 

72 for name, predictions_path in data.items(): 

73 

74 # Load predictions 

75 logger.info(f"Loading predictions from {predictions_path}...") 

76 pred_data = pandas.read_csv(predictions_path) 

77 pred = torch.Tensor([eval(re.sub(' +', ' ', x.replace('\n', '')).replace(' ', ',')) if isinstance(x, str) else x for x in pred_data['likelihood'].values]).double().flatten() 

78 gt = torch.Tensor([eval(re.sub(' +', ' ', x.replace('\n', '')).replace(' ', ',')) if isinstance(x, str) else x for x in pred_data['ground_truth'].values]).double().flatten() 

79 

80 pred_data['likelihood'] = pred 

81 pred_data['ground_truth'] = gt 

82 

83 retval[name] = dict(df=pred_data, threshold=use_threshold) 

84 

85 return retval 

86 

87 

88@click.command( 

89 epilog="""Examples: 

90 

91\b 

92 1. Compares system A and B, with their own predictions files: 

93\b 

94 $ bob tb compare -vv A path/to/A/predictions.csv B path/to/B/predictions.csv 

95""", 

96) 

97@click.argument( 

98 'label_path', 

99 nargs=-1, 

100 ) 

101@click.option( 

102 "--output-figure", 

103 "-f", 

104 help="Path where write the output figure (any extension supported by " 

105 "matplotlib is possible). If not provided, does not produce a figure.", 

106 required=False, 

107 default=None, 

108 type=click.Path(dir_okay=False, file_okay=True), 

109) 

110@click.option( 

111 "--table-format", 

112 "-T", 

113 help="The format to use for the comparison table", 

114 show_default=True, 

115 required=True, 

116 default="rst", 

117 type=click.Choice(tabulate.tabulate_formats), 

118) 

119@click.option( 

120 "--output-table", 

121 "-u", 

122 help="Path where write the output table. If not provided, does not write " 

123 "write a table to file, only to stdout.", 

124 required=False, 

125 default=None, 

126 type=click.Path(dir_okay=False, file_okay=True), 

127) 

128@click.option( 

129 "--threshold", 

130 "-t", 

131 help="This number is used to separate positive and negative cases " 

132 "by thresholding their score.", 

133 default=None, 

134 show_default=False, 

135 required=False, 

136) 

137@verbosity_option() 

138def compare(label_path, output_figure, table_format, output_table, 

139 threshold, **kwargs): 

140 """Compares multiple systems together""" 

141 

142 # hack to get a dictionary from arguments passed to input 

143 if len(label_path) % 2 != 0: 

144 raise click.ClickException("Input label-paths should be doubles" 

145 " composed of name-path entries") 

146 data = dict(zip(label_path[::2], label_path[1::2])) 

147 

148 threshold = _validate_threshold(threshold, data) 

149 

150 # load all data measures 

151 data = _load(data, threshold=threshold) 

152 

153 if output_figure is not None: 

154 output_figure = os.path.realpath(output_figure) 

155 logger.info(f"Creating and saving plot at {output_figure}...") 

156 os.makedirs(os.path.dirname(output_figure), exist_ok=True) 

157 pdf = PdfPages(output_figure) 

158 pdf.savefig(precision_recall_f1iso(data)) 

159 pdf.savefig(roc_curve(data)) 

160 pdf.close() 

161 

162 logger.info("Tabulating performance summary...") 

163 table = performance_table(data, table_format) 

164 click.echo(table) 

165 if output_table is not None: 

166 output_table = os.path.realpath(output_table) 

167 logger.info(f"Saving table at {output_table}...") 

168 os.makedirs(os.path.dirname(output_table), exist_ok=True) 

169 with open(output_table, "wt") as f: 

170 f.write(table)