Coverage for src/bob/bio/base/transformers/extractor.py: 81%

26 statements  

« prev     ^ index     » next       coverage.py v7.6.5, created at 2024-11-14 21:41 +0100

1#!/usr/bin/env python 

2# vim: set fileencoding=utf-8 : 

3 

4from sklearn.base import BaseEstimator, TransformerMixin 

5 

6from bob.bio.base.extractor import Extractor 

7 

8from . import split_X_by_y 

9 

10 

11class ExtractorTransformer(TransformerMixin, BaseEstimator): 

12 """Scikit learn transformer for :py:class:`bob.bio.base.extractor.Extractor`. 

13 

14 Parameters 

15 ---------- 

16 instance: object 

17 An instance of :py:class:`bob.bio.base.extractor.Extractor` 

18 

19 model_path: ``str`` 

20 Model path in case ``instance.requires_training`` is equal to ``True``. 

21 """ 

22 

23 def __init__( 

24 self, 

25 instance, 

26 model_path=None, 

27 **kwargs, 

28 ): 

29 if not isinstance(instance, Extractor): 

30 raise ValueError( 

31 "`instance` should be an instance of `bob.bio.base.extractor.Extractor`" 

32 ) 

33 

34 if instance.requires_training and ( 

35 model_path is None or model_path == "" 

36 ): 

37 raise ValueError( 

38 f"`model_path` needs to be set if extractor {instance} requires training" 

39 ) 

40 

41 self.instance = instance 

42 self.model_path = model_path 

43 super().__init__(**kwargs) 

44 

45 def fit(self, X, y=None): 

46 if not self.instance.requires_training: 

47 return self 

48 

49 training_data = X 

50 if self.instance.split_training_data_by_client: 

51 training_data = split_X_by_y(X, y) 

52 

53 self.instance.train(training_data, self.model_path) 

54 return self 

55 

56 def transform(self, X, metadata=None): 

57 if metadata is None: 

58 return [self.instance(data) for data in X] 

59 else: 

60 return [ 

61 self.instance(data, metadata) 

62 for data, metadata in zip(X, metadata) 

63 ] 

64 

65 def _more_tags(self): 

66 return { 

67 "requires_fit": self.instance.requires_training, 

68 "bob_features_save_fn": self.instance.write_feature, 

69 "bob_features_load_fn": self.instance.read_feature, 

70 }