Coverage for src/bob/learn/em/ivector.py: 93%
135 statements
« prev ^ index » next coverage.py v7.0.5, created at 2023-06-16 14:34 +0200
« prev ^ index » next coverage.py v7.0.5, created at 2023-06-16 14:34 +0200
1#!/usr/bin/env python
2# @author: Yannick Dayer <yannick.dayer@idiap.ch>
3# @date: Fri 06 May 2022 14:18:25 UTC+02
5import copy
6import logging
7import operator
9from typing import Any, Dict, List, Optional, Tuple, Union
11import dask
12import dask.bag
13import numpy as np
15from sklearn.base import BaseEstimator
17from bob.learn.em import GMMMachine, GMMStats
19logger = logging.getLogger("__name__")
22class IVectorStats:
23 """Stores I-Vector statistics. Can be used to accumulate multiple statistics.
25 **Attributes:**
26 nij_sigma_wij2: numpy.ndarray of shape (n_gaussians,dim_t,dim_t)
27 fnorm_sigma_wij: numpy.ndarray of shape (n_gaussians,n_features,dim_t)
28 snormij: numpy.ndarray of shape (n_gaussians,n_features)
29 nij: numpy.ndarray of shape (n_gaussians,)
30 """
32 def __init__(self, dim_c, dim_d, dim_t):
33 self.dim_c = dim_c
34 self.dim_d = dim_d
35 self.dim_t = dim_t
37 # Accumulator storage variables
39 # nij sigma wij2: shape = (c,t,t)
40 self.nij_sigma_wij2 = np.zeros(
41 shape=(self.dim_c, self.dim_t, self.dim_t), dtype=float
42 )
43 # fnorm sigma wij: shape = (c,d,t)
44 self.fnorm_sigma_wij = np.zeros(
45 shape=(self.dim_c, self.dim_d, self.dim_t), dtype=float
46 )
47 # Snormij (used only when updating sigma)
48 self.snormij = np.zeros(
49 shape=(
50 self.dim_c,
51 self.dim_d,
52 ),
53 dtype=float,
54 )
55 # Nij (used only when updating sigma)
56 self.nij = np.zeros(shape=(self.dim_c,), dtype=float)
58 @property
59 def shape(self) -> Tuple[int, int, int]:
60 return (self.dim_c, self.dim_d, self.dim_t)
62 def __add__(self, other):
63 if self.shape != other.shape:
64 raise ValueError("Cannot add stats of different shapes")
65 result = IVectorStats(self.dim_c, self.dim_d, self.dim_t)
66 result.nij_sigma_wij2 = self.nij_sigma_wij2 + other.nij_sigma_wij2
67 result.fnorm_sigma_wij = self.fnorm_sigma_wij + other.fnorm_sigma_wij
68 result.snormij = self.snormij + other.snormij
69 result.nij = self.nij + other.nij
70 return result
72 def __iadd__(self, other):
73 if self.shape != other.shape:
74 raise ValueError("Cannot add stats of different shapes")
75 self.nij_sigma_wij2 += other.nij_sigma_wij2
76 self.fnorm_sigma_wij += other.fnorm_sigma_wij
77 self.snormij += other.snormij
78 self.nij += other.nij
79 return self
82def compute_tct_sigmac_inv(T: np.ndarray, sigma: np.ndarray) -> np.ndarray:
83 """Computes T_{c}^{T}.sigma_{c}^{-1}"""
84 # TT_sigma_inv (c,t,d) = T.T (c,t,d) / sigma (c,1,d)
85 Tct_sigmacInv = T.transpose(0, 2, 1) / sigma[:, None, :]
87 # Tt_sigma_inv (c,t,d)
88 return Tct_sigmacInv
91def compute_tct_sigmac_inv_tc(T: np.ndarray, sigma: np.ndarray) -> np.ndarray:
92 """Computes T_{c}^{T}.sigma_{c}^{-1}.T_{c}"""
93 tct_sigmac_inv = compute_tct_sigmac_inv(T, sigma)
95 # (c,t,t) = (c,t,d) @ (c,d,t)
96 Tct_sigmacInv_Tc = tct_sigmac_inv @ T
98 # Output: shape (c,t,t)
99 return Tct_sigmacInv_Tc
102def compute_id_tt_sigma_inv_t(
103 stats: GMMStats, T: np.ndarray, sigma: np.ndarray
104) -> np.ndarray:
105 dim_t = T.shape[-1]
106 tct_sigmac_inv_tc = compute_tct_sigmac_inv_tc(T, sigma)
108 output = np.eye(dim_t, dim_t) + np.einsum(
109 "c,ctu->tu", stats.n, tct_sigmac_inv_tc
110 )
112 # Output: (t,t)
113 return output
116def compute_tt_sigma_inv_fnorm(
117 ubm_means: np.ndarray, stats: GMMStats, T: np.ndarray, sigma: np.ndarray
118) -> np.ndarray:
119 """Computes \f$(Id + \\sum_{c=1}^{C} N_{i,j,c} T^{T} \\Sigma_{c}^{-1} T)\f$
121 Returns an array of shape (t,)
122 """
124 tct_sigmac_inv = compute_tct_sigmac_inv(T, sigma) # (c,t,d)
125 fnorm = stats.sum_px - stats.n[:, None] * ubm_means # (c,d)
127 # (t,) += (t,d) @ (d) [repeated c times]
128 output = np.einsum("ctd,cd->t", tct_sigmac_inv, fnorm)
130 # Output: shape (t,)
131 return output
134def e_step(machine: "IVectorMachine", data: List[GMMStats]) -> IVectorStats:
135 """Computes the expectation step of the e-m algorithm."""
136 stats = IVectorStats(machine.dim_c, machine.dim_d, machine.dim_t)
138 for sample in data:
139 Nij = sample.n
140 Fij = sample.sum_px
141 Sij = sample.sum_pxx
143 # Estimate latent variables
144 TtSigmaInv_Fnorm = compute_tt_sigma_inv_fnorm(
145 machine.ubm.means, sample, machine.T, machine.sigma
146 ) # self.compute_TtSigmaInvFnorm(data[n]) # shape: (t,)
147 I_TtSigmaInvNT = compute_id_tt_sigma_inv_t(
148 sample, machine.T, machine.sigma
149 ) # self.compute_Id_TtSigmaInvT(data[n]), # shape: (t,t)
151 # Latent variables
152 I_TtSigmaInvNT_inv = np.linalg.inv(I_TtSigmaInvNT) # shape: (t,t)
153 sigma_w_ij = np.dot(I_TtSigmaInvNT_inv, TtSigmaInv_Fnorm) # shape: (t,)
154 sigma_w_ij2 = I_TtSigmaInvNT_inv + np.outer(
155 sigma_w_ij, sigma_w_ij
156 ) # shape: (t,t)
158 # Compute normalized statistics
159 Fnorm = Fij - Nij[:, None] * machine.ubm.means
160 Snorm = (
161 Sij
162 - (2 * Fij * machine.ubm.means)
163 + (Nij[:, None] * machine.ubm.means * machine.ubm.means)
164 )
166 # Do the accumulation for each component
167 stats.snormij = stats.snormij + Snorm # shape: (c, d)
169 # (c,t,t) += (c,) * (t,t)
170 stats.nij_sigma_wij2 = stats.nij_sigma_wij2 + (
171 Nij[:, None, None] * sigma_w_ij2[None, :, :]
172 ) # (c,t,t)
173 stats.nij = stats.nij + Nij
174 stats.fnorm_sigma_wij = stats.fnorm_sigma_wij + np.matmul(
175 Fnorm[:, :, None], sigma_w_ij[None, :]
176 ) # (c,d,t)
178 return stats
181def m_step(machine: "IVectorMachine", stats: IVectorStats) -> "IVectorMachine":
182 """Updates the Machine with the maximization step of the e-m algorithm."""
183 logger.debug("Computing new machine parameters.")
184 A = stats.nij_sigma_wij2.transpose((0, 2, 1))
185 B = stats.fnorm_sigma_wij.transpose((0, 2, 1))
187 # Default value of X if any of A[c] is 0
188 X = np.zeros_like(B)
189 # Solve for all A[c] != 0
190 if any(mask := A.any(axis=(-2, -1))): # Prevents solving with 0 matrices
191 X[mask] = [
192 np.linalg.solve(A[c], B[c]) for c in range(len(mask)) if A[c].any()
193 ]
195 # Update the machine
196 machine.T = X.transpose((0, 2, 1))
198 if machine.update_sigma:
199 fnorm_sigma_wij_tt = np.diagonal(
200 stats.fnorm_sigma_wij @ X, axis1=-2, axis2=-1
201 )
202 machine.sigma = (stats.snormij - fnorm_sigma_wij_tt) / stats.nij[
203 :, None
204 ]
205 machine.sigma[
206 machine.sigma < machine.variance_floor
207 ] = machine.variance_floor
209 return machine
212class IVectorMachine(BaseEstimator):
213 """Trains and projects data using I-Vector.
215 Dimensions:
216 - dim_c: number of Gaussians
217 - dim_d: number of features
218 - dim_t: dimension of the i-vector
220 **Attributes**
222 T (c,d,t):
223 The total variability matrix :math:`T`
224 sigma (c,d):
225 The diagonal covariance matrix :math:`Sigma`
227 """
229 def __init__(
230 self,
231 ubm: GMMMachine,
232 dim_t: int = 2,
233 convergence_threshold: Optional[float] = None,
234 max_iterations: int = 25,
235 update_sigma: bool = True,
236 variance_floor: float = 1e-10,
237 **kwargs,
238 ) -> None:
239 """Initializes the IVectorMachine object.
241 **Parameters**
243 ubm
244 The Universal Background Model.
245 dim_t
246 The dimension of the i-vector.
247 """
249 super().__init__(**kwargs)
250 self.ubm = ubm
251 self.dim_t = dim_t
252 self.convergence_threshold = convergence_threshold
253 self.max_iterations = max_iterations
254 self.update_sigma = update_sigma
255 self.dim_c = None
256 self.dim_d = None
257 self.variance_floor = variance_floor
259 self.T = None
260 self.sigma = None
262 if self.convergence_threshold:
263 logger.info(
264 "The convergence threshold is ignored by IVectorMachine."
265 )
267 def fit(
268 self, X: Union[List[np.ndarray], dask.bag.Bag], y=None
269 ) -> "IVectorMachine":
270 """Trains the IVectorMachine.
272 Repeats the e-m steps until ``max_iterations`` is reached.
273 """
275 chunky = False
276 if isinstance(X, dask.bag.Bag):
277 chunky = True
278 X = X.to_delayed()
280 self.dim_c = self.ubm.n_gaussians
281 self.dim_d = self.ubm.means.shape[-1]
283 self.T = np.random.normal(
284 loc=0.0,
285 scale=1.0,
286 size=(self.dim_c, self.dim_d, self.dim_t),
287 )
288 self.sigma = copy.deepcopy(self.ubm.variances)
290 logger.info("Training I-Vector...")
291 for step in range(self.max_iterations):
292 logger.info(
293 f"IVector step {step+1:{len(str(self.max_iterations))}d}/{self.max_iterations}."
294 )
295 if chunky:
296 # Compute the IVectorStats of each chunk
297 stats = [
298 dask.delayed(e_step)(
299 machine=self,
300 data=xx,
301 )
302 for xx in X
303 ]
305 # Workaround to prevent memory issues at compute with too many chunks.
306 # This adds pairs of stats together instead of sending all the stats to
307 # one worker.
308 while (length := len(stats)) > 1:
309 last = stats[-1]
310 stats = [
311 dask.delayed(operator.add)(
312 stats[i], stats[length // 2 + i]
313 )
314 for i in range(length // 2)
315 ]
316 if length % 2 != 0:
317 stats.append(last)
318 stats_sum = stats[0]
320 # Update the machine parameters with the aggregated stats
321 new_machine = dask.compute(
322 dask.delayed(m_step)(self, stats_sum)
323 )[0]
324 for attr in ["T", "sigma"]:
325 setattr(self, attr, getattr(new_machine, attr))
326 else: # Working directly on numpy array, not dask.Bags
327 stats = e_step(machine=self, data=X)
328 _ = m_step(self, stats)
329 logger.info(f"Reached {step+1} steps.")
330 return self
332 def project(self, stats: GMMStats) -> np.ndarray:
333 """Projects the GMMStats on the IVectorMachine.
335 This takes data already projected onto the UBM.
337 **Returns:**
339 The IVector of the input stats.
341 """
343 return np.linalg.solve(
344 compute_id_tt_sigma_inv_t(stats, self.T, self.sigma),
345 compute_tt_sigma_inv_fnorm(
346 self.ubm.means, stats, self.T, self.sigma
347 ),
348 )
350 def transform(self, X: List[GMMStats]) -> List[np.ndarray]:
351 """Transforms the data using the trained IVectorMachine.
353 This takes MFCC data, will project them onto the ubm, and compute the IVector
354 statistics.
356 **Parameters:**
358 data
359 The data (MFCC features) to transform.
360 Arrays of shape (n_samples, n_features).
362 **Returns:**
364 The IVector for each sample. Arrays of shape (dim_t,)
365 """
367 return [self.project(x) for x in X]
369 def _more_tags(self) -> Dict[str, Any]:
370 return {
371 "requires_fit": True,
372 "bob_fit_supports_dask_bag": True,
373 }