Coverage for src/bob/bio/spear/extractor/speechbrain_embeddings.py: 39%
28 statements
« prev ^ index » next coverage.py v7.3.2, created at 2023-12-06 22:04 +0100
« prev ^ index » next coverage.py v7.3.2, created at 2023-12-06 22:04 +0100
1import numpy as np
2import torch
4from sklearn.base import BaseEstimator
7class SpeechbrainEmbeddings(BaseEstimator):
8 def __init__(self, **kwargs) -> None:
9 # later on we will add source and savedir as input parameters to allow
10 # loading of different models
11 super().__init__(**kwargs)
13 # set model to None for load_model call
14 self.model = None
15 # ensure the files are downloaded before dask execution
16 self.load_model()
17 # only load models when they are used. (Prevents model transfer over the network)
18 self.model = None
20 def load_model(self):
21 if self.model is not None:
22 return
23 from speechbrain.pretrained import EncoderClassifier
25 self.model = EncoderClassifier.from_hparams(
26 source="speechbrain/spkrec-ecapa-voxceleb",
27 savedir="pretrained_models/spkrec-ecapa-voxceleb",
28 )
30 def fit(self, X, y=None):
31 return self
33 def transform_one(self, audio_track):
34 return self.model.encode_batch(
35 torch.from_numpy(audio_track),
36 normalize=True,
37 ).numpy()
39 def transform(self, audio_tracks, y=None):
40 # actual load of the model (on the workers)
41 self.load_model()
42 embeddings = [
43 self.transform_one(audio_track) for audio_track in audio_tracks
44 ]
46 return np.vstack(embeddings)
48 def __getstate__(self):
49 # Handling unpicklable objects
50 d = self.__dict__.copy()
51 d["model"] = None
52 return d
54 def _more_tags(self):
55 return {"requires_fit": False}