Coverage for src/bob/bio/base/transformers/extractor.py: 81%
26 statements
« prev ^ index » next coverage.py v7.6.5, created at 2024-11-14 21:41 +0100
« prev ^ index » next coverage.py v7.6.5, created at 2024-11-14 21:41 +0100
1#!/usr/bin/env python
2# vim: set fileencoding=utf-8 :
4from sklearn.base import BaseEstimator, TransformerMixin
6from bob.bio.base.extractor import Extractor
8from . import split_X_by_y
11class ExtractorTransformer(TransformerMixin, BaseEstimator):
12 """Scikit learn transformer for :py:class:`bob.bio.base.extractor.Extractor`.
14 Parameters
15 ----------
16 instance: object
17 An instance of :py:class:`bob.bio.base.extractor.Extractor`
19 model_path: ``str``
20 Model path in case ``instance.requires_training`` is equal to ``True``.
21 """
23 def __init__(
24 self,
25 instance,
26 model_path=None,
27 **kwargs,
28 ):
29 if not isinstance(instance, Extractor):
30 raise ValueError(
31 "`instance` should be an instance of `bob.bio.base.extractor.Extractor`"
32 )
34 if instance.requires_training and (
35 model_path is None or model_path == ""
36 ):
37 raise ValueError(
38 f"`model_path` needs to be set if extractor {instance} requires training"
39 )
41 self.instance = instance
42 self.model_path = model_path
43 super().__init__(**kwargs)
45 def fit(self, X, y=None):
46 if not self.instance.requires_training:
47 return self
49 training_data = X
50 if self.instance.split_training_data_by_client:
51 training_data = split_X_by_y(X, y)
53 self.instance.train(training_data, self.model_path)
54 return self
56 def transform(self, X, metadata=None):
57 if metadata is None:
58 return [self.instance(data) for data in X]
59 else:
60 return [
61 self.instance(data, metadata)
62 for data, metadata in zip(X, metadata)
63 ]
65 def _more_tags(self):
66 return {
67 "requires_fit": self.instance.requires_training,
68 "bob_features_save_fn": self.instance.write_feature,
69 "bob_features_load_fn": self.instance.read_feature,
70 }