Coverage for src/deepdraw/utils/table.py: 100%

18 statements  

« prev     ^ index     » next       coverage.py v7.3.1, created at 2023-11-30 15:00 +0100

1# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> 

2# 

3# SPDX-License-Identifier: GPL-3.0-or-later 

4 

5import tabulate 

6 

7from .measure import auc 

8 

9 

10def performance_table(data, fmt): 

11 """Tables result comparison in a given format. 

12 

13 Parameters 

14 ---------- 

15 

16 data : dict 

17 A dictionary in which keys are strings defining plot labels and values 

18 are dictionaries with two entries: 

19 

20 * ``df``: :py:class:`pandas.DataFrame` 

21 

22 A dataframe that is produced by our evaluator engine, indexed by 

23 integer "thresholds", containing the following columns: 

24 ``threshold``, ``tp``, ``fp``, ``tn``, ``fn``, ``mean_precision``, 

25 ``mode_precision``, ``lower_precision``, ``upper_precision``, 

26 ``mean_recall``, ``mode_recall``, ``lower_recall``, ``upper_recall``, 

27 ``mean_specificity``, ``mode_specificity``, ``lower_specificity``, 

28 ``upper_specificity``, ``mean_accuracy``, ``mode_accuracy``, 

29 ``lower_accuracy``, ``upper_accuracy``, ``mean_jaccard``, 

30 ``mode_jaccard``, ``lower_jaccard``, ``upper_jaccard``, 

31 ``mean_f1_score``, ``mode_f1_score``, ``lower_f1_score``, 

32 ``upper_f1_score``, ``frequentist_precision``, 

33 ``frequentist_recall``, ``frequentist_specificity``, 

34 ``frequentist_accuracy``, ``frequentist_jaccard``, 

35 ``frequentist_f1_score``. 

36 

37 * ``threshold``: :py:class:`list` 

38 

39 A threshold to graph with a dot for each set. Specific 

40 threshold values do not affect "second-annotator" dataframes. 

41 

42 

43 fmt : str 

44 One of the formats supported by tabulate. 

45 

46 

47 Returns 

48 ------- 

49 

50 table : str 

51 A table in a specific format 

52 """ 

53 

54 headers = [ 

55 "Dataset", 

56 "T", 

57 "E(F1)", 

58 "CI(F1)", 

59 "AUC", 

60 "CI(AUC)", 

61 ] 

62 

63 table = [] 

64 for k, v in data.items(): 

65 entry = [ 

66 k, 

67 v["threshold"], 

68 ] 

69 

70 # statistics based on the "assigned" threshold (a priori, less biased) 

71 bins = len(v["df"]) 

72 index = int(round(bins * v["threshold"])) 

73 index = min(index, len(v["df"]) - 1) # avoids out of range indexing 

74 entry.append(v["df"].mean_f1_score[index]) 

75 entry.append( 

76 f"{v['df'].lower_f1_score[index]:.3f}-{v['df'].upper_f1_score[index]:.3f}" 

77 ) 

78 

79 # AUC PR curve 

80 entry.append( 

81 auc( 

82 v["df"]["mean_recall"].to_numpy(), 

83 v["df"]["mean_precision"].to_numpy(), 

84 ) 

85 ) 

86 lower_auc = auc( 

87 v["df"]["lower_recall"].to_numpy(), 

88 v["df"]["lower_precision"].to_numpy(), 

89 ) 

90 upper_auc = auc( 

91 v["df"]["upper_recall"].to_numpy(), 

92 v["df"]["upper_precision"].to_numpy(), 

93 ) 

94 entry.append(f"{lower_auc:.3f}-{upper_auc:.3f}") 

95 

96 table.append(entry) 

97 

98 return tabulate.tabulate( 

99 table, headers, tablefmt=fmt, floatfmt=".3f", stralign="right" 

100 )