Coverage for /scratch/builds/bob/bob.med.tb/miniconda/conda-bld/bob.med.tb_1637571489937/_test_env_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placeho/lib/python3.8/site-packages/bob/med/tb/utils/table.py: 100%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

33 statements  

1#!/usr/bin/env python 

2# coding=utf-8 

3 

4 

5import tabulate 

6import numpy as np 

7import torch 

8from sklearn.metrics import auc, precision_recall_curve as pr_curve, roc_curve as r_curve, f1_score, accuracy_score 

9from ..engine.evaluator import posneg 

10from ..utils.measure import bayesian_measures, base_measures 

11 

12 

13def performance_table(data, fmt): 

14 """Tables result comparison in a given format 

15 

16 

17 Parameters 

18 ---------- 

19 

20 data : dict 

21 A dictionary in which keys are strings defining plot labels and values 

22 are dictionaries with two entries: 

23 

24 * ``df``: :py:class:`pandas.DataFrame` 

25 

26 A dataframe that is produced by our predictor engine containing  

27 the following columns: ``filename``, ``likelihood``,  

28 ``ground_truth``. 

29 

30 * ``threshold``: :py:class:`list` 

31 

32 A threshold to compute measures. 

33 

34 

35 fmt : str 

36 One of the formats supported by tabulate. 

37 

38 

39 Returns 

40 ------- 

41 

42 table : str 

43 A table in a specific format 

44 

45 """ 

46 

47 headers = [ 

48 "Dataset", 

49 "T", 

50 "F1 (95% CI)", 

51 "Prec (95% CI)", 

52 "Recall/Sen (95% CI)", 

53 "Spec (95% CI)", 

54 "Acc (95% CI)", 

55 "AUC (PRC)", 

56 "AUC (ROC)" 

57 ] 

58 

59 table = [] 

60 for k, v in data.items(): 

61 entry = [k, v["threshold"], ] 

62 

63 df = v["df"] 

64 

65 gt = torch.tensor(df['ground_truth'].values) 

66 pred = torch.tensor(df['likelihood'].values) 

67 threshold = v["threshold"] 

68 

69 tp_tensor, fp_tensor, tn_tensor, fn_tensor = posneg(pred, gt, threshold) 

70 

71 # calc measures from scalars 

72 tp_count = torch.sum(tp_tensor).item() 

73 fp_count = torch.sum(fp_tensor).item() 

74 tn_count = torch.sum(tn_tensor).item() 

75 fn_count = torch.sum(fn_tensor).item() 

76 

77 base_m = base_measures( 

78 tp_count, 

79 fp_count, 

80 tn_count, 

81 fn_count, 

82 ) 

83 

84 bayes_m = bayesian_measures( 

85 tp_count, 

86 fp_count, 

87 tn_count, 

88 fn_count, 

89 lambda_=1, 

90 coverage=0.95, 

91 ) 

92 

93 # statistics based on the "assigned" threshold (a priori, less biased) 

94 entry.append("{:.2f} ({:.2f}, {:.2f})".format(base_m[5], bayes_m[5][2], bayes_m[5][3])) # f1 

95 entry.append("{:.2f} ({:.2f}, {:.2f})".format(base_m[0], bayes_m[0][2], bayes_m[0][3])) # precision 

96 entry.append("{:.2f} ({:.2f}, {:.2f})".format(base_m[1], bayes_m[1][2], bayes_m[1][3])) # recall/sensitivity 

97 entry.append("{:.2f} ({:.2f}, {:.2f})".format(base_m[2], bayes_m[2][2], bayes_m[2][3])) # specificity 

98 entry.append("{:.2f} ({:.2f}, {:.2f})".format(base_m[3], bayes_m[3][2], bayes_m[3][3])) # accuracy 

99 

100 prec, recall, _ = pr_curve(gt, pred) 

101 fpr, tpr, _ = r_curve(gt, pred) 

102 

103 entry.append(auc(recall, prec)) 

104 entry.append(auc(fpr, tpr)) 

105 

106 table.append(entry) 

107 

108 return tabulate.tabulate(table, headers, tablefmt=fmt, floatfmt=".3f")