Coverage for src/bob/bio/base/algorithm/jfa.py: 64%
22 statements
« prev ^ index » next coverage.py v7.6.0, created at 2024-07-12 22:34 +0200
« prev ^ index » next coverage.py v7.6.0, created at 2024-07-12 22:34 +0200
1import logging
2import pickle
4from bob.bio.base.pipelines import BioAlgorithm
5from bob.learn.em import JFAMachine
7logger = logging.getLogger(__name__)
10class JFA(JFAMachine, BioAlgorithm):
11 """JFA transformer and bioalgorithm to be used in pipelines"""
13 def transform(self, X):
14 """Passthrough"""
15 return X
17 def create_templates(self, list_of_feature_sets, enroll):
18 if enroll:
19 return [
20 self.enroll(feature_set) for feature_set in list_of_feature_sets
21 ]
22 else:
23 # TODO: We should compute these parts of self.score:
24 # x = self.estimate_x(data)
25 # Ux = self._U @ x
26 # here to make scoring faster
27 return list_of_feature_sets
29 def compare(self, enroll_templates, probe_templates):
30 # TODO: The underlying score method actually supports batched scoring
31 return [
32 [self.score(enroll, probe) for probe in probe_templates]
33 for enroll in enroll_templates
34 ]
36 @classmethod
37 def custom_enrolled_save_fn(cls, data, path):
38 pickle.dump(data, open(path, "wb"))
40 @classmethod
41 def custom_enrolled_load_fn(cls, path):
42 return pickle.load(open(path, "rb"))
44 def _more_tags(self):
45 return {
46 "bob_fit_supports_dask_bag": True,
47 "bob_fit_extra_input": [("y", "subject_id_int")],
48 "bob_enrolled_save_fn": self.custom_enrolled_save_fn,
49 "bob_enrolled_load_fn": self.custom_enrolled_load_fn,
50 "bob_checkpoint_features": False,
51 }