Coverage for src/bob/bio/base/algorithm/jfa.py: 64%

22 statements  

« prev     ^ index     » next       coverage.py v7.6.0, created at 2024-07-12 22:34 +0200

1import logging 

2import pickle 

3 

4from bob.bio.base.pipelines import BioAlgorithm 

5from bob.learn.em import JFAMachine 

6 

7logger = logging.getLogger(__name__) 

8 

9 

10class JFA(JFAMachine, BioAlgorithm): 

11 """JFA transformer and bioalgorithm to be used in pipelines""" 

12 

13 def transform(self, X): 

14 """Passthrough""" 

15 return X 

16 

17 def create_templates(self, list_of_feature_sets, enroll): 

18 if enroll: 

19 return [ 

20 self.enroll(feature_set) for feature_set in list_of_feature_sets 

21 ] 

22 else: 

23 # TODO: We should compute these parts of self.score: 

24 # x = self.estimate_x(data) 

25 # Ux = self._U @ x 

26 # here to make scoring faster 

27 return list_of_feature_sets 

28 

29 def compare(self, enroll_templates, probe_templates): 

30 # TODO: The underlying score method actually supports batched scoring 

31 return [ 

32 [self.score(enroll, probe) for probe in probe_templates] 

33 for enroll in enroll_templates 

34 ] 

35 

36 @classmethod 

37 def custom_enrolled_save_fn(cls, data, path): 

38 pickle.dump(data, open(path, "wb")) 

39 

40 @classmethod 

41 def custom_enrolled_load_fn(cls, path): 

42 return pickle.load(open(path, "rb")) 

43 

44 def _more_tags(self): 

45 return { 

46 "bob_fit_supports_dask_bag": True, 

47 "bob_fit_extra_input": [("y", "subject_id_int")], 

48 "bob_enrolled_save_fn": self.custom_enrolled_save_fn, 

49 "bob_enrolled_load_fn": self.custom_enrolled_load_fn, 

50 "bob_checkpoint_features": False, 

51 }