Coverage for src/bob/bio/base/pipelines/abstract_classes.py: 89%
145 statements
« prev ^ index » next coverage.py v7.6.5, created at 2024-11-14 21:41 +0100
« prev ^ index » next coverage.py v7.6.5, created at 2024-11-14 21:41 +0100
1#!/usr/bin/env python
2# vim: set fileencoding=utf-8 :
5import logging
6import os
8from abc import ABCMeta, abstractmethod
9from typing import Any, Callable, Optional, Union
11import numpy as np
13from sklearn.base import BaseEstimator
15from bob.pipelines import Sample, SampleBatch, SampleSet
16from bob.pipelines.wrappers import _frmt
18logger = logging.getLogger(__name__)
21def reduce_scores(
22 scores: np.ndarray,
23 axis: int,
24 fn: Union[str, Callable[[np.ndarray, int], np.ndarray]] = "max",
25):
26 """
27 Reduce scores using a function.
29 Parameters:
30 -----------
31 scores
32 Scores to reduce.
34 fn
35 Function to use for reduction. You can also provide a string like
36 ``max`` to use the corresponding function from numpy. Some possible
37 values are: ``max``, ``min``, ``mean``, ``median``, ``sum``.
39 Returns:
40 --------
41 Reduced scores.
42 """
43 if isinstance(fn, str):
44 fn = getattr(np, fn)
45 return fn(scores, axis=axis)
48def _data_valid(data: Any) -> bool:
49 """Check if data is valid.
51 Parameters:
52 -----------
53 data
54 Data to check.
56 Returns:
57 --------
58 True if data is valid, False otherwise.
59 """
60 if data is None:
61 return False
62 if isinstance(data, np.ndarray):
63 return data.size > 0
64 # we also have to check for [[]]
65 if isinstance(data, list) and len(data) > 0:
66 if isinstance(data[0], (list, tuple)):
67 return len(data[0]) > 0
68 return bool(data)
71class BioAlgorithm(BaseEstimator, metaclass=ABCMeta):
72 """Describes a base biometric comparator for the PipelineSimple
73 :ref:`bob.bio.base.biometric_algorithm`.
75 A biometric algorithm converts each SampleSet (which is a list of
76 samples/features) into a single template. Template creation is done for both
77 enroll and probe samples but the format of the templates can be different
78 between enrollment and probe samples. After the creation of the templates,
79 the algorithm computes one similarity score for comparison of an enroll
80 template with a probe template.
82 Examples
83 --------
84 >>> import numpy as np
85 >>> from bob.bio.base.pipelines import BioAlgorithm
86 >>> class MyAlgorithm(BioAlgorithm):
87 ...
88 ... def create_templates(self, list_of_feature_sets, enroll):
89 ... # you cannot call np.mean(list_of_feature_sets, axis=1) because the
90 ... # number of features in each feature set may vary.
91 ... return [np.mean(feature_set, axis=0) for feature_set in list_of_feature_sets]
92 ...
93 ... def compare(self, enroll_templates, probe_templates):
94 ... scores = []
95 ... for enroll_template in enroll_templates:
96 ... scores.append([])
97 ... for probe_template in probe_templates:
98 ... similarity = 1 / np.linalg.norm(model - probe)
99 ... scores[-1].append(similarity)
100 ... scores = np.array(scores, dtype=float)
101 ... return scores
102 """
104 def __init__(
105 self,
106 probes_score_fusion: Union[
107 str, Callable[[list[np.ndarray], int], np.ndarray]
108 ] = "max",
109 enrolls_score_fusion: Union[
110 str, Callable[[list[np.ndarray], int], np.ndarray]
111 ] = "max",
112 **kwargs,
113 ) -> None:
114 super().__init__(**kwargs)
115 self.probes_score_fusion = probes_score_fusion
116 self.enrolls_score_fusion = enrolls_score_fusion
118 def fuse_probe_scores(self, scores, axis):
119 return reduce_scores(scores, axis, self.probes_score_fusion)
121 def fuse_enroll_scores(self, scores, axis):
122 return reduce_scores(scores, axis, self.enrolls_score_fusion)
124 @abstractmethod
125 def create_templates(
126 self, list_of_feature_sets: list[Any], enroll: bool
127 ) -> list[Sample]:
128 """Creates enroll or probe templates from multiple sets of features.
130 The enroll template format can be different from the probe templates.
132 Parameters
133 ----------
134 list_of_feature_sets
135 A list of list of features with the shape of Nx?xD. N templates
136 should be computed. Note that you cannot call
137 np.array(list_of_feature_sets) because the number of features per
138 set can be different depending on the database.
139 enroll
140 If True, the features are for enrollment. If False, the features are
141 for probe.
143 Returns
144 -------
145 templates
146 A list of templates which has the same length as
147 ``list_of_feature_sets``.
148 """
149 pass
151 @abstractmethod
152 def compare(
153 self, enroll_templates: list[Sample], probe_templates: list[Sample]
154 ) -> np.ndarray:
155 """Computes the similarity score between all enrollment and probe templates.
157 Parameters
158 ----------
159 enroll_templates
160 A list (length N) of enrollment templates.
162 probe_templates
163 A list (length M) of probe templates.
165 Returns
166 -------
167 scores
168 A matrix of shape (N, M) containing the similarity scores.
169 """
170 pass
172 def create_templates_from_samplesets(
173 self, list_of_samplesets: list[SampleSet], enroll: bool
174 ) -> list[Sample]:
175 """Creates enroll or probe templates from multiple SampleSets.
177 Parameters
178 ----------
179 list_of_samplesets
180 A list (length N) of SampleSets.
182 enroll
183 If True, the SampleSets are for enrollment. If False, the SampleSets
184 are for probe.
186 Returns
187 -------
188 templates
189 A list of Samples which has the same length as ``list_of_samplesets``.
190 Each Sample contains a template.
191 """
192 logger.debug(
193 f"{_frmt(self)}.create_templates_from_samplesets(... enroll={enroll})"
194 )
195 # create templates from .data attribute of samples inside sample_sets
196 list_of_feature_sets = []
197 for sampleset in list_of_samplesets:
198 data = [s.data for s in sampleset.samples]
199 valid_data = [d for d in data if d is not None]
200 if len(data) != len(valid_data):
201 logger.warning(
202 f"Removed {len(data)-len(valid_data)} invalid enrollment samples."
203 )
204 if not valid_data and enroll:
205 # we do not support failure to enroll cases currently
206 raise NotImplementedError(
207 f"None of the enrollment samples were valid for {sampleset}."
208 )
209 list_of_feature_sets.append(valid_data)
211 templates = self.create_templates(list_of_feature_sets, enroll)
212 expected_size = len(list_of_samplesets)
213 assert len(templates) == expected_size, (
214 "The number of (%s) templates (%d) created by the algorithm does not match "
215 "the number of sample sets (%d)"
216 % (
217 "enroll" if enroll else "probe",
218 len(templates),
219 expected_size,
220 )
221 )
222 # return a list of Samples (one per template)
223 templates = [
224 Sample(t, parent=sampleset)
225 for t, sampleset in zip(templates, list_of_samplesets)
226 ]
227 return templates
229 def score_sample_templates(
230 self,
231 probe_samples: list[Sample],
232 enroll_samples: list[Sample],
233 score_all_vs_all: bool,
234 ) -> list[SampleSet]:
235 """Computes the similarity score between all probe and enroll templates.
237 Parameters
238 ----------
239 probe_samples
240 A list (length N) of Samples containing probe templates.
242 enroll_samples
243 A list (length M) of Samples containing enroll templates.
245 score_all_vs_all
246 If True, the similarity scores between all probe and enroll templates
247 are computed. If False, the similarity scores between the probes and
248 their associated enroll templates are computed.
250 Returns
251 -------
252 score_samplesets
253 A list of N SampleSets each containing a list of M score Samples if score_all_vs_all
254 is True. Otherwise, a list of N SampleSets each containing a list of <=M score Samples
255 depending on the database.
256 """
257 logger.debug(
258 f"{_frmt(self)}.score_sample_templates(... score_all_vs_all={score_all_vs_all})"
259 )
260 # Returns a list of SampleSets where a Sampleset for each probe
261 # SampleSet where each Sample inside the SampleSets contains the score
262 # for one enroll SampleSet
263 score_samplesets = []
264 if score_all_vs_all:
265 probe_data = [s.data for s in probe_samples]
266 valid_probe_indices = [
267 i for i, d in enumerate(probe_data) if _data_valid(d)
268 ]
269 valid_probe_data = [probe_data[i] for i in valid_probe_indices]
270 scores = self.compare(SampleBatch(enroll_samples), valid_probe_data)
271 scores = np.asarray(scores, dtype=float)
273 if len(valid_probe_indices) != len(probe_data):
274 # inject None scores for invalid probe samples
275 scores: list = scores.T.tolist()
276 for i in range(len(probe_data)):
277 if i not in valid_probe_indices:
278 scores.insert(i, [None] * len(enroll_samples))
279 # transpose back to original shape
280 scores = np.array(scores, dtype=float).T
282 expected_shape = (len(enroll_samples), len(probe_samples))
283 assert scores.shape == expected_shape, (
284 "The shape of the similarity scores (%s) does not match the expected shape (%s)"
285 % (scores.shape, expected_shape)
286 )
287 for j, probe in enumerate(probe_samples):
288 samples = []
289 for i, enroll in enumerate(enroll_samples):
290 samples.append(Sample(scores[i, j], parent=enroll))
291 score_samplesets.append(SampleSet(samples, parent=probe))
292 else:
293 for probe in probe_samples:
294 references = [str(ref) for ref in probe.references]
295 # get the indices of references for enroll samplesets
296 indices = [
297 i
298 for i, enroll in enumerate(enroll_samples)
299 if str(enroll.template_id) in references
300 ]
301 if not indices:
302 raise ValueError(
303 f"No enroll sampleset found for probe {probe} and its required references {references}. "
304 "Did you mean to set score_all_vs_all=True?"
305 )
306 if not _data_valid(probe.data):
307 scores = [[None]] * len(indices)
308 else:
309 scores = self.compare(
310 SampleBatch([enroll_samples[i] for i in indices]),
311 SampleBatch([probe]),
312 )
313 scores = np.asarray(scores, dtype=float)
314 expected_shape = (len(indices), 1)
315 assert scores.shape == expected_shape, (
316 "The shape of the similarity scores (%s) does not match the expected shape (%s)"
317 % (scores.shape, expected_shape)
318 )
319 samples = []
320 for i, j in enumerate(indices):
321 samples.append(
322 Sample(scores[i, 0], parent=enroll_samples[j])
323 )
324 score_samplesets.append(SampleSet(samples, parent=probe))
326 return score_samplesets
329class Database(metaclass=ABCMeta):
330 """Base class for PipelineSimple databases"""
332 def __init__(
333 self,
334 protocol: Optional[str] = None,
335 score_all_vs_all: bool = False,
336 annotation_type: Optional[str] = None,
337 fixed_positions: Optional[str] = None,
338 memory_demanding: bool = False,
339 **kwargs,
340 ):
341 """
342 Parameters
343 ----------
344 protocol
345 Name of the database protocol to use.
346 score_all_vs_all
347 Wether to allow scoring of all the probes against all the references, or to
348 provide a list ``references`` provided with each probes to indicate against
349 which references it needs to be compared.
350 annotation_type
351 The type of annotation passed to the annotation loading function.
352 fixed_positions
353 The constant eyes positions passed to the annotation loading function.
354 TODO why keep this face-related name here? Which one is it, too (position
355 when annotations are missing, or ending position in the result image)?
356 --> move this when the FaceCrop annotator is correctly implemented.
357 memory_demanding
358 Flag to indicate that this should not be loaded locally.
359 TODO Where is it used?
360 """
361 super().__init__(**kwargs)
362 if not hasattr(self, "protocol"):
363 self.protocol = protocol
364 self.score_all_vs_all = score_all_vs_all
365 self.annotation_type = annotation_type
366 self.fixed_positions = fixed_positions
367 self.memory_demanding = memory_demanding
369 def __str__(self) -> str:
370 args = ", ".join(
371 [
372 "{}={}".format(k, v)
373 for k, v in self.__dict__.items()
374 if not k.startswith("_")
375 ]
376 )
377 return f"{self.__class__.__name__}({args})"
379 @abstractmethod
380 def background_model_samples(self) -> list[Sample]:
381 """Returns :any:`Sample`\ s to train a background model
384 Returns
385 -------
386 samples
387 List of samples for background model training.
389 """ # noqa: W605
390 pass
392 @abstractmethod
393 def references(self, group: str = "dev") -> list[SampleSet]:
394 """Returns references to enroll biometric references
397 Parameters
398 ----------
399 group
400 Limits samples to this group
403 Returns
404 -------
405 references
406 List of samples for the creation of biometric references.
408 """
409 pass
411 @abstractmethod
412 def probes(self, group: str = "dev") -> list[SampleSet]:
413 """Returns probes to score against enrolled biometric references
416 Parameters
417 ----------
418 group
419 Limits samples to this group
422 Returns
423 -------
424 probes
425 List of samples for the creation of biometric probes.
427 """
428 pass
430 @abstractmethod
431 def all_samples(self, groups: Optional[str] = None) -> list[Sample]:
432 """Returns all the samples of the dataset
434 Parameters
435 ----------
436 groups
437 List of groups to consider (like 'dev' or 'eval'). If `None`, will
438 return samples from all the groups.
440 Returns
441 -------
442 samples
443 List of all the samples of the dataset.
444 """
445 pass
447 @abstractmethod
448 def groups(self) -> list[str]:
449 """Returns all the possible groups for the current protocol."""
450 pass
452 @abstractmethod
453 def protocols(self) -> list[str]:
454 """Returns all the possible protocols of the database."""
455 pass
457 def template_ids(self, group: str) -> list[Any]:
458 """Returns the ``template_id`` attribute of each reference."""
459 return [s.template_id for s in self.references(group=group)]
462class ScoreWriter(metaclass=ABCMeta):
463 """
464 Defines base methods to read, write scores and concatenate scores
465 for :any:`bob.bio.base.pipelines.BioAlgorithm`
466 """
468 def __init__(self, path, extension=".txt", **kwargs):
469 super().__init__(**kwargs)
470 self.path = path
471 self.extension = extension
473 @abstractmethod
474 def write(self, sampleset, path):
475 pass
477 def post_process(self, score_paths, filename):
478 def _post_process(score_paths, filename):
479 os.makedirs(os.path.dirname(filename), exist_ok=True)
480 with open(filename, "w") as f:
481 for path in score_paths:
482 with open(path) as f2:
483 f.writelines(f2.readlines())
484 return filename
486 import dask
487 import dask.bag
489 if isinstance(score_paths, dask.bag.Bag):
490 all_paths = dask.delayed(list)(score_paths)
491 return dask.delayed(_post_process)(all_paths, filename)
492 return _post_process(score_paths, filename)