Coverage for src/bob/bio/base/script/compare_samples.py: 89%
35 statements
« prev ^ index » next coverage.py v7.6.5, created at 2024-11-14 21:41 +0100
« prev ^ index » next coverage.py v7.6.5, created at 2024-11-14 21:41 +0100
1#!/usr/bin/env python
2# vim: set fileencoding=utf-8 :
3# Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
6"""Executes biometric pipeline"""
8import functools
9import logging
11from typing import List
13import click
14import dask.distributed
16from clapper.click import ResourceOption, verbosity_option
17from tabulate import tabulate
19import bob.io.base
21from bob.bio.base.pipelines import PipelineSimple, dask_bio_pipeline
22from bob.pipelines import DelayedSample, Sample, SampleSet
24logger = logging.getLogger(__name__)
26EPILOG = """\n
30 Command line examples\n
31 -----------------------
33 >>> bob bio compare-samples ./imgs/1.png ./imgs/2.png -p inception_resnetv2_msceleb \n
34 \n
35 \n
37 All vs All comparison \n
38 ------------------- ------------------- \n
39 ./imgs/1.png ./imgs/2.png \n
40 0.0 -0.5430189337666903 \n
41 -0.5430189337666903 0.0 \n
42 ------------------- ------------------- \n
44"""
47@click.command(epilog=EPILOG)
48@click.argument("samples", nargs=-1)
49@click.option(
50 "--pipeline",
51 "-p",
52 required=True,
53 cls=ResourceOption,
54 entry_point_group="bob.bio.pipeline",
55 help="Vanilla biometrics pipeline composed of a scikit-learn Pipeline and a BioAlgorithm",
56)
57@click.option(
58 "--dask-client",
59 "-l",
60 required=False,
61 cls=ResourceOption,
62 entry_point_group="dask.client",
63 help="Dask client for the execution of the pipeline.",
64)
65@verbosity_option(logger=logger)
66def compare_samples(
67 samples: List[Sample],
68 pipeline: PipelineSimple,
69 dask_client: dask.distributed.Client,
70 verbose: int,
71):
72 """Compare several samples in a All vs All fashion."""
73 if len(samples) == 1:
74 raise ValueError(
75 "It's necessary to have at least two samples for the comparison"
76 )
78 sample_sets = [
79 SampleSet(
80 [DelayedSample(functools.partial(bob.io.base.load, s), key=str(s))],
81 key=str(s),
82 biometric_id=str(i),
83 )
84 for i, s in enumerate(samples)
85 ]
86 if dask_client is not None:
87 pipeline = dask_bio_pipeline(pipeline)
89 table = [[s for s in samples]]
90 enroll_templates = pipeline.enroll_templates(sample_sets)
91 probe_templates = pipeline.probe_templates(sample_sets)
92 scores = pipeline.compute_scores(
93 probe_templates, enroll_templates, score_all_vs_all=True
94 )
95 if dask_client is not None:
96 scores = scores.compute(scheduler=dask_client)
97 for sset in scores:
98 table.append([str(s.data) for s in sset])
100 print("All vs All comparison")
101 print(tabulate(table))
103 if dask_client is not None:
104 dask_client.shutdown()