Coverage for src/bob/bio/spear/annotator/energy_2gauss.py: 95%
58 statements
« prev ^ index » next coverage.py v7.3.2, created at 2023-12-06 22:04 +0100
« prev ^ index » next coverage.py v7.3.2, created at 2023-12-06 22:04 +0100
1#!/usr/bin/env python
2# vim: set fileencoding=utf-8 :
3# @author: Elie Khoury <Elie.Khoury@idiap.ch>
4# @date: Sun 7 Jun 15:41:03 CEST 2015
5#
6# Copyright (C) 2012-2015 Idiap Research Institute, Martigny, Switzerland
7#
8# This program is free software: you can redistribute it and/or modify
9# it under the terms of the GNU General Public License as published by
10# the Free Software Foundation, version 3 of the License.
11#
12# This program is distributed in the hope that it will be useful,
13# but WITHOUT ANY WARRANTY; without even the implied warranty of
14# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15# GNU General Public License for more details.
16#
17# You should have received a copy of the GNU General Public License
18# along with this program. If not, see <http://www.gnu.org/licenses/>.
20"""Energy-based voice activity detection for speaker recognition"""
22import logging
24import dask
25import numpy as np
27from bob.bio.base.annotator import Annotator
28from bob.learn.em import GMMMachine, KMeansMachine
30from .. import audio_processing as ap
31from .. import utils
33logger = logging.getLogger(__name__)
36class Energy_2Gauss(Annotator):
37 """Detects the Voice Activity using the Energy of the signal and 2 Gaussian GMM."""
39 def __init__(
40 self,
41 max_iterations=10, # 10 iterations for the GMM trainer
42 convergence_threshold=0.0005,
43 variance_threshold=0.0005,
44 win_length_ms=20.0, # 20 ms
45 win_shift_ms=10.0, # 10 ms
46 smoothing_window=10, # 10 frames (i.e. 100 ms)
47 **kwargs,
48 ):
49 super().__init__(**kwargs)
50 self.max_iterations = max_iterations
51 self.convergence_threshold = convergence_threshold
52 self.variance_threshold = variance_threshold
53 self.win_length_ms = win_length_ms
54 self.win_shift_ms = win_shift_ms
55 self.smoothing_window = smoothing_window
57 def _voice_activity_detection(self, energy_array: np.ndarray) -> np.ndarray:
58 """Fits a 2 Gaussian GMM on the energy that splits between voice and silence."""
59 n_samples = len(energy_array)
60 # if energy does not change a lot, it may not be audio?
61 if np.std(energy_array) < 10e-5:
62 return np.zeros(shape=n_samples)
64 # Add an epsilon small Gaussian noise to avoid numerical issues (mainly due to artificial silence).
65 energy_array = (1e-6 * np.random.randn(n_samples)) + energy_array
67 # Normalize the energy array, make it an array of 1D samples
68 normalized_energy = utils.normalize_std_array(energy_array).reshape(
69 (-1, 1)
70 )
72 # Note: self.max_iterations and self.convergence_threshold are used for both
73 # k-means and GMM training.
74 kmeans_trainer = KMeansMachine(
75 n_clusters=2,
76 convergence_threshold=self.convergence_threshold,
77 max_iter=self.max_iterations,
78 init_max_iter=self.max_iterations,
79 )
80 ubm_gmm = GMMMachine(
81 n_gaussians=2,
82 trainer="ml",
83 update_means=True,
84 update_variances=True,
85 update_weights=True,
86 convergence_threshold=self.convergence_threshold,
87 max_fitting_steps=self.max_iterations,
88 k_means_trainer=kmeans_trainer,
89 )
90 ubm_gmm.variance_thresholds = self.variance_threshold
92 ubm_gmm.fit(normalized_energy)
94 if np.isnan(ubm_gmm.means).any():
95 logger.warn("Annotation aborted: File contains NaN's")
96 return np.zeros(shape=n_samples, dtype=int)
98 # Classify
100 # Different behavior dep on which mean represents high energy (higher value)
101 labels = ubm_gmm.log_weighted_likelihood(normalized_energy)
102 if ubm_gmm.means.argmax() == 0: # High energy in means[0]
103 labels = labels.argmin(axis=0)
104 else: # High energy in means[1]
105 labels = labels.argmax(axis=0)
107 return labels
109 def _compute_energy(
110 self, audio_signal: np.ndarray, sample_rate: int
111 ) -> np.ndarray:
112 """Retrieves the speech / non speech labels for the speech sample in ``audio_signal``"""
114 energy_array = ap.energy(
115 audio_signal,
116 sample_rate,
117 win_length_ms=self.win_length_ms,
118 win_shift_ms=self.win_shift_ms,
119 )
120 labels = self._voice_activity_detection(energy_array)
122 # discard isolated speech a number of frames defined in smoothing_window
123 labels = utils.smoothing(labels, self.smoothing_window)
124 logger.debug(
125 "After 2 Gaussian Energy-based VAD there are %d frames remaining over %d",
126 np.sum(labels),
127 len(labels),
128 )
129 return labels
131 def transform_one(self, audio_signal: np.ndarray, sample_rate: int) -> list:
132 """labels speech (1) and non-speech (0) parts of the given input wave file using 2 Gaussian-modeled Energy
133 Parameters
134 ----------
135 audio_signal: array
136 Audio signal to annotate
137 sample_rate: int
138 The sample rate in Hertz
139 """
140 labels = self._compute_energy(
141 audio_signal=audio_signal, sample_rate=sample_rate
142 )
143 if (labels == 0).all():
144 logger.warning(
145 "Could not annotate: No audio was detected in the sample!"
146 )
147 return None
148 return labels.tolist()
150 def transform(
151 self, audio_signals: "list[np.ndarray]", sample_rates: "list[int]"
152 ):
153 with dask.config.set(scheduler="threads"):
154 results = []
155 for audio_signal, sample_rate in zip(audio_signals, sample_rates):
156 results.append(self.transform_one(audio_signal, sample_rate))
157 return results
159 def fit(self, X, y=None, **fit_params):
160 return self
162 def _more_tags(self):
163 return {
164 "requires_fit": False,
165 "bob_transform_extra_input": (("sample_rates", "rate"),),
166 "bob_output": "annotations",
167 }