Coverage for src/bob/bio/spear/extractor/speechbrain_embeddings.py: 39%

28 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2023-12-06 22:04 +0100

1import numpy as np 

2import torch 

3 

4from sklearn.base import BaseEstimator 

5 

6 

7class SpeechbrainEmbeddings(BaseEstimator): 

8 def __init__(self, **kwargs) -> None: 

9 # later on we will add source and savedir as input parameters to allow 

10 # loading of different models 

11 super().__init__(**kwargs) 

12 

13 # set model to None for load_model call 

14 self.model = None 

15 # ensure the files are downloaded before dask execution 

16 self.load_model() 

17 # only load models when they are used. (Prevents model transfer over the network) 

18 self.model = None 

19 

20 def load_model(self): 

21 if self.model is not None: 

22 return 

23 from speechbrain.pretrained import EncoderClassifier 

24 

25 self.model = EncoderClassifier.from_hparams( 

26 source="speechbrain/spkrec-ecapa-voxceleb", 

27 savedir="pretrained_models/spkrec-ecapa-voxceleb", 

28 ) 

29 

30 def fit(self, X, y=None): 

31 return self 

32 

33 def transform_one(self, audio_track): 

34 return self.model.encode_batch( 

35 torch.from_numpy(audio_track), 

36 normalize=True, 

37 ).numpy() 

38 

39 def transform(self, audio_tracks, y=None): 

40 # actual load of the model (on the workers) 

41 self.load_model() 

42 embeddings = [ 

43 self.transform_one(audio_track) for audio_track in audio_tracks 

44 ] 

45 

46 return np.vstack(embeddings) 

47 

48 def __getstate__(self): 

49 # Handling unpicklable objects 

50 d = self.__dict__.copy() 

51 d["model"] = None 

52 return d 

53 

54 def _more_tags(self): 

55 return {"requires_fit": False}