Coverage for src/bob/bio/base/script/compare_samples.py: 89%

35 statements  

« prev     ^ index     » next       coverage.py v7.6.5, created at 2024-11-14 21:41 +0100

1#!/usr/bin/env python 

2# vim: set fileencoding=utf-8 : 

3# Tiago de Freitas Pereira <tiago.pereira@idiap.ch> 

4 

5 

6"""Executes biometric pipeline""" 

7 

8import functools 

9import logging 

10 

11from typing import List 

12 

13import click 

14import dask.distributed 

15 

16from clapper.click import ResourceOption, verbosity_option 

17from tabulate import tabulate 

18 

19import bob.io.base 

20 

21from bob.bio.base.pipelines import PipelineSimple, dask_bio_pipeline 

22from bob.pipelines import DelayedSample, Sample, SampleSet 

23 

24logger = logging.getLogger(__name__) 

25 

26EPILOG = """\n 

27 

28 

29 

30 Command line examples\n 

31 ----------------------- 

32 

33 >>> bob bio compare-samples ./imgs/1.png ./imgs/2.png -p inception_resnetv2_msceleb \n 

34 \n 

35 \n 

36 

37 All vs All comparison \n 

38 ------------------- ------------------- \n 

39 ./imgs/1.png ./imgs/2.png \n 

40 0.0 -0.5430189337666903 \n 

41 -0.5430189337666903 0.0 \n 

42 ------------------- ------------------- \n 

43 

44""" 

45 

46 

47@click.command(epilog=EPILOG) 

48@click.argument("samples", nargs=-1) 

49@click.option( 

50 "--pipeline", 

51 "-p", 

52 required=True, 

53 cls=ResourceOption, 

54 entry_point_group="bob.bio.pipeline", 

55 help="Vanilla biometrics pipeline composed of a scikit-learn Pipeline and a BioAlgorithm", 

56) 

57@click.option( 

58 "--dask-client", 

59 "-l", 

60 required=False, 

61 cls=ResourceOption, 

62 entry_point_group="dask.client", 

63 help="Dask client for the execution of the pipeline.", 

64) 

65@verbosity_option(logger=logger) 

66def compare_samples( 

67 samples: List[Sample], 

68 pipeline: PipelineSimple, 

69 dask_client: dask.distributed.Client, 

70 verbose: int, 

71): 

72 """Compare several samples in a All vs All fashion.""" 

73 if len(samples) == 1: 

74 raise ValueError( 

75 "It's necessary to have at least two samples for the comparison" 

76 ) 

77 

78 sample_sets = [ 

79 SampleSet( 

80 [DelayedSample(functools.partial(bob.io.base.load, s), key=str(s))], 

81 key=str(s), 

82 biometric_id=str(i), 

83 ) 

84 for i, s in enumerate(samples) 

85 ] 

86 if dask_client is not None: 

87 pipeline = dask_bio_pipeline(pipeline) 

88 

89 table = [[s for s in samples]] 

90 enroll_templates = pipeline.enroll_templates(sample_sets) 

91 probe_templates = pipeline.probe_templates(sample_sets) 

92 scores = pipeline.compute_scores( 

93 probe_templates, enroll_templates, score_all_vs_all=True 

94 ) 

95 if dask_client is not None: 

96 scores = scores.compute(scheduler=dask_client) 

97 for sset in scores: 

98 table.append([str(s.data) for s in sset]) 

99 

100 print("All vs All comparison") 

101 print(tabulate(table)) 

102 

103 if dask_client is not None: 

104 dask_client.shutdown()