Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
4"""Defines functionality for the evaluation of predictions"""
6import logging
7import os
9import h5py
10import numpy
11import pandas
12import PIL
13import torch
14import torch.nn.functional
15import torchvision.transforms.functional as VF
17from tqdm import tqdm
19from ..utils.measure import base_measures, bayesian_measures
21logger = logging.getLogger(__name__)
24def _posneg(pred, gt, threshold):
25 """Calculates true and false positives and negatives
28 Parameters
29 ----------
31 pred : torch.Tensor
32 pixel-wise predictions
34 gt : torch.Tensor
35 ground-truth (annotations)
37 threshold : float
38 a particular threshold in which to calculate the performance
39 measures
42 Returns
43 -------
45 tp_tensor : torch.Tensor
46 boolean tensor with true positives, considering all observations
48 fp_tensor : torch.Tensor
49 boolean tensor with false positives, considering all observations
51 tn_tensor : torch.Tensor
52 boolean tensor with true negatives, considering all observations
54 fn_tensor : torch.Tensor
55 boolean tensor with false negatives, considering all observations
57 """
59 gt = gt.byte() # byte tensor
61 # threshold
62 binary_pred = torch.gt(pred, threshold).byte()
64 # equals and not-equals
65 equals = torch.eq(binary_pred, gt).type(torch.uint8) # tensor
66 notequals = torch.ne(binary_pred, gt).type(torch.uint8) # tensor
68 # true positives
69 tp_tensor = gt * binary_pred
71 # false positives
72 fp_tensor = torch.eq((binary_pred + tp_tensor), 1).byte()
74 # true negatives
75 tn_tensor = equals - tp_tensor
77 # false negatives
78 fn_tensor = notequals - fp_tensor
80 return tp_tensor, fp_tensor, tn_tensor, fn_tensor
83def sample_measures_for_threshold(pred, gt, mask, threshold):
84 """
85 Calculates counts on one single sample, for a specific threshold
88 Parameters
89 ----------
91 pred : torch.Tensor
92 pixel-wise predictions
94 gt : torch.Tensor
95 ground-truth (annotations)
97 mask : torch.Tensor
98 region mask (used only if available). May be set to ``None``.
100 threshold : float
101 a particular threshold in which to calculate the performance
102 measures
105 Returns
106 -------
108 tp : int
110 fp : int
112 tn : int
114 fn : int
116 """
118 tp_tensor, fp_tensor, tn_tensor, fn_tensor = _posneg(pred, gt, threshold)
120 # if a mask is provided, consider only TP/FP/TN/FN **within** the region of
121 # interest defined by the mask
122 if mask is not None:
123 antimask = torch.le(mask, 0.5)
124 tp_tensor[antimask] = 0
125 fp_tensor[antimask] = 0
126 tn_tensor[antimask] = 0
127 fn_tensor[antimask] = 0
129 # calc measures from scalars
130 tp_count = torch.sum(tp_tensor).item()
131 fp_count = torch.sum(fp_tensor).item()
132 tn_count = torch.sum(tn_tensor).item()
133 fn_count = torch.sum(fn_tensor).item()
135 return tp_count, fp_count, tn_count, fn_count
138def _sample_measures(pred, gt, mask, steps):
139 """
140 Calculates measures on one single sample
143 Parameters
144 ----------
146 pred : torch.Tensor
147 pixel-wise predictions
149 gt : torch.Tensor
150 ground-truth (annotations)
152 mask : torch.Tensor
153 region mask (used only if available). May be set to ``None``.
155 steps : int
156 number of steps to use for threshold analysis. The step size is
157 calculated from this by dividing ``1.0/steps``
160 Returns
161 -------
163 measures : pandas.DataFrame
165 A pandas dataframe with the following columns:
167 * tp: int
168 * fp: int
169 * tn: int
170 * fn: int
172 """
174 step_size = 1.0 / steps
175 data = [
176 (index, threshold)
177 + sample_measures_for_threshold(pred, gt, mask, threshold)
178 for index, threshold in enumerate(numpy.arange(0.0, 1.0, step_size))
179 ]
181 retval = pandas.DataFrame(
182 data,
183 columns=(
184 "index",
185 "threshold",
186 "tp",
187 "fp",
188 "tn",
189 "fn",
190 ),
191 )
192 retval.set_index("index", inplace=True)
193 return retval
196def _sample_analysis(
197 img,
198 pred,
199 gt,
200 mask,
201 threshold,
202 tp_color=(0, 255, 0), # (128,128,128) Gray
203 fp_color=(0, 0, 255), # (70, 240, 240) Cyan
204 fn_color=(255, 0, 0), # (245, 130, 48) Orange
205 overlay=True,
206):
207 """Visualizes true positives, false positives and false negatives
210 Parameters
211 ----------
213 img : torch.Tensor
214 original image
216 pred : torch.Tensor
217 pixel-wise predictions
219 gt : torch.Tensor
220 ground-truth (annotations)
222 mask : torch.Tensor
223 region mask (used only if available). May be set to ``None``.
225 threshold : float
226 The threshold to be used while analyzing this image's probability map
228 tp_color : tuple
229 RGB value for true positives
231 fp_color : tuple
232 RGB value for false positives
234 fn_color : tuple
235 RGB value for false negatives
237 overlay : :py:class:`bool`, Optional
238 If set to ``True`` (which is the default), then overlay annotations on
239 top of the image. Otherwise, represent data on a black canvas.
242 Returns
243 -------
245 figure : PIL.Image.Image
247 A PIL image that contains the overlayed analysis of true-positives
248 (TP), false-positives (FP) and false negatives (FN).
250 """
252 tp_tensor, fp_tensor, tn_tensor, fn_tensor = _posneg(pred, gt, threshold)
254 # if a mask is provided, consider only TP/FP/TN/FN **within** the region of
255 # interest defined by the mask
256 if mask is not None:
257 antimask = torch.le(mask, 0.5)
258 tp_tensor[antimask] = 0
259 fp_tensor[antimask] = 0
260 tn_tensor[antimask] = 0
261 fn_tensor[antimask] = 0
263 # change to PIL representation
264 tp_pil = VF.to_pil_image(tp_tensor.float())
265 tp_pil_colored = PIL.ImageOps.colorize(tp_pil, (0, 0, 0), tp_color)
267 fp_pil = VF.to_pil_image(fp_tensor.float())
268 fp_pil_colored = PIL.ImageOps.colorize(fp_pil, (0, 0, 0), fp_color)
270 fn_pil = VF.to_pil_image(fn_tensor.float())
271 fn_pil_colored = PIL.ImageOps.colorize(fn_pil, (0, 0, 0), fn_color)
273 tp_pil_colored.paste(fp_pil_colored, mask=fp_pil)
274 tp_pil_colored.paste(fn_pil_colored, mask=fn_pil)
276 if overlay:
277 img = VF.to_pil_image(img) # PIL Image
278 # using blend here, to fade original image being overlayed, or
279 # its brightness may obfuscate colors from the vessel map
280 tp_pil_colored = PIL.Image.blend(img, tp_pil_colored, 0.5)
282 return tp_pil_colored
285def _summarize(data):
286 """Summarizes collected dataframes and adds bayesian figures"""
288 _entries = (
289 "mean_precision",
290 "mode_precision",
291 "lower_precision",
292 "upper_precision",
293 "mean_recall",
294 "mode_recall",
295 "lower_recall",
296 "upper_recall",
297 "mean_specificity",
298 "mode_specificity",
299 "lower_specificity",
300 "upper_specificity",
301 "mean_accuracy",
302 "mode_accuracy",
303 "lower_accuracy",
304 "upper_accuracy",
305 "mean_jaccard",
306 "mode_jaccard",
307 "lower_jaccard",
308 "upper_jaccard",
309 "mean_f1_score",
310 "mode_f1_score",
311 "lower_f1_score",
312 "upper_f1_score",
313 "frequentist_precision",
314 "frequentist_recall",
315 "frequentist_specificity",
316 "frequentist_accuracy",
317 "frequentist_jaccard",
318 "frequentist_f1_score",
319 )
321 def _row_summary(r):
323 # run bayesian_measures(), flatten tuple of tuples, name entries
324 bayesian = [
325 item
326 for sublist in bayesian_measures(
327 r.tp,
328 r.fp,
329 r.tn,
330 r.fn,
331 lambda_=0.5,
332 coverage=0.95,
333 )
334 for item in sublist
335 ]
337 # evaluate frequentist measures
338 frequentist = base_measures(r.tp, r.fp, r.tn, r.fn)
339 return pandas.Series(bayesian + list(frequentist), index=_entries)
341 # Merges all dataframes together
342 sums = pandas.concat(data.values()).groupby("index").sum()
343 sums["threshold"] /= len(data)
345 # create a new dataframe with these
346 measures = sums.apply(lambda r: _row_summary(r), axis=1)
348 # merge sums and measures into a single dataframe
349 return pandas.concat([sums, measures.reindex(sums.index)], axis=1).copy()
352def run(
353 dataset,
354 name,
355 predictions_folder,
356 output_folder=None,
357 overlayed_folder=None,
358 threshold=None,
359 steps=1000,
360):
361 """
362 Runs inference and calculates measures
365 Parameters
366 ---------
368 dataset : py:class:`torch.utils.data.Dataset`
369 a dataset to iterate on
371 name : str
372 the local name of this dataset (e.g. ``train``, or ``test``), to be
373 used when saving measures files.
375 predictions_folder : str
376 folder where predictions for the dataset images has been previously
377 stored
379 output_folder : :py:class:`str`, Optional
380 folder where to store results. If not provided, then do not store any
381 analysis (useful for quickly calculating overlay thresholds)
383 overlayed_folder : :py:class:`str`, Optional
384 if not ``None``, then it should be the name of a folder where to store
385 overlayed versions of the images and ground-truths
387 threshold : :py:class:`float`, Optional
388 if ``overlayed_folder``, then this should be threshold (floating point)
389 to apply to prediction maps to decide on positives and negatives for
390 overlaying analysis (graphical output). This number should come from
391 the training set or a separate validation set. Using a test set value
392 may bias your analysis. This number is also used to print the a priori
393 F1-score on the evaluated set.
395 steps : :py:class:`float`, Optional
396 number of threshold steps to consider when evaluating thresholds.
399 Returns
400 -------
402 threshold : float
403 Threshold to achieve the highest possible F1-score for this dataset
405 """
407 # Collect overall measures
408 data = {}
410 use_predictions_folder = os.path.join(predictions_folder, name)
411 if not os.path.exists(use_predictions_folder):
412 use_predictions_folder = predictions_folder
414 for sample in tqdm(dataset):
415 stem = sample[0]
416 image = sample[1]
417 gt = sample[2]
418 mask = None if len(sample) <= 3 else sample[3]
419 pred_fullpath = os.path.join(use_predictions_folder, stem + ".hdf5")
420 with h5py.File(pred_fullpath, "r") as f:
421 pred = f["array"][:]
422 pred = torch.from_numpy(pred)
423 if stem in data:
424 raise RuntimeError(
425 f"{stem} entry already exists in data. Cannot overwrite."
426 )
427 data[stem] = _sample_measures(pred, gt, mask, steps)
429 if output_folder is not None:
430 fullpath = os.path.join(output_folder, name, f"{stem}.csv")
431 tqdm.write(f"Saving {fullpath}...")
432 os.makedirs(os.path.dirname(fullpath), exist_ok=True)
433 data[stem].to_csv(fullpath)
435 if overlayed_folder is not None:
436 overlay_image = _sample_analysis(
437 image, pred, gt, mask, threshold=threshold, overlay=True
438 )
439 fullpath = os.path.join(overlayed_folder, name, f"{stem}.png")
440 tqdm.write(f"Saving {fullpath}...")
441 os.makedirs(os.path.dirname(fullpath), exist_ok=True)
442 overlay_image.save(fullpath)
444 # Merges all dataframes together
445 measures = _summarize(data)
447 maxf1 = measures["mean_f1_score"].max()
448 maxf1_index = measures["mean_f1_score"].idxmax()
449 maxf1_threshold = measures["threshold"][maxf1_index]
451 logger.info(
452 f"Maximum F1-score of {maxf1:.5f}, achieved at "
453 f"threshold {maxf1_threshold:.3f} (chosen *a posteriori*)"
454 )
456 if threshold is not None:
458 # get the closest possible threshold we have
459 index = int(round(steps * threshold))
460 f1_a_priori = measures["mean_f1_score"][index]
461 actual_threshold = measures["threshold"][index]
463 # mark threshold a priori chosen on this dataset
464 measures["threshold_a_priori"] = False
465 measures["threshold_a_priori", index] = True
467 logger.info(
468 f"F1-score of {f1_a_priori:.5f}, at threshold "
469 f"{actual_threshold:.3f} (chosen *a priori*)"
470 )
472 if output_folder is not None:
473 logger.info(f"Output folder: {output_folder}")
474 os.makedirs(output_folder, exist_ok=True)
475 measures_path = os.path.join(output_folder, f"{name}.csv")
476 logger.info(
477 f"Saving measures over all input images at {measures_path}..."
478 )
479 measures.to_csv(measures_path)
481 return maxf1_threshold
484def compare_annotators(
485 baseline, other, name, output_folder, overlayed_folder=None
486):
487 """
488 Compares annotations on the **same** dataset
491 Parameters
492 ---------
494 baseline : py:class:`torch.utils.data.Dataset`
495 a dataset to iterate on, containing the baseline annotations
497 other : py:class:`torch.utils.data.Dataset`
498 a second dataset, with the same samples as ``baseline``, but annotated
499 by a different annotator than in the first dataset. The key values
500 must much between ``baseline`` and this dataset.
502 name : str
503 the local name of this dataset (e.g. ``train-second-annotator``, or
504 ``test-second-annotator``), to be used when saving measures files.
506 output_folder : str
507 folder where to store results
509 overlayed_folder : :py:class:`str`, Optional
510 if not ``None``, then it should be the name of a folder where to store
511 overlayed versions of the images and ground-truths
513 """
515 logger.info(f"Output folder: {output_folder}")
516 os.makedirs(output_folder, exist_ok=True)
518 # Collect overall measures
519 data = {}
521 for baseline_sample, other_sample in tqdm(
522 list(zip(baseline, other)), desc="samples", leave=False, disable=None
523 ):
524 assert baseline_sample[0] == other_sample[0], (
525 f"Mismatch between "
526 f"datasets for second-annotator analysis "
527 f"({baseline_sample[0]} != {other_sample[0]}). This "
528 f"typically occurs when the second annotator (`other`) "
529 f"comes from a different dataset than the `baseline` dataset"
530 )
532 stem = baseline_sample[0]
533 image = baseline_sample[1]
534 gt = baseline_sample[2]
535 pred = other_sample[2] # works as a prediction
536 mask = None if len(baseline_sample) < 4 else baseline_sample[3]
537 if stem in data:
538 raise RuntimeError(
539 f"{stem} entry already exists in data. " f"Cannot overwrite."
540 )
541 data[stem] = _sample_measures(pred, gt, mask, 2)
543 if output_folder is not None:
544 fullpath = os.path.join(
545 output_folder, "second-annotator", name, f"{stem}.csv"
546 )
547 tqdm.write(f"Saving {fullpath}...")
548 os.makedirs(os.path.dirname(fullpath), exist_ok=True)
549 data[stem].to_csv(fullpath)
551 if overlayed_folder is not None:
552 overlay_image = _sample_analysis(
553 image, pred, gt, mask, threshold=0.5, overlay=True
554 )
555 fullpath = os.path.join(
556 overlayed_folder, "second-annotator", name, f"{stem}.png"
557 )
558 tqdm.write(f"Saving {fullpath}...")
559 os.makedirs(os.path.dirname(fullpath), exist_ok=True)
560 overlay_image.save(fullpath)
562 measures = _summarize(data)
563 measures.drop(0, inplace=True) # removes threshold == 0.0, keeps 0.5 only
565 measures_path = os.path.join(
566 output_folder, "second-annotator", f"{name}.csv"
567 )
568 os.makedirs(os.path.dirname(measures_path), exist_ok=True)
569 logger.info(f"Saving summaries over all input images at {measures_path}...")
570 measures.to_csv(measures_path)
572 maxf1 = measures["mean_f1_score"].max()
573 logger.info(f"F1-score of {maxf1:.5f} (second annotator; threshold=0.5)")