Coverage for /scratch/builds/bob/bob.med.tb/miniconda/conda-bld/bob.med.tb_1637571489937/_test_env_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placeho/lib/python3.8/site-packages/bob/med/tb/scripts/aggregpred.py: 88%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

34 statements  

1#!/usr/bin/env python 

2# coding=utf-8 

3 

4import os 

5import click 

6 

7from bob.extension.scripts.click_helper import ( 

8 verbosity_option, 

9 AliasedGroup, 

10) 

11 

12import shutil 

13import torch 

14import re 

15import pandas 

16 

17import logging 

18logger = logging.getLogger(__name__) 

19 

20@click.command( 

21 epilog="""Examples: 

22 

23\b 

24 1. Aggregate multiple predictions csv files into one 

25\b 

26 $ bob tb aggregpred -vv path/to/train/predictions.csv path/to/test/predictions.csv 

27""", 

28) 

29@click.argument( 

30 'label_path', 

31 nargs=-1, 

32 ) 

33@click.option( 

34 "--output-folder", 

35 "-f", 

36 help="Path where to store the aggregated csv file (created if necessary)", 

37 required=False, 

38 default=None, 

39 type=click.Path(dir_okay=True, file_okay=False), 

40) 

41@verbosity_option() 

42def aggregpred(label_path, output_folder, **kwargs): 

43 """Aggregate multiple predictions csv files into one""" 

44 

45 # loads all data 

46 series = [] 

47 for predictions_path in label_path: 

48 

49 # Load predictions 

50 logger.info(f"Loading predictions from {predictions_path}...") 

51 pred_data = pandas.read_csv(predictions_path) 

52 pred = torch.Tensor([eval(re.sub(' +', ' ', x.replace('\n', '')).replace(' ', ',')) for x in pred_data['likelihood'].values]).double().flatten() 

53 gt = torch.Tensor([eval(re.sub(' +', ' ', x.replace('\n', '')).replace(' ', ',')) for x in pred_data['ground_truth'].values]).double().flatten() 

54 

55 pred_data['likelihood'] = pred 

56 pred_data['ground_truth'] = gt 

57 

58 series.append(pred_data) 

59 

60 df = pandas.concat([s for s in series]) 

61 

62 logger.info(f"Output folder: {output_folder}") 

63 os.makedirs(output_folder, exist_ok=True) 

64 

65 output_file = os.path.join(output_folder, "aggregpred.csv") 

66 if os.path.exists(output_file): 

67 backup = output_file + "~" 

68 if os.path.exists(backup): 

69 os.unlink(backup) 

70 shutil.move(output_file, backup) 

71 

72 logger.info("Saving aggregated CSV file...") 

73 df.to_csv(output_file, index=False, header=True)