Coverage for src/bob/bio/spear/annotator/energy_2gauss.py: 95%

58 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2023-12-06 22:04 +0100

1#!/usr/bin/env python 

2# vim: set fileencoding=utf-8 : 

3# @author: Elie Khoury <Elie.Khoury@idiap.ch> 

4# @date: Sun 7 Jun 15:41:03 CEST 2015 

5# 

6# Copyright (C) 2012-2015 Idiap Research Institute, Martigny, Switzerland 

7# 

8# This program is free software: you can redistribute it and/or modify 

9# it under the terms of the GNU General Public License as published by 

10# the Free Software Foundation, version 3 of the License. 

11# 

12# This program is distributed in the hope that it will be useful, 

13# but WITHOUT ANY WARRANTY; without even the implied warranty of 

14# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

15# GNU General Public License for more details. 

16# 

17# You should have received a copy of the GNU General Public License 

18# along with this program. If not, see <http://www.gnu.org/licenses/>. 

19 

20"""Energy-based voice activity detection for speaker recognition""" 

21 

22import logging 

23 

24import dask 

25import numpy as np 

26 

27from bob.bio.base.annotator import Annotator 

28from bob.learn.em import GMMMachine, KMeansMachine 

29 

30from .. import audio_processing as ap 

31from .. import utils 

32 

33logger = logging.getLogger(__name__) 

34 

35 

36class Energy_2Gauss(Annotator): 

37 """Detects the Voice Activity using the Energy of the signal and 2 Gaussian GMM.""" 

38 

39 def __init__( 

40 self, 

41 max_iterations=10, # 10 iterations for the GMM trainer 

42 convergence_threshold=0.0005, 

43 variance_threshold=0.0005, 

44 win_length_ms=20.0, # 20 ms 

45 win_shift_ms=10.0, # 10 ms 

46 smoothing_window=10, # 10 frames (i.e. 100 ms) 

47 **kwargs, 

48 ): 

49 super().__init__(**kwargs) 

50 self.max_iterations = max_iterations 

51 self.convergence_threshold = convergence_threshold 

52 self.variance_threshold = variance_threshold 

53 self.win_length_ms = win_length_ms 

54 self.win_shift_ms = win_shift_ms 

55 self.smoothing_window = smoothing_window 

56 

57 def _voice_activity_detection(self, energy_array: np.ndarray) -> np.ndarray: 

58 """Fits a 2 Gaussian GMM on the energy that splits between voice and silence.""" 

59 n_samples = len(energy_array) 

60 # if energy does not change a lot, it may not be audio? 

61 if np.std(energy_array) < 10e-5: 

62 return np.zeros(shape=n_samples) 

63 

64 # Add an epsilon small Gaussian noise to avoid numerical issues (mainly due to artificial silence). 

65 energy_array = (1e-6 * np.random.randn(n_samples)) + energy_array 

66 

67 # Normalize the energy array, make it an array of 1D samples 

68 normalized_energy = utils.normalize_std_array(energy_array).reshape( 

69 (-1, 1) 

70 ) 

71 

72 # Note: self.max_iterations and self.convergence_threshold are used for both 

73 # k-means and GMM training. 

74 kmeans_trainer = KMeansMachine( 

75 n_clusters=2, 

76 convergence_threshold=self.convergence_threshold, 

77 max_iter=self.max_iterations, 

78 init_max_iter=self.max_iterations, 

79 ) 

80 ubm_gmm = GMMMachine( 

81 n_gaussians=2, 

82 trainer="ml", 

83 update_means=True, 

84 update_variances=True, 

85 update_weights=True, 

86 convergence_threshold=self.convergence_threshold, 

87 max_fitting_steps=self.max_iterations, 

88 k_means_trainer=kmeans_trainer, 

89 ) 

90 ubm_gmm.variance_thresholds = self.variance_threshold 

91 

92 ubm_gmm.fit(normalized_energy) 

93 

94 if np.isnan(ubm_gmm.means).any(): 

95 logger.warn("Annotation aborted: File contains NaN's") 

96 return np.zeros(shape=n_samples, dtype=int) 

97 

98 # Classify 

99 

100 # Different behavior dep on which mean represents high energy (higher value) 

101 labels = ubm_gmm.log_weighted_likelihood(normalized_energy) 

102 if ubm_gmm.means.argmax() == 0: # High energy in means[0] 

103 labels = labels.argmin(axis=0) 

104 else: # High energy in means[1] 

105 labels = labels.argmax(axis=0) 

106 

107 return labels 

108 

109 def _compute_energy( 

110 self, audio_signal: np.ndarray, sample_rate: int 

111 ) -> np.ndarray: 

112 """Retrieves the speech / non speech labels for the speech sample in ``audio_signal``""" 

113 

114 energy_array = ap.energy( 

115 audio_signal, 

116 sample_rate, 

117 win_length_ms=self.win_length_ms, 

118 win_shift_ms=self.win_shift_ms, 

119 ) 

120 labels = self._voice_activity_detection(energy_array) 

121 

122 # discard isolated speech a number of frames defined in smoothing_window 

123 labels = utils.smoothing(labels, self.smoothing_window) 

124 logger.debug( 

125 "After 2 Gaussian Energy-based VAD there are %d frames remaining over %d", 

126 np.sum(labels), 

127 len(labels), 

128 ) 

129 return labels 

130 

131 def transform_one(self, audio_signal: np.ndarray, sample_rate: int) -> list: 

132 """labels speech (1) and non-speech (0) parts of the given input wave file using 2 Gaussian-modeled Energy 

133 Parameters 

134 ---------- 

135 audio_signal: array 

136 Audio signal to annotate 

137 sample_rate: int 

138 The sample rate in Hertz 

139 """ 

140 labels = self._compute_energy( 

141 audio_signal=audio_signal, sample_rate=sample_rate 

142 ) 

143 if (labels == 0).all(): 

144 logger.warning( 

145 "Could not annotate: No audio was detected in the sample!" 

146 ) 

147 return None 

148 return labels.tolist() 

149 

150 def transform( 

151 self, audio_signals: "list[np.ndarray]", sample_rates: "list[int]" 

152 ): 

153 with dask.config.set(scheduler="threads"): 

154 results = [] 

155 for audio_signal, sample_rate in zip(audio_signals, sample_rates): 

156 results.append(self.transform_one(audio_signal, sample_rate)) 

157 return results 

158 

159 def fit(self, X, y=None, **fit_params): 

160 return self 

161 

162 def _more_tags(self): 

163 return { 

164 "requires_fit": False, 

165 "bob_transform_extra_input": (("sample_rates", "rate"),), 

166 "bob_output": "annotations", 

167 }