1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3
4"""Defines functionality for the evaluation of predictions"""
5
6import os
7
8import numpy
9import pandas as pd
10import matplotlib.pyplot as plt
11import re
12
13import torch
14from sklearn import metrics
15from bob.measure import eer_threshold
16
17from ..utils.measure import base_measures, get_centered_maxf1
18
19import logging
20
21logger = logging.getLogger(__name__)
22
23
24def posneg(pred, gt, threshold):
25 """Calculates true and false positives and negatives"""
26
27 # threshold
28 binary_pred = torch.gt(pred, threshold)
29
30 # equals and not-equals
31 equals = torch.eq(binary_pred, gt).type(torch.uint8)
32 notequals = torch.ne(binary_pred, gt).type(torch.uint8)
33
34 # true positives
35 tp_tensor = (gt * binary_pred).type(torch.uint8)
36
37 # false positives
38 fp_tensor = torch.eq((binary_pred + tp_tensor), 1).type(torch.uint8)
39
40 # true negatives
41 tn_tensor = (equals - tp_tensor).type(torch.uint8)
42
43 # false negatives
44 fn_tensor = notequals - fp_tensor.type(torch.uint8)
45
46 return tp_tensor, fp_tensor, tn_tensor, fn_tensor
47
48def sample_measures_for_threshold(pred, gt, threshold):
49 """
50 Calculates measures on one single sample, for a specific threshold
51
52
53 Parameters
54 ----------
55
56 pred : torch.Tensor
57 pixel-wise predictions
58
59 gt : torch.Tensor
60 ground-truth (annotations)
61
62 threshold : float
63 a particular threshold in which to calculate the performance
64 measures
65
66
67 Returns
68 -------
69
70 precision: float
71
72 recall: float
73
74 specificity: float
75
76 accuracy: float
77
78 jaccard: float
79
80 f1_score: float
81
82 """
83
84 tp_tensor, fp_tensor, tn_tensor, fn_tensor = posneg(pred, gt, threshold)
85
86 # calc measures from scalars
87 tp_count = torch.sum(tp_tensor).item()
88 fp_count = torch.sum(fp_tensor).item()
89 tn_count = torch.sum(tn_tensor).item()
90 fn_count = torch.sum(fn_tensor).item()
91 return base_measures(tp_count, fp_count, tn_count, fn_count)
92
93def run(
94 dataset,
95 name,
96 predictions_folder,
97 output_folder=None,
98 f1_thresh=None,
99 eer_thresh=None,
100 steps=1000,
101):
102 """
103 Runs inference and calculates measures
104
105
106 Parameters
107 ---------
108
109 dataset : py:class:`torch.utils.data.Dataset`
110 a dataset to iterate on
111
112 name : str
113 the local name of this dataset (e.g. ``train``, or ``test``), to be
114 used when saving measures files.
115
116 predictions_folder : str
117 folder where predictions for the dataset images has been previously
118 stored
119
120 output_folder : :py:class:`str`, Optional
121 folder where to store results.
122
123 f1_thresh : :py:class:`float`, Optional
124 This number should come from
125 the training set or a separate validation set. Using a test set value
126 may bias your analysis. This number is also used to print the a priori
127 F1-score on the evaluated set.
128
129 eer_thresh : :py:class:`float`, Optional
130 This number should come from
131 the training set or a separate validation set. Using a test set value
132 may bias your analysis. This number is used to print the a priori
133 EER.
134
135 steps : :py:class:`float`, Optional
136 number of threshold steps to consider when evaluating thresholds.
137
138
139 Returns
140 -------
141
142 f1_threshold : float
143 Threshold to achieve the highest possible F1-score for this dataset
144
145 eer_threshold : float
146 Threshold achieving Equal Error Rate for this dataset
147
148 """
149
150 predictions_path = os.path.join(predictions_folder, name, "predictions.csv")
151 if not os.path.exists(predictions_path):
152 predictions_path = predictions_folder
153
154 # Load predictions
155 pred_data = pd.read_csv(predictions_path)
156 pred = torch.Tensor([eval(re.sub(' +', ' ', x.replace('\n', '')).replace(' ', ',')) for x in pred_data['likelihood'].values]).double()
157 gt = torch.Tensor([eval(re.sub(' +', ' ', x.replace('\n', '')).replace(' ', ',')) for x in pred_data['ground_truth'].values]).double()
158
159 if pred.shape[1] == 1 and gt.shape[1] == 1:
160 pred = torch.flatten(pred)
161 gt = torch.flatten(gt)
162
163 pred_data['likelihood'] = pred
164 pred_data['ground_truth'] = gt
165
166 # Multiclass f1 score computation
167 if pred.ndim > 1:
168 auc = metrics.roc_auc_score(gt, pred)
169 logger.info("Evaluating multiclass classification")
170 logger.info(f"AUC: {auc}")
171 logger.info("F1 and EER are not implemented for multiclass")
172
173 return None, None
174
175 # Generate measures for each threshold
176 step_size = 1.0 / steps
177 data = [
178 (index, threshold) + sample_measures_for_threshold(pred, gt, threshold)
179 for index, threshold in enumerate(numpy.arange(0.0, 1.0, step_size))
180 ]
181
182 data_df = pd.DataFrame(
183 data,
184 columns=(
185 "index",
186 "threshold",
187 "precision",
188 "recall",
189 "specificity",
190 "accuracy",
191 "jaccard",
192 "f1_score",
193 )
194 )
195 data_df = data_df.set_index("index")
196
197 # Save evaluation csv
198 if output_folder is not None:
199 fullpath = os.path.join(output_folder, f"{name}.csv")
200 logger.info(f"Saving {fullpath}...")
201 os.makedirs(os.path.dirname(fullpath), exist_ok=True)
202 data_df.to_csv(fullpath)
203
204 # Find max F1 score
205 f1_scores = numpy.asarray(data_df["f1_score"])
206 thresholds = numpy.asarray(data_df["threshold"])
207
208 maxf1, maxf1_threshold = get_centered_maxf1(
209 f1_scores,
210 thresholds
211 )
212
213 logger.info(
214 f"Maximum F1-score of {maxf1:.5f}, achieved at "
215 f"threshold {maxf1_threshold:.3f} (chosen *a posteriori*)"
216 )
217
218 # Find EER
219 neg_gt = pred_data.loc[pred_data.loc[:, 'ground_truth'] == 0, :]
220 pos_gt = pred_data.loc[pred_data.loc[:, 'ground_truth'] == 1, :]
221 post_eer_threshold = eer_threshold(neg_gt['likelihood'], pos_gt['likelihood'])
222
223 logger.info(
224 f"Equal error rate achieved at "
225 f"threshold {post_eer_threshold:.3f} (chosen *a posteriori*)"
226 )
227
228 # Save score table
229 if output_folder is not None:
230 fig, axes = plt.subplots(1)
231 fig.tight_layout(pad=3.0)
232
233 # Names and bounds
234 axes.set_xlabel("Score")
235 axes.set_ylabel("Normalized counts")
236 axes.set_xlim(0.0, 1.0)
237
238 neg_weights = numpy.ones_like(neg_gt['likelihood']) / len(pred_data['likelihood'])
239 pos_weights = numpy.ones_like(pos_gt['likelihood']) / len(pred_data['likelihood'])
240
241 axes.hist(
242 [neg_gt['likelihood'], pos_gt['likelihood']],
243 weights=[neg_weights, pos_weights],
244 bins=100, color=['tab:blue', 'tab:orange'],
245 label=["Negatives", "Positives"])
246 axes.legend(prop={'size': 10}, loc="upper center")
247 axes.set_title(f"Score table for {name} subset")
248
249 # we should see some of axes 1 axes
250 axes.spines["right"].set_visible(False)
251 axes.spines["top"].set_visible(False)
252 axes.spines["left"].set_position(("data", -0.015))
253
254 fullpath = os.path.join(output_folder, f"{name}_score_table.pdf")
255 fig.savefig(fullpath)
256
257 if f1_thresh is not None and eer_thresh is not None:
258
259 # get the closest possible threshold we have
260 index = int(round(steps * f1_thresh))
261 f1_a_priori = data_df["f1_score"][index]
262 actual_threshold = data_df["threshold"][index]
263
264 logger.info(
265 f"F1-score of {f1_a_priori:.5f}, at threshold "
266 f"{actual_threshold:.3f} (chosen *a priori*)"
267 )
268
269 # Print the a priori EER threshold
270 logger.info(
271 f"Equal error rate (chosen *a priori*) {eer_thresh:.3f}"
272 )
273
274 return maxf1_threshold, post_eer_threshold
275
276 # from matplotlib.backends.backend_pdf import PdfPages
277
278 # fname = os.path.join(output_folder, name + ".pdf")
279 # os.makedirs(os.path.dirname(fname), exist_ok=True)
280
281 # with PdfPages(fname) as pdf:
282
283 # fig, axes = plt.subplots(2, 2, figsize=(12.8, 9.6))
284 # fig.suptitle(f"Subset: {name}", fontsize=16, fontweight='semibold')
285 # axes = axes.flatten()
286
287 # # Tight layout often produces nice results
288 # # but requires the title to be spaced accordingly
289 # fig.tight_layout(pad=3.0)
290 # fig.subplots_adjust(top=0.92)
291
292 # # ------------
293 # # Score table
294 # # ------------
295
296 # axes[0].set_xlim(0.0, 1.0)
297 # axes[0].hist(
298 # [neg_gt['likelihood'], pos_gt['likelihood']],
299 # bins=30, color=['tab:blue', 'tab:orange'],
300 # label=["Negatives", "Positives"])
301 # axes[0].legend(prop={'size': 10})
302 # axes[0].set_title("Score table")
303
304 # # ----------
305 # # ROC Curve
306 # # ----------
307
308 # # TPR = 1 - FNR
309 # (line,) = axes[1].plot(
310 # 1 - data_df['specificity'],
311 # data_df['recall'],
312 # color="#1f77b4"
313 # )
314 # auc = roc_auc_score(neg_gt['likelihood'], pos_gt['likelihood'])
315 # axes[1].set(xlabel='1 - specificity', ylabel='Sensitivity',
316 # title=f'ROC curve (AUC={auc:.4f})')
317 # # axes[1].plot([0, 1], [0, 1], color='tab:orange', linestyle='--')
318 # axes[1].grid(linestyle="--", linewidth=1, color="gray", alpha=0.2)
319 # axes[1].set_xlim([0.0, 1.0])
320 # axes[1].set_ylim([0.0, 1.0])
321
322 # # Equal Error Rate threshold
323 # EER = eer(neg_gt['likelihood'], pos_gt['likelihood'])
324 # threshold = eer_threshold(neg_gt['likelihood'], pos_gt['likelihood'])
325 # threshold_index = data_df['threshold'].sub(threshold).abs().idxmin()
326 # # hter_threshold = min_hter_threshold(neg_gt['likelihood'], pos_gt['likelihood'])
327
328 # # Plot EER
329 # (marker,) = axes[1].plot(
330 # 1 - data_df["specificity"][threshold_index],
331 # data_df["recall"][threshold_index],
332 # marker="o",
333 # color="tab:blue",
334 # markersize=8
335 # )
336
337 # # We should see some of axes 1 axes
338 # axes[1].spines["right"].set_visible(False)
339 # axes[1].spines["top"].set_visible(False)
340 # axes[1].spines["left"].set_position(("data", -0.015))
341 # axes[1].spines["bottom"].set_position(("data", -0.015))
342
343 # # Legend
344 # label = f"{name} set (EER={EER:.4f})"
345 # axes[1].legend(
346 # [tuple([line, marker])],
347 # [label],
348 # loc="lower right",
349 # fancybox=True,
350 # framealpha=0.7,
351 # )
352
353 # # -----------------------
354 # # Precision-recall Curve
355 # # -----------------------
356
357 # (line,) = axes[2].plot(data_df['recall'], data_df['precision'])
358 # prc_auc = metrics.auc(data_df['recall'], data_df['precision'])
359 # axes[2].set(xlabel='Recall', ylabel='Precision',
360 # title=f'Precision-recall curve (AUC={prc_auc:.4f})')
361 # axes[2].grid(linestyle="--", linewidth=1, color="gray", alpha=0.2)
362 # axes[2].set_xlim([0.0, 1.0])
363 # axes[2].set_ylim([0.0, 1.0])
364
365 # # Annotates plot with F1-score iso-lines
366 # axes_right = axes[2].twinx()
367 # f_scores_d = numpy.linspace(0.1, 0.9, num=9)
368 # tick_locs = []
369 # tick_labels = []
370 # for f in f_scores_d:
371 # x = numpy.linspace(0.01, 1)
372 # y = f * x / (2 * x - f)
373 # (l,) = axes_right.plot(x[y >= 0], y[y >= 0], color="green", alpha=0.1)
374 # tick_locs.append(y[-1])
375 # tick_labels.append("%.1f" % f)
376 # axes_right.tick_params(axis="y", which="both", pad=0, right=False, left=False)
377 # axes_right.set_ylabel("iso-F", color="green", alpha=0.3)
378 # axes_right.set_ylim([0.0, 1.0])
379 # axes_right.yaxis.set_label_coords(1.015, 0.97)
380 # axes_right.set_yticks(tick_locs) # notice these are invisible
381 # for k in axes_right.set_yticklabels(tick_labels):
382 # k.set_color("green")
383 # k.set_alpha(0.3)
384 # k.set_size(8)
385
386 # # We shouldn't see any of axes_right axes
387 # axes_right.spines["right"].set_visible(False)
388 # axes_right.spines["top"].set_visible(False)
389 # axes_right.spines["left"].set_visible(False)
390 # axes_right.spines["bottom"].set_visible(False)
391
392 # # Plot F1 score
393 # (marker,) = axes[2].plot(
394 # data_df["recall"][maxf1_index],
395 # data_df["precision"][maxf1_index],
396 # marker="o",
397 # color="tab:blue",
398 # markersize=8
399 # )
400
401 # # We should see some of axes 2 axes
402 # axes[2].spines["right"].set_visible(False)
403 # axes[2].spines["top"].set_visible(False)
404 # axes[2].spines["left"].set_position(("data", -0.015))
405 # axes[2].spines["bottom"].set_position(("data", -0.015))
406
407 # # Legend
408 # label = f"{name} set (F1={data_df['f1_score'][maxf1_index]:.4f})"
409 # axes[2].legend(
410 # [tuple([line, marker])],
411 # [label],
412 # loc="lower left",
413 # fancybox=True,
414 # framealpha=0.7,
415 # )
416
417 # # Mean square error given optimal threshold (computed on train set)
418 # ground_truth = pred_data['ground_truth']
419 # likelihood = pred_data['likelihood']
420 # mse_res = mse(likelihood, ground_truth)
421 # text_mse = f"MSE with a threshold of {threshold:.3f}: {mse_res:.3f}"
422 # axes[3].text(0.5, 0.5, text_mse, horizontalalignment="center",
423 # verticalalignment="center")
424 # axes[3].axis('off')
425
426 # pdf.savefig()
427 # plt.close(fig)
428
429 # f1_score = f_score(neg_gt['likelihood'], pos_gt['likelihood'], threshold)
430
431 # logger.info(
432 # f"Maximum F1-score of {f1_score:.5f}, achieved at "
433 # f"threshold {threshold:.3f} (chosen *a priori*)"
434 # )
435
436 # return threshold