Coverage for src/bob/learn/em/gmm.py: 92%
390 statements
« prev ^ index » next coverage.py v7.0.5, created at 2023-06-16 14:34 +0200
« prev ^ index » next coverage.py v7.0.5, created at 2023-06-16 14:34 +0200
1#!/usr/bin/env python
2# @author: Yannick Dayer <yannick.dayer@idiap.ch>
3# @date: Fri 30 Jul 2021 10:06:47 UTC+02
5"""This module provides classes and functions for the training and usage of GMM."""
7import copy
8import functools
9import logging
10import operator
12from typing import Union
14import dask
15import dask.array as da
16import numpy as np
18from h5py import File as HDF5File
19from sklearn.base import BaseEstimator
21from .kmeans import KMeansMachine
22from .utils import array_to_delayed_list, check_and_persist_dask_input
24logger = logging.getLogger(__name__)
26EPSILON = np.finfo(float).eps
29def logaddexp_reduce(array, axis=0, keepdims=False):
30 return np.logaddexp.reduce(
31 array, axis=axis, keepdims=keepdims, initial=-np.inf
32 )
35def log_weighted_likelihood(data, machine):
36 """Returns the weighted log likelihood for each Gaussian for a set of data.
38 Parameters
39 ----------
40 data
41 Data to compute the log likelihood on.
42 machine
43 The GMM machine.
45 Returns
46 -------
47 array of shape (n_gaussians, n_samples)
48 The weighted log likelihood of each sample of each Gaussian.
49 """
50 # Compute the likelihood for each data point on each Gaussian
51 n_gaussians = len(machine.means)
52 z = []
53 for i in range(n_gaussians):
54 temp = np.sum(
55 (data - machine.means[i]) ** 2 / machine.variances[i], axis=-1
56 )
57 z.append(temp)
58 z = np.vstack(z)
59 ll = -0.5 * (machine.g_norms[:, None] + z)
60 log_weighted_likelihoods = machine.log_weights[:, None] + ll
61 return log_weighted_likelihoods
64def reduce_loglikelihood(log_weighted_likelihoods):
65 if isinstance(log_weighted_likelihoods, np.ndarray):
66 log_likelihood = logaddexp_reduce(log_weighted_likelihoods)
67 else:
68 # Sum along gaussians axis (using logAddExp to prevent underflow)
69 log_likelihood = da.reduction(
70 x=log_weighted_likelihoods,
71 chunk=logaddexp_reduce,
72 aggregate=logaddexp_reduce,
73 axis=0,
74 dtype=float,
75 keepdims=False,
76 )
77 return log_likelihood
80def log_likelihood(data, machine):
81 """Returns the current log likelihood for a set of data in this Machine.
83 Parameters
84 ----------
85 data
86 Data to compute the log likelihood on.
87 machine
88 The GMM machine.
90 Returns
91 -------
92 array of shape (n_samples)
93 The log likelihood of each sample.
94 """
95 data = np.atleast_2d(data)
97 # All likelihoods [array of shape (n_gaussians, n_samples)]
98 log_weighted_likelihoods = log_weighted_likelihood(
99 data=data,
100 machine=machine,
101 )
102 # Likelihoods of each sample on this machine. [array of shape (n_samples,)]
103 ll_reduced = reduce_loglikelihood(log_weighted_likelihoods)
104 return ll_reduced
107def e_step(data, machine):
108 """Expectation step of the e-m algorithm."""
109 # Ensure data is a series of samples (2D array)
110 data = np.atleast_2d(data)
112 n_gaussians = len(machine.weights)
114 # Allow the absence of previous statistics
115 statistics = GMMStats(n_gaussians, data.shape[-1], like=data)
117 # Log weighted Gaussian likelihoods [array of shape (n_gaussians,n_samples)]
118 log_weighted_likelihoods = log_weighted_likelihood(
119 data=data, machine=machine
120 )
122 # Log likelihood [array of shape (n_samples,)]
123 log_likelihood = reduce_loglikelihood(log_weighted_likelihoods)
125 # Responsibility P [array of shape (n_gaussians, n_samples)]
126 responsibility = np.exp(log_weighted_likelihoods - log_likelihood[None, :])
128 # Accumulate
130 # Total likelihood [float]
131 statistics.log_likelihood = log_likelihood.sum()
132 # Count of samples [int]
133 statistics.t = data.shape[0]
134 # Responsibilities [array of shape (n_gaussians,)]
135 statistics.n = responsibility.sum(axis=-1)
136 sum_px, sum_pxx = [], []
137 for i in range(n_gaussians):
138 # p * x [array of shape (n_gaussians, n_samples, n_features)]
139 px = responsibility[i, :, None] * data
140 # First order stats [array of shape (n_gaussians, n_features)]
141 # statistics.sum_px[i] = np.sum(px, axis=0)
142 sum_px.append(np.sum(px, axis=0))
143 # Second order stats [array of shape (n_gaussians, n_features)]
144 # statistics.sum_pxx[i] = np.sum(px * data, axis=0)
145 sum_pxx.append(np.sum(px * data, axis=0))
147 statistics.sum_px = np.vstack(sum_px)
148 statistics.sum_pxx = np.vstack(sum_pxx)
150 return statistics
153def m_step(
154 statistics,
155 machine,
156):
157 """Maximization step of the e-m algorithm."""
158 m_step_func = map_gmm_m_step if machine.trainer == "map" else ml_gmm_m_step
159 statistics = functools.reduce(operator.iadd, statistics)
160 m_step_func(
161 machine=machine,
162 statistics=statistics,
163 update_means=machine.update_means,
164 update_variances=machine.update_variances,
165 update_weights=machine.update_weights,
166 mean_var_update_threshold=machine.mean_var_update_threshold,
167 reynolds_adaptation=machine.map_relevance_factor is not None,
168 alpha=machine.map_alpha,
169 relevance_factor=machine.map_relevance_factor,
170 )
171 average_output = float(statistics.log_likelihood / statistics.t)
172 return machine, average_output
175class GMMStats:
176 """Stores accumulated statistics of a GMM.
178 Attributes
179 ----------
180 log_likelihood: float
181 The sum of log_likelihood of each sample on a GMM.
182 t: int
183 The number of considered samples.
184 n: array of shape (n_gaussians,)
185 Sum of responsibility.
186 sum_px: array of shape (n_gaussians, n_features)
187 First order statistic
188 sum_pxx: array of shape (n_gaussians, n_features)
189 Second order statistic
190 """
192 def __init__(
193 self, n_gaussians: int, n_features: int, like=None, **kwargs
194 ) -> None:
195 super().__init__(**kwargs)
196 self.n_gaussians = n_gaussians
197 self.n_features = n_features
198 self.log_likelihood = 0
199 self.t = 0
200 # create dask arrays if required
201 kw = dict(like=like) if like is not None else {}
202 self.n = np.zeros(shape=(self.n_gaussians,), dtype=float, **kw)
203 self.sum_px = np.zeros(
204 shape=(self.n_gaussians, self.n_features), dtype=float, **kw
205 )
206 self.sum_pxx = np.zeros(
207 shape=(self.n_gaussians, self.n_features), dtype=float, **kw
208 )
210 def init_fields(
211 self, log_likelihood=0.0, t=0, n=None, sum_px=None, sum_pxx=None
212 ):
213 """Initializes the statistics values to a defined value, or zero by default."""
214 # The accumulated log likelihood of all samples
215 self.log_likelihood = log_likelihood
216 # The accumulated number of samples
217 self.t = t
218 # For each Gaussian, the accumulated sum of responsibilities, i.e. the sum of
219 # P(gaussian_i|x)
220 self.n = (
221 np.zeros(shape=(self.n_gaussians,), dtype=float) if n is None else n
222 )
223 # For each Gaussian, the accumulated sum of responsibility times the sample
224 self.sum_px = (
225 np.zeros(shape=(self.n_gaussians, self.n_features), dtype=float)
226 if sum_px is None
227 else sum_px
228 )
229 # For each Gaussian, the accumulated sum of responsibility times the sample
230 # squared
231 self.sum_pxx = (
232 np.zeros(shape=(self.n_gaussians, self.n_features), dtype=float)
233 if sum_pxx is None
234 else sum_pxx
235 )
237 def reset(self):
238 """Sets all statistics to zero."""
239 self.init_fields()
241 @classmethod
242 def from_hdf5(cls, hdf5):
243 """Creates a new GMMStats object from an `HDF5File` object."""
244 if isinstance(hdf5, str):
245 hdf5 = HDF5File(hdf5, "r")
246 try:
247 version_major, version_minor = hdf5.attrs["file_version"].split(".")
248 logger.debug(
249 f"Reading a GMMStats HDF5 file of version {version_major}.{version_minor}"
250 )
251 except (KeyError, RuntimeError):
252 version_major, version_minor = 0, 0
253 if int(version_major) >= 1:
254 if hdf5.attrs["writer_class"] != str(cls):
255 logger.warning(f"{hdf5.attrs['writer_class']} is not {cls}.")
256 self = cls(
257 n_gaussians=hdf5["n_gaussians"][()],
258 n_features=hdf5["n_features"][()],
259 )
260 self.log_likelihood = hdf5["log_likelihood"][()]
261 self.t = hdf5["T"][()]
262 self.n = hdf5["n"][...]
263 self.sum_px = hdf5["sumPx"][...]
264 self.sum_pxx = hdf5["sumPxx"][...]
265 else: # Legacy file version
266 logger.info("Loading a legacy HDF5 stats file.")
267 self = cls(
268 n_gaussians=int(hdf5["n_gaussians"][()]),
269 n_features=int(hdf5["n_inputs"][()]),
270 )
271 self.log_likelihood = hdf5["log_liklihood"][()]
272 self.t = int(hdf5["T"][()])
273 self.n = np.reshape(hdf5["n"], (self.n_gaussians,))
274 self.sum_px = np.reshape(hdf5["sumPx"], (self.shape))
275 self.sum_pxx = np.reshape(hdf5["sumPxx"], (self.shape))
276 return self
278 def save(self, hdf5):
279 """Saves the current statistsics in an `HDF5File` object."""
280 if isinstance(hdf5, str):
281 hdf5 = HDF5File(hdf5, "w")
282 hdf5.attrs["file_version"] = "1.0"
283 hdf5.attrs["writer_class"] = str(self.__class__)
284 hdf5["n_gaussians"] = self.n_gaussians
285 hdf5["n_features"] = self.n_features
286 hdf5["log_likelihood"] = float(self.log_likelihood)
287 hdf5["T"] = int(self.t)
288 hdf5["n"] = np.array(self.n)
289 hdf5["sumPx"] = np.array(self.sum_px)
290 hdf5["sumPxx"] = np.array(self.sum_pxx)
292 def load(self, hdf5):
293 """Overwrites the current statistics with those in an `HDF5File` object."""
294 new_self = self.from_hdf5(hdf5)
295 if new_self.shape != self.shape:
296 logger.warning("Loaded GMMStats from hdf5 with a different shape.")
297 self.resize(*new_self.shape)
298 self.init_fields(
299 new_self.log_likelihood,
300 new_self.t,
301 new_self.n,
302 new_self.sum_px,
303 new_self.sum_pxx,
304 )
306 def __add__(self, other):
307 if (
308 self.n_gaussians != other.n_gaussians
309 or self.n_features != other.n_features
310 ):
311 raise ValueError(
312 "Statistics could not be added together (shape mismatch)"
313 )
314 new_stats = GMMStats(self.n_gaussians, self.n_features)
315 new_stats.log_likelihood = self.log_likelihood + other.log_likelihood
316 new_stats.t = self.t + other.t
317 new_stats.n = self.n + other.n
318 new_stats.sum_px = self.sum_px + other.sum_px
319 new_stats.sum_pxx = self.sum_pxx + other.sum_pxx
320 return new_stats
322 def __iadd__(self, other):
323 if (
324 self.n_gaussians != other.n_gaussians
325 or self.n_features != other.n_features
326 ):
327 raise ValueError(
328 "Statistics could not be added together (shape mismatch)"
329 )
330 self.log_likelihood += other.log_likelihood
331 self.t += other.t
332 self.n += other.n
333 self.sum_px += other.sum_px
334 self.sum_pxx += other.sum_pxx
335 return self
337 def __eq__(self, other):
338 return (
339 self.log_likelihood == other.log_likelihood
340 and self.t == other.t
341 and np.array_equal(self.n, other.n)
342 and np.array_equal(self.sum_px, other.sum_px)
343 and np.array_equal(self.sum_pxx, other.sum_pxx)
344 )
346 def is_similar_to(self, other, rtol=1e-5, atol=1e-8):
347 """Returns True if `other` has the same values (within a tolerance)."""
348 return (
349 np.isclose(
350 self.log_likelihood, other.log_likelihood, rtol=rtol, atol=atol
351 )
352 and np.isclose(self.t, other.t, rtol=rtol, atol=atol)
353 and np.allclose(self.n, other.n, rtol=rtol, atol=atol)
354 and np.allclose(self.sum_px, other.sum_px, rtol=rtol, atol=atol)
355 and np.allclose(self.sum_pxx, other.sum_pxx, rtol=rtol, atol=atol)
356 )
358 def resize(self, n_gaussians, n_features):
359 """Reinitializes the machine with new dimensions."""
360 self.n_gaussians = n_gaussians
361 self.n_features = n_features
362 self.init_fields()
364 @property
365 def shape(self):
366 """The number of gaussians and their dimensionality."""
367 return (self.n_gaussians, self.n_features)
369 @property
370 def nbytes(self):
371 """The number of bytes used by the statistics n, sum_px, sum_pxx."""
372 return self.n.nbytes + self.sum_px.nbytes + self.sum_pxx.nbytes
375class GMMMachine(BaseEstimator):
376 """Transformer that stores a Gaussian Mixture Model (GMM) parameters.
378 This class implements the statistical model for multivariate diagonal mixture
379 Gaussian distribution (GMM), as well as ways to train a model on data.
381 A GMM is defined as
382 :math:`\\sum_{c=0}^{C} \\omega_c \\mathcal{N}(x | \\mu_c, \\sigma_c)`, where
383 :math:`C` is the number of Gaussian components :math:`\\mu_c`, :math:`\\sigma_c`
384 and :math:`\\omega_c` are respectively the the mean, variance and the weight of
385 each gaussian component :math:`c`.
386 See Section 2.3.9 of Bishop, \"Pattern recognition and machine learning\", 2006
388 Two types of training are available MLE and MAP, chosen with `trainer`.
390 * Maximum Likelihood Estimation (:ref:`MLE <mle>`, ML)
392 The mixtures are initialized (with k-means by default). The means,
393 variances, and weights of the mixtures are then trained on the data to
394 increase the likelihood value. (:ref:`MLE <mle>`)
396 * Maximum a Posteriori (:ref:`MAP <map>`)
398 The MAP machine takes another GMM machine as prior, called Universal Background
399 Model (UBM).
400 The means, variances, and weights of the MAP mixtures are then trained on the data
401 as adaptation of the UBM.
403 Both training method use a Expectation-Maximization (e-m) algorithm to iteratively
404 train the GMM.
406 Note
407 ----
408 When setting manually any of the means, variances or variance thresholds, the
409 k-means initialization will be skipped in `fit`.
411 Attributes
412 ----------
413 means, variances, variance_thresholds
414 Gaussians parameters.
415 """
417 def __init__(
418 self,
419 n_gaussians: int,
420 trainer: str = "ml",
421 ubm: "Union[GMMMachine, None]" = None,
422 convergence_threshold: float = 1e-5,
423 max_fitting_steps: Union[int, None] = 200,
424 random_state: Union[int, np.random.RandomState] = 0,
425 weights: "Union[np.ndarray[('n_gaussians',), float], None]" = None, # noqa: F821
426 k_means_trainer: Union[KMeansMachine, None] = None,
427 update_means: bool = True,
428 update_variances: bool = False,
429 update_weights: bool = False,
430 mean_var_update_threshold: float = EPSILON,
431 map_alpha: float = 0.5,
432 map_relevance_factor: Union[None, float] = 4,
433 **kwargs,
434 ):
435 """
436 Parameters
437 ----------
438 n_gaussians
439 The number of gaussians to be represented by the machine.
440 trainer
441 `"ml"` for the maximum likelihood estimator method;
442 `"map"` for the maximum a posteriori method. (MAP Requires `ubm`)
443 ubm: GMMMachine
444 Universal Background Model. GMMMachine Required for the MAP method.
445 convergence_threshold
446 The threshold value of likelihood difference between e-m steps used for
447 stopping the training iterations.
448 max_fitting_steps
449 The number of e-m iterations to fit the GMM. Stop the training even when
450 the convergence threshold isn't met.
451 random_state
452 Specifies a RandomState or a seed for reproducibility.
453 weights
454 The weight of each Gaussian. (defaults to `1/n_gaussians`)
455 k_means_trainer
456 Optional trainer for the k-means method, replacing the default one.
457 update_means
458 Update the Gaussians means at every m step.
459 update_variances
460 Update the Gaussians variances at every m step.
461 update_weights
462 Update the GMM weights at every m step.
463 mean_var_update_threshold
464 Threshold value used when updating the means and variances.
465 map_alpha
466 Ratio for MAP adaptation. Used when `trainer == "map"` and
467 `relevance_factor is None`)
468 map_relevance_factor
469 Factor for the computation of alpha with Reynolds adaptation. (Used when
470 `trainer == "map"`)
471 """
473 super().__init__(**kwargs)
475 self.n_gaussians = n_gaussians
476 self.trainer = trainer if trainer in ["ml", "map"] else "ml"
477 self.m_step_func = (
478 map_gmm_m_step if self.trainer == "map" else ml_gmm_m_step
479 )
480 if self.trainer == "map" and ubm is None:
481 raise ValueError("A UBM is required for MAP GMM.")
482 if ubm is not None and ubm.n_gaussians != self.n_gaussians:
483 raise ValueError(
484 "The UBM machine is not compatible with this machine."
485 )
486 self.ubm = ubm
487 if max_fitting_steps is None and convergence_threshold is None:
488 raise ValueError(
489 "Either or both convergence_threshold and max_fitting_steps must be set"
490 )
491 self.convergence_threshold = convergence_threshold
492 self.max_fitting_steps = max_fitting_steps
493 self.random_state = random_state
494 self.k_means_trainer = k_means_trainer
495 self.update_means = update_means
496 self.update_variances = update_variances
497 self.update_weights = update_weights
498 self.mean_var_update_threshold = mean_var_update_threshold
499 self._means = None
500 self._variances = None
501 self._variance_thresholds = mean_var_update_threshold
502 self._g_norms = None
504 if self.ubm is not None:
505 self.means = copy.deepcopy(self.ubm.means)
506 self.variances = copy.deepcopy(self.ubm.variances)
507 self.variance_thresholds = copy.deepcopy(
508 self.ubm.variance_thresholds
509 )
510 self.weights = copy.deepcopy(self.ubm.weights)
511 else:
512 self.weights = np.full(
513 (self.n_gaussians,),
514 fill_value=(1 / self.n_gaussians),
515 dtype=float,
516 )
517 if weights is not None:
518 self.weights = weights
519 self.map_alpha = map_alpha
520 self.map_relevance_factor = map_relevance_factor
522 @property
523 def weights(self):
524 """The weights of each Gaussian mixture."""
525 return self._weights
527 @weights.setter
528 def weights(
529 self, weights: "np.ndarray[('n_gaussians',), float]" # noqa: F821
530 ): # noqa: F821
531 self._weights = weights
532 self._log_weights = np.log(self._weights)
534 @property
535 def means(self):
536 """The means of each Gaussian."""
537 if self._means is None:
538 raise ValueError("GMMMachine means were never set.")
539 return self._means
541 @means.setter
542 def means(
543 self,
544 means: "np.ndarray[('n_gaussians', 'n_features'), float]", # noqa: F821
545 ): # noqa: F821
546 self._means = means
548 @property
549 def variances(self):
550 """The (diagonal) variances of the gaussians."""
551 if self._variances is None:
552 raise ValueError("GMMMachine variances were never set.")
553 return self._variances
555 @variances.setter
556 def variances(
557 self,
558 variances: "np.ndarray[('n_gaussians', 'n_features'), float]", # noqa: F821
559 ):
560 self._variances = np.maximum(self.variance_thresholds, variances)
561 # Recompute g_norm for each gaussian [array of shape (n_gaussians,)]
562 n_log_2pi = self._variances.shape[-1] * np.log(2 * np.pi)
563 self._g_norms = n_log_2pi + np.log(self._variances).sum(axis=-1)
565 @property
566 def variance_thresholds(self):
567 """Threshold below which variances are clamped to prevent precision losses."""
568 if self._variance_thresholds is None:
569 return EPSILON
570 return self._variance_thresholds
572 @variance_thresholds.setter
573 def variance_thresholds(
574 self,
575 threshold: "Union[float, np.ndarray[('n_gaussians', 'n_features'), float]]", # noqa: F821
576 ):
577 self._variance_thresholds = threshold
578 if self._variances is not None:
579 self.variances = np.maximum(threshold, self._variances)
581 @property
582 def g_norms(self):
583 """Precomputed g_norms (depends on variances and feature shape)."""
584 if self._g_norms is None:
585 # Recompute g_norm for each gaussian [array of shape (n_gaussians,)]
586 n_log_2pi = self.variances.shape[-1] * np.log(2 * np.pi)
587 self._g_norms = n_log_2pi + np.log(self._variances).sum(axis=-1)
588 return self._g_norms
590 @property
591 def log_weights(self):
592 """Retrieve the logarithm of the weights."""
593 return self._log_weights
595 @property
596 def shape(self):
597 """Shape of the gaussians in the GMM machine."""
598 return self.means.shape
600 @classmethod
601 def from_hdf5(cls, hdf5: Union[str, HDF5File], ubm: "GMMMachine" = None):
602 """Creates a new GMMMachine object from an `HDF5File` object."""
603 if isinstance(hdf5, str):
604 hdf5 = HDF5File(hdf5, "r")
605 try:
606 version_major, version_minor = hdf5.attrs["file_version"].split(".")
607 logger.debug(
608 f"Reading a GMMMachine HDF5 file of version {version_major}.{version_minor}"
609 )
610 except (KeyError, RuntimeError):
611 version_major, version_minor = 0, 0
612 if int(version_major) >= 1:
613 if hdf5.attrs["writer_class"] != str(cls):
614 logger.warning(f"{hdf5.attrs['writer_class']} is not {cls}.")
615 if hdf5["trainer"] == "map" and ubm is None:
616 raise ValueError(
617 "The UBM is needed when loading a MAP machine."
618 )
619 self = cls(
620 n_gaussians=hdf5["n_gaussians"][()],
621 trainer=hdf5["trainer"][()],
622 ubm=ubm,
623 convergence_threshold=1e-5,
624 max_fitting_steps=hdf5["max_fitting_steps"][()],
625 weights=hdf5["weights"][...],
626 k_means_trainer=None,
627 update_means=hdf5["update_means"][()],
628 update_variances=hdf5["update_variances"][()],
629 update_weights=hdf5["update_weights"][()],
630 )
631 gaussians_group = hdf5["gaussians"]
632 self.means = gaussians_group["means"][...]
633 self.variances = gaussians_group["variances"][...]
634 self.variance_thresholds = gaussians_group["variance_thresholds"][
635 ...
636 ]
637 else: # Legacy file version
638 logger.info("Loading a legacy HDF5 machine file.")
639 n_gaussians = hdf5["m_n_gaussians"][()][0]
640 g_means = []
641 g_variances = []
642 g_variance_thresholds = []
643 for i in range(n_gaussians):
644 gaussian_group = hdf5[f"m_gaussians{i}"]
645 g_means.append(gaussian_group["m_mean"][...])
646 g_variances.append(gaussian_group["m_variance"][...])
647 g_variance_thresholds.append(
648 gaussian_group["m_variance_thresholds"][...]
649 )
650 weights = np.reshape(hdf5["m_weights"], (n_gaussians,))
651 self = cls(n_gaussians=n_gaussians, ubm=ubm, weights=weights)
652 self.means = np.array(g_means).reshape(n_gaussians, -1)
653 self.variances = np.array(g_variances).reshape(n_gaussians, -1)
654 self.variance_thresholds = np.array(g_variance_thresholds).reshape(
655 n_gaussians, -1
656 )
657 return self
659 def load(self, hdf5):
660 """Overwrites the current state with those in an `HDF5File` object."""
661 new_self = self.from_hdf5(hdf5)
662 self.__dict__.update(new_self.__dict__)
664 def save(self, hdf5):
665 """Saves the current statistics in an `HDF5File` object."""
666 if isinstance(hdf5, str):
667 hdf5 = HDF5File(hdf5, "w")
668 hdf5.attrs["file_version"] = "1.0"
669 hdf5.attrs["writer_class"] = str(self.__class__)
670 hdf5["n_gaussians"] = self.n_gaussians
671 hdf5["trainer"] = self.trainer
672 hdf5["convergence_threshold"] = self.convergence_threshold
673 hdf5["max_fitting_steps"] = self.max_fitting_steps
674 hdf5["weights"] = self.weights
675 hdf5["update_means"] = self.update_means
676 hdf5["update_variances"] = self.update_variances
677 hdf5["update_weights"] = self.update_weights
678 gaussians_group = hdf5.create_group("gaussians")
679 gaussians_group["means"] = self.means
680 gaussians_group["variances"] = self.variances
681 gaussians_group["variance_thresholds"] = self.variance_thresholds
683 def __eq__(self, other):
684 if self._means is None:
685 return False
687 return (
688 np.allclose(self.means, other.means)
689 and np.allclose(self.variances, other.variances)
690 and np.allclose(self.variance_thresholds, other.variance_thresholds)
691 and np.allclose(self.weights, other.weights)
692 )
694 def is_similar_to(self, other, rtol=1e-5, atol=1e-8):
695 """Returns True if `other` has the same gaussians (within a tolerance)."""
696 return (
697 np.allclose(self.means, other.means, rtol=rtol, atol=atol)
698 and np.allclose(
699 self.variances, other.variances, rtol=rtol, atol=atol
700 )
701 and np.allclose(
702 self.variance_thresholds,
703 other.variance_thresholds,
704 rtol=rtol,
705 atol=atol,
706 )
707 and np.allclose(self.weights, other.weights, rtol=rtol, atol=atol)
708 )
710 def initialize_gaussians(
711 self,
712 data: "Union[np.ndarray[('n_samples', 'n_features'), float], None]" = None, # noqa: F821
713 ):
714 """Populates gaussians parameters with either k-means or the UBM values."""
715 if self.trainer == "map":
716 self.means = copy.deepcopy(self.ubm.means)
717 self.variances = copy.deepcopy(self.ubm.variances)
718 self.variance_thresholds = copy.deepcopy(
719 self.ubm.variance_thresholds
720 )
721 self.weights = copy.deepcopy(self.ubm.weights)
722 else:
723 logger.debug("GMM means was never set. Initializing with k-means.")
724 if data is None:
725 raise ValueError("Data is required when training with k-means.")
726 logger.info("Initializing GMM with k-means.")
727 kmeans_machine = self.k_means_trainer or KMeansMachine(
728 self.n_gaussians,
729 random_state=self.random_state,
730 )
731 kmeans_machine = kmeans_machine.fit(data)
733 # Set the GMM machine's gaussians with the results of k-means
734 self.means = copy.deepcopy(kmeans_machine.centroids_)
735 logger.debug(
736 "Estimating the variance and weights of each gaussian from kmeans."
737 )
738 (
739 self.variances,
740 self.weights,
741 ) = kmeans_machine.get_variances_and_weights_for_each_cluster(data)
742 logger.debug("Done.")
744 def log_weighted_likelihood(
745 self,
746 data: "np.ndarray[('n_samples', 'n_features'), float]", # noqa: F821
747 ):
748 """Returns the weighted log likelihood for each Gaussian for a set of data.
750 Parameters
751 ----------
752 data
753 Data to compute the log likelihood on.
755 Returns
756 -------
757 array of shape (n_gaussians, n_samples)
758 The weighted log likelihood of each sample of each Gaussian.
759 """
760 return log_weighted_likelihood(
761 data=data,
762 machine=self,
763 )
765 def log_likelihood(
766 self,
767 data: "np.ndarray[('n_samples', 'n_features'), float]", # noqa: F821
768 ):
769 """Returns the current log likelihood for a set of data in this Machine.
771 Parameters
772 ----------
773 data
774 Data to compute the log likelihood on.
776 Returns
777 -------
778 array of shape (n_samples)
779 The log likelihood of each sample.
780 """
781 return log_likelihood(
782 data=data,
783 machine=self,
784 )
786 def fit(self, X, y=None):
787 """Trains the GMM on data until convergence or maximum step is reached."""
789 input_is_dask, X = check_and_persist_dask_input(X)
791 if self._means is None:
792 self.initialize_gaussians(X)
793 else:
794 logger.debug("GMM means already set. Initialization was not run!")
796 if self._variances is None:
797 logger.warning(
798 "Variances were not defined before fit. Using variance=1"
799 )
800 self.variances = np.ones_like(self.means)
802 X = array_to_delayed_list(X, input_is_dask)
804 average_output = 0
805 logger.info("Training GMM...")
806 step = 0
807 while self.max_fitting_steps is None or step < self.max_fitting_steps:
808 step += 1
809 logger.info(
810 f"Iteration {step:3d}"
811 + (
812 f"/{self.max_fitting_steps:3d}"
813 if self.max_fitting_steps
814 else ""
815 )
816 )
818 average_output_previous = average_output
820 # compute the e-m steps
821 if input_is_dask:
822 stats = [
823 dask.delayed(e_step)(
824 data=xx,
825 machine=self,
826 )
827 for xx in X
828 ]
829 new_machine, average_output = dask.compute(
830 dask.delayed(m_step)(stats, self)
831 )[0]
832 for attr in ["weights", "means", "variances"]:
833 setattr(self, attr, getattr(new_machine, attr))
834 else:
835 stats = [
836 e_step(
837 data=X,
838 machine=self,
839 )
840 ]
841 _, average_output = m_step(stats, self)
843 logger.debug(f"log likelihood = {average_output}")
844 if step > 1:
845 convergence_value = abs(
846 (average_output_previous - average_output)
847 / average_output_previous
848 )
849 logger.debug(
850 f"convergence val = {convergence_value} and threshold = {self.convergence_threshold}"
851 )
853 # Terminates if converged (and likelihood computation is set)
854 if (
855 self.convergence_threshold is not None
856 and convergence_value <= self.convergence_threshold
857 ):
858 logger.info(
859 "Reached convergence threshold. Training stopped."
860 )
861 break
863 else:
864 logger.info(
865 "Reached maximum step. Training stopped without convergence."
866 )
867 return self
869 def acc_stats(self, X):
870 """Returns the statistics for `X`."""
871 # we need this because sometimes the transform function gets overridden
872 return e_step(data=X, machine=self)
874 def transform(self, X):
875 """Returns the statistics for `X`."""
876 return self.stats_per_sample(X)
878 def stats_per_sample(self, X):
879 return [e_step(data=xx, machine=self) for xx in X]
882def ml_gmm_m_step(
883 machine: GMMMachine,
884 statistics: GMMStats,
885 update_means=True,
886 update_variances=False,
887 update_weights=False,
888 mean_var_update_threshold=EPSILON,
889 **kwargs,
890):
891 """Updates a gmm machine parameter according to the e-step statistics."""
892 logger.debug("ML GMM Trainer m-step")
894 # Threshold the low n to prevent divide by zero
895 thresholded_n = np.clip(statistics.n, mean_var_update_threshold, None)
897 # Update weights if requested
898 # (Equation 9.26 of Bishop, "Pattern recognition and machine learning", 2006)
899 if update_weights:
900 logger.debug("Update weights.")
901 machine.weights = thresholded_n / statistics.t
903 # Update GMM parameters using the sufficient statistics (m_ss):
905 # Update means if requested
906 # (Equation 9.24 of Bishop, "Pattern recognition and machine learning", 2006)
907 if update_means:
908 logger.debug("Update means.")
909 # Using n with the applied threshold
910 machine.means = statistics.sum_px / thresholded_n[:, None]
912 # Update variances if requested
913 # (Equation 9.25 of Bishop, "Pattern recognition and machine learning", 2006)
914 # ...but we use the "computational formula for the variance", i.e.
915 # var = 1/n * sum (P(x-mean)(x-mean))
916 # = 1/n * sum (Pxx) - mean^2
917 if update_variances:
918 logger.debug("Update variances.")
919 machine.variances = statistics.sum_pxx / thresholded_n[
920 :, None
921 ] - np.power(machine.means, 2)
924def map_gmm_m_step(
925 machine: GMMMachine,
926 statistics: GMMStats,
927 update_means=True,
928 update_variances=False,
929 update_weights=False,
930 reynolds_adaptation=True,
931 relevance_factor=4,
932 alpha=0.5,
933 mean_var_update_threshold=EPSILON,
934):
935 """Updates a GMMMachine parameters using statistics adapted from a UBM."""
936 if machine.ubm is None:
937 raise ValueError("A machine used for MAP must have a UBM.")
938 # Calculate the "data-dependent adaptation coefficient", alpha_i
939 # [array of shape (n_gaussians, )]
940 if reynolds_adaptation:
941 alpha = statistics.n / (statistics.n + relevance_factor)
942 else:
943 if not hasattr(alpha, "ndim"):
944 alpha = np.full((machine.n_gaussians,), alpha)
946 # - Update weights if requested
947 # Equation 11 of Reynolds et al., "Speaker Verification Using Adapted
948 # Gaussian Mixture Models", Digital Signal Processing, 2000
949 if update_weights:
950 # Calculate the maximum likelihood weights [array of shape (n_gaussians,)]
951 ml_weights = statistics.n / statistics.t
953 # Calculate the new weights
954 machine.weights = alpha * ml_weights + (1 - alpha) * machine.ubm.weights
956 # Apply the scale factor, gamma, to ensure the new weights sum to unity
957 gamma = machine.weights.sum()
958 machine.weights /= gamma
960 # Update GMM parameters
961 # - Update means if requested
962 # Equation 12 of Reynolds et al., "Speaker Verification Using Adapted
963 # Gaussian Mixture Models", Digital Signal Processing, 2000
964 if update_means:
965 # Apply threshold to prevent divide by zero below
966 n_threshold = np.where(
967 statistics.n < mean_var_update_threshold,
968 mean_var_update_threshold,
969 statistics.n,
970 )
971 # n_threshold = np.full(statistics.n.shape, fill_value=mean_var_update_threshold)
972 # n_threshold[statistics.n > mean_var_update_threshold] = statistics.n[
973 # statistics.n > mean_var_update_threshold
974 # ]
975 new_means = np.multiply(
976 alpha[:, None],
977 (statistics.sum_px / n_threshold[:, None]),
978 ) + np.multiply((1 - alpha[:, None]), machine.ubm.means)
979 machine.means = np.where(
980 statistics.n[:, None] < mean_var_update_threshold,
981 machine.ubm.means,
982 new_means,
983 )
985 # - Update variance if requested
986 # Equation 13 of Reynolds et al., "Speaker Verification Using Adapted
987 # Gaussian Mixture Models", Digital Signal Processing, 2000
988 if update_variances:
989 # Calculate new variances (equation 13)
990 prior_norm_variances = (
991 machine.ubm.variances + machine.ubm.means
992 ) - np.power(machine.means, 2)
993 new_variances = (
994 alpha[:, None] * statistics.sum_pxx / statistics.n[:, None]
995 + (1 - alpha[:, None]) * (machine.ubm.variances + machine.ubm.means)
996 - np.power(machine.means, 2)
997 )
998 machine.variances = np.where(
999 statistics.n[:, None] < mean_var_update_threshold,
1000 prior_norm_variances,
1001 new_variances,
1002 )