#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
from sklearn.base import BaseEstimator, TransformerMixin
from bob.bio.base.extractor import Extractor
from . import split_X_by_y
class ExtractorTransformer(TransformerMixin, BaseEstimator):
"""Scikit learn transformer for :py:class:`bob.bio.base.extractor.Extractor`.
Parameters
----------
instance: object
An instance of :py:class:`bob.bio.base.extractor.Extractor`
model_path: ``str``
Model path in case ``instance.requires_training`` is equal to ``True``.
"""
def __init__(
self,
instance,
model_path=None,
**kwargs,
):
if not isinstance(instance, Extractor):
raise ValueError(
"`instance` should be an instance of `bob.bio.base.extractor.Extractor`"
)
if instance.requires_training and (
model_path is None or model_path == ""
):
raise ValueError(
f"`model_path` needs to be set if extractor {instance} requires training"
)
self.instance = instance
self.model_path = model_path
super().__init__(**kwargs)
def _more_tags(self):
return {
"requires_fit": self.instance.requires_training,
"bob_features_save_fn": self.instance.write_feature,
"bob_features_load_fn": self.instance.read_feature,
}