Coverage for src/bob/learn/em/gmm.py: 92%

390 statements  

« prev     ^ index     » next       coverage.py v7.0.5, created at 2023-06-16 14:34 +0200

1#!/usr/bin/env python 

2# @author: Yannick Dayer <yannick.dayer@idiap.ch> 

3# @date: Fri 30 Jul 2021 10:06:47 UTC+02 

4 

5"""This module provides classes and functions for the training and usage of GMM.""" 

6 

7import copy 

8import functools 

9import logging 

10import operator 

11 

12from typing import Union 

13 

14import dask 

15import dask.array as da 

16import numpy as np 

17 

18from h5py import File as HDF5File 

19from sklearn.base import BaseEstimator 

20 

21from .kmeans import KMeansMachine 

22from .utils import array_to_delayed_list, check_and_persist_dask_input 

23 

24logger = logging.getLogger(__name__) 

25 

26EPSILON = np.finfo(float).eps 

27 

28 

29def logaddexp_reduce(array, axis=0, keepdims=False): 

30 return np.logaddexp.reduce( 

31 array, axis=axis, keepdims=keepdims, initial=-np.inf 

32 ) 

33 

34 

35def log_weighted_likelihood(data, machine): 

36 """Returns the weighted log likelihood for each Gaussian for a set of data. 

37 

38 Parameters 

39 ---------- 

40 data 

41 Data to compute the log likelihood on. 

42 machine 

43 The GMM machine. 

44 

45 Returns 

46 ------- 

47 array of shape (n_gaussians, n_samples) 

48 The weighted log likelihood of each sample of each Gaussian. 

49 """ 

50 # Compute the likelihood for each data point on each Gaussian 

51 n_gaussians = len(machine.means) 

52 z = [] 

53 for i in range(n_gaussians): 

54 temp = np.sum( 

55 (data - machine.means[i]) ** 2 / machine.variances[i], axis=-1 

56 ) 

57 z.append(temp) 

58 z = np.vstack(z) 

59 ll = -0.5 * (machine.g_norms[:, None] + z) 

60 log_weighted_likelihoods = machine.log_weights[:, None] + ll 

61 return log_weighted_likelihoods 

62 

63 

64def reduce_loglikelihood(log_weighted_likelihoods): 

65 if isinstance(log_weighted_likelihoods, np.ndarray): 

66 log_likelihood = logaddexp_reduce(log_weighted_likelihoods) 

67 else: 

68 # Sum along gaussians axis (using logAddExp to prevent underflow) 

69 log_likelihood = da.reduction( 

70 x=log_weighted_likelihoods, 

71 chunk=logaddexp_reduce, 

72 aggregate=logaddexp_reduce, 

73 axis=0, 

74 dtype=float, 

75 keepdims=False, 

76 ) 

77 return log_likelihood 

78 

79 

80def log_likelihood(data, machine): 

81 """Returns the current log likelihood for a set of data in this Machine. 

82 

83 Parameters 

84 ---------- 

85 data 

86 Data to compute the log likelihood on. 

87 machine 

88 The GMM machine. 

89 

90 Returns 

91 ------- 

92 array of shape (n_samples) 

93 The log likelihood of each sample. 

94 """ 

95 data = np.atleast_2d(data) 

96 

97 # All likelihoods [array of shape (n_gaussians, n_samples)] 

98 log_weighted_likelihoods = log_weighted_likelihood( 

99 data=data, 

100 machine=machine, 

101 ) 

102 # Likelihoods of each sample on this machine. [array of shape (n_samples,)] 

103 ll_reduced = reduce_loglikelihood(log_weighted_likelihoods) 

104 return ll_reduced 

105 

106 

107def e_step(data, machine): 

108 """Expectation step of the e-m algorithm.""" 

109 # Ensure data is a series of samples (2D array) 

110 data = np.atleast_2d(data) 

111 

112 n_gaussians = len(machine.weights) 

113 

114 # Allow the absence of previous statistics 

115 statistics = GMMStats(n_gaussians, data.shape[-1], like=data) 

116 

117 # Log weighted Gaussian likelihoods [array of shape (n_gaussians,n_samples)] 

118 log_weighted_likelihoods = log_weighted_likelihood( 

119 data=data, machine=machine 

120 ) 

121 

122 # Log likelihood [array of shape (n_samples,)] 

123 log_likelihood = reduce_loglikelihood(log_weighted_likelihoods) 

124 

125 # Responsibility P [array of shape (n_gaussians, n_samples)] 

126 responsibility = np.exp(log_weighted_likelihoods - log_likelihood[None, :]) 

127 

128 # Accumulate 

129 

130 # Total likelihood [float] 

131 statistics.log_likelihood = log_likelihood.sum() 

132 # Count of samples [int] 

133 statistics.t = data.shape[0] 

134 # Responsibilities [array of shape (n_gaussians,)] 

135 statistics.n = responsibility.sum(axis=-1) 

136 sum_px, sum_pxx = [], [] 

137 for i in range(n_gaussians): 

138 # p * x [array of shape (n_gaussians, n_samples, n_features)] 

139 px = responsibility[i, :, None] * data 

140 # First order stats [array of shape (n_gaussians, n_features)] 

141 # statistics.sum_px[i] = np.sum(px, axis=0) 

142 sum_px.append(np.sum(px, axis=0)) 

143 # Second order stats [array of shape (n_gaussians, n_features)] 

144 # statistics.sum_pxx[i] = np.sum(px * data, axis=0) 

145 sum_pxx.append(np.sum(px * data, axis=0)) 

146 

147 statistics.sum_px = np.vstack(sum_px) 

148 statistics.sum_pxx = np.vstack(sum_pxx) 

149 

150 return statistics 

151 

152 

153def m_step( 

154 statistics, 

155 machine, 

156): 

157 """Maximization step of the e-m algorithm.""" 

158 m_step_func = map_gmm_m_step if machine.trainer == "map" else ml_gmm_m_step 

159 statistics = functools.reduce(operator.iadd, statistics) 

160 m_step_func( 

161 machine=machine, 

162 statistics=statistics, 

163 update_means=machine.update_means, 

164 update_variances=machine.update_variances, 

165 update_weights=machine.update_weights, 

166 mean_var_update_threshold=machine.mean_var_update_threshold, 

167 reynolds_adaptation=machine.map_relevance_factor is not None, 

168 alpha=machine.map_alpha, 

169 relevance_factor=machine.map_relevance_factor, 

170 ) 

171 average_output = float(statistics.log_likelihood / statistics.t) 

172 return machine, average_output 

173 

174 

175class GMMStats: 

176 """Stores accumulated statistics of a GMM. 

177 

178 Attributes 

179 ---------- 

180 log_likelihood: float 

181 The sum of log_likelihood of each sample on a GMM. 

182 t: int 

183 The number of considered samples. 

184 n: array of shape (n_gaussians,) 

185 Sum of responsibility. 

186 sum_px: array of shape (n_gaussians, n_features) 

187 First order statistic 

188 sum_pxx: array of shape (n_gaussians, n_features) 

189 Second order statistic 

190 """ 

191 

192 def __init__( 

193 self, n_gaussians: int, n_features: int, like=None, **kwargs 

194 ) -> None: 

195 super().__init__(**kwargs) 

196 self.n_gaussians = n_gaussians 

197 self.n_features = n_features 

198 self.log_likelihood = 0 

199 self.t = 0 

200 # create dask arrays if required 

201 kw = dict(like=like) if like is not None else {} 

202 self.n = np.zeros(shape=(self.n_gaussians,), dtype=float, **kw) 

203 self.sum_px = np.zeros( 

204 shape=(self.n_gaussians, self.n_features), dtype=float, **kw 

205 ) 

206 self.sum_pxx = np.zeros( 

207 shape=(self.n_gaussians, self.n_features), dtype=float, **kw 

208 ) 

209 

210 def init_fields( 

211 self, log_likelihood=0.0, t=0, n=None, sum_px=None, sum_pxx=None 

212 ): 

213 """Initializes the statistics values to a defined value, or zero by default.""" 

214 # The accumulated log likelihood of all samples 

215 self.log_likelihood = log_likelihood 

216 # The accumulated number of samples 

217 self.t = t 

218 # For each Gaussian, the accumulated sum of responsibilities, i.e. the sum of 

219 # P(gaussian_i|x) 

220 self.n = ( 

221 np.zeros(shape=(self.n_gaussians,), dtype=float) if n is None else n 

222 ) 

223 # For each Gaussian, the accumulated sum of responsibility times the sample 

224 self.sum_px = ( 

225 np.zeros(shape=(self.n_gaussians, self.n_features), dtype=float) 

226 if sum_px is None 

227 else sum_px 

228 ) 

229 # For each Gaussian, the accumulated sum of responsibility times the sample 

230 # squared 

231 self.sum_pxx = ( 

232 np.zeros(shape=(self.n_gaussians, self.n_features), dtype=float) 

233 if sum_pxx is None 

234 else sum_pxx 

235 ) 

236 

237 def reset(self): 

238 """Sets all statistics to zero.""" 

239 self.init_fields() 

240 

241 @classmethod 

242 def from_hdf5(cls, hdf5): 

243 """Creates a new GMMStats object from an `HDF5File` object.""" 

244 if isinstance(hdf5, str): 

245 hdf5 = HDF5File(hdf5, "r") 

246 try: 

247 version_major, version_minor = hdf5.attrs["file_version"].split(".") 

248 logger.debug( 

249 f"Reading a GMMStats HDF5 file of version {version_major}.{version_minor}" 

250 ) 

251 except (KeyError, RuntimeError): 

252 version_major, version_minor = 0, 0 

253 if int(version_major) >= 1: 

254 if hdf5.attrs["writer_class"] != str(cls): 

255 logger.warning(f"{hdf5.attrs['writer_class']} is not {cls}.") 

256 self = cls( 

257 n_gaussians=hdf5["n_gaussians"][()], 

258 n_features=hdf5["n_features"][()], 

259 ) 

260 self.log_likelihood = hdf5["log_likelihood"][()] 

261 self.t = hdf5["T"][()] 

262 self.n = hdf5["n"][...] 

263 self.sum_px = hdf5["sumPx"][...] 

264 self.sum_pxx = hdf5["sumPxx"][...] 

265 else: # Legacy file version 

266 logger.info("Loading a legacy HDF5 stats file.") 

267 self = cls( 

268 n_gaussians=int(hdf5["n_gaussians"][()]), 

269 n_features=int(hdf5["n_inputs"][()]), 

270 ) 

271 self.log_likelihood = hdf5["log_liklihood"][()] 

272 self.t = int(hdf5["T"][()]) 

273 self.n = np.reshape(hdf5["n"], (self.n_gaussians,)) 

274 self.sum_px = np.reshape(hdf5["sumPx"], (self.shape)) 

275 self.sum_pxx = np.reshape(hdf5["sumPxx"], (self.shape)) 

276 return self 

277 

278 def save(self, hdf5): 

279 """Saves the current statistsics in an `HDF5File` object.""" 

280 if isinstance(hdf5, str): 

281 hdf5 = HDF5File(hdf5, "w") 

282 hdf5.attrs["file_version"] = "1.0" 

283 hdf5.attrs["writer_class"] = str(self.__class__) 

284 hdf5["n_gaussians"] = self.n_gaussians 

285 hdf5["n_features"] = self.n_features 

286 hdf5["log_likelihood"] = float(self.log_likelihood) 

287 hdf5["T"] = int(self.t) 

288 hdf5["n"] = np.array(self.n) 

289 hdf5["sumPx"] = np.array(self.sum_px) 

290 hdf5["sumPxx"] = np.array(self.sum_pxx) 

291 

292 def load(self, hdf5): 

293 """Overwrites the current statistics with those in an `HDF5File` object.""" 

294 new_self = self.from_hdf5(hdf5) 

295 if new_self.shape != self.shape: 

296 logger.warning("Loaded GMMStats from hdf5 with a different shape.") 

297 self.resize(*new_self.shape) 

298 self.init_fields( 

299 new_self.log_likelihood, 

300 new_self.t, 

301 new_self.n, 

302 new_self.sum_px, 

303 new_self.sum_pxx, 

304 ) 

305 

306 def __add__(self, other): 

307 if ( 

308 self.n_gaussians != other.n_gaussians 

309 or self.n_features != other.n_features 

310 ): 

311 raise ValueError( 

312 "Statistics could not be added together (shape mismatch)" 

313 ) 

314 new_stats = GMMStats(self.n_gaussians, self.n_features) 

315 new_stats.log_likelihood = self.log_likelihood + other.log_likelihood 

316 new_stats.t = self.t + other.t 

317 new_stats.n = self.n + other.n 

318 new_stats.sum_px = self.sum_px + other.sum_px 

319 new_stats.sum_pxx = self.sum_pxx + other.sum_pxx 

320 return new_stats 

321 

322 def __iadd__(self, other): 

323 if ( 

324 self.n_gaussians != other.n_gaussians 

325 or self.n_features != other.n_features 

326 ): 

327 raise ValueError( 

328 "Statistics could not be added together (shape mismatch)" 

329 ) 

330 self.log_likelihood += other.log_likelihood 

331 self.t += other.t 

332 self.n += other.n 

333 self.sum_px += other.sum_px 

334 self.sum_pxx += other.sum_pxx 

335 return self 

336 

337 def __eq__(self, other): 

338 return ( 

339 self.log_likelihood == other.log_likelihood 

340 and self.t == other.t 

341 and np.array_equal(self.n, other.n) 

342 and np.array_equal(self.sum_px, other.sum_px) 

343 and np.array_equal(self.sum_pxx, other.sum_pxx) 

344 ) 

345 

346 def is_similar_to(self, other, rtol=1e-5, atol=1e-8): 

347 """Returns True if `other` has the same values (within a tolerance).""" 

348 return ( 

349 np.isclose( 

350 self.log_likelihood, other.log_likelihood, rtol=rtol, atol=atol 

351 ) 

352 and np.isclose(self.t, other.t, rtol=rtol, atol=atol) 

353 and np.allclose(self.n, other.n, rtol=rtol, atol=atol) 

354 and np.allclose(self.sum_px, other.sum_px, rtol=rtol, atol=atol) 

355 and np.allclose(self.sum_pxx, other.sum_pxx, rtol=rtol, atol=atol) 

356 ) 

357 

358 def resize(self, n_gaussians, n_features): 

359 """Reinitializes the machine with new dimensions.""" 

360 self.n_gaussians = n_gaussians 

361 self.n_features = n_features 

362 self.init_fields() 

363 

364 @property 

365 def shape(self): 

366 """The number of gaussians and their dimensionality.""" 

367 return (self.n_gaussians, self.n_features) 

368 

369 @property 

370 def nbytes(self): 

371 """The number of bytes used by the statistics n, sum_px, sum_pxx.""" 

372 return self.n.nbytes + self.sum_px.nbytes + self.sum_pxx.nbytes 

373 

374 

375class GMMMachine(BaseEstimator): 

376 """Transformer that stores a Gaussian Mixture Model (GMM) parameters. 

377 

378 This class implements the statistical model for multivariate diagonal mixture 

379 Gaussian distribution (GMM), as well as ways to train a model on data. 

380 

381 A GMM is defined as 

382 :math:`\\sum_{c=0}^{C} \\omega_c \\mathcal{N}(x | \\mu_c, \\sigma_c)`, where 

383 :math:`C` is the number of Gaussian components :math:`\\mu_c`, :math:`\\sigma_c` 

384 and :math:`\\omega_c` are respectively the the mean, variance and the weight of 

385 each gaussian component :math:`c`. 

386 See Section 2.3.9 of Bishop, \"Pattern recognition and machine learning\", 2006 

387 

388 Two types of training are available MLE and MAP, chosen with `trainer`. 

389 

390 * Maximum Likelihood Estimation (:ref:`MLE <mle>`, ML) 

391 

392 The mixtures are initialized (with k-means by default). The means, 

393 variances, and weights of the mixtures are then trained on the data to 

394 increase the likelihood value. (:ref:`MLE <mle>`) 

395 

396 * Maximum a Posteriori (:ref:`MAP <map>`) 

397 

398 The MAP machine takes another GMM machine as prior, called Universal Background 

399 Model (UBM). 

400 The means, variances, and weights of the MAP mixtures are then trained on the data 

401 as adaptation of the UBM. 

402 

403 Both training method use a Expectation-Maximization (e-m) algorithm to iteratively 

404 train the GMM. 

405 

406 Note 

407 ---- 

408 When setting manually any of the means, variances or variance thresholds, the 

409 k-means initialization will be skipped in `fit`. 

410 

411 Attributes 

412 ---------- 

413 means, variances, variance_thresholds 

414 Gaussians parameters. 

415 """ 

416 

417 def __init__( 

418 self, 

419 n_gaussians: int, 

420 trainer: str = "ml", 

421 ubm: "Union[GMMMachine, None]" = None, 

422 convergence_threshold: float = 1e-5, 

423 max_fitting_steps: Union[int, None] = 200, 

424 random_state: Union[int, np.random.RandomState] = 0, 

425 weights: "Union[np.ndarray[('n_gaussians',), float], None]" = None, # noqa: F821 

426 k_means_trainer: Union[KMeansMachine, None] = None, 

427 update_means: bool = True, 

428 update_variances: bool = False, 

429 update_weights: bool = False, 

430 mean_var_update_threshold: float = EPSILON, 

431 map_alpha: float = 0.5, 

432 map_relevance_factor: Union[None, float] = 4, 

433 **kwargs, 

434 ): 

435 """ 

436 Parameters 

437 ---------- 

438 n_gaussians 

439 The number of gaussians to be represented by the machine. 

440 trainer 

441 `"ml"` for the maximum likelihood estimator method; 

442 `"map"` for the maximum a posteriori method. (MAP Requires `ubm`) 

443 ubm: GMMMachine 

444 Universal Background Model. GMMMachine Required for the MAP method. 

445 convergence_threshold 

446 The threshold value of likelihood difference between e-m steps used for 

447 stopping the training iterations. 

448 max_fitting_steps 

449 The number of e-m iterations to fit the GMM. Stop the training even when 

450 the convergence threshold isn't met. 

451 random_state 

452 Specifies a RandomState or a seed for reproducibility. 

453 weights 

454 The weight of each Gaussian. (defaults to `1/n_gaussians`) 

455 k_means_trainer 

456 Optional trainer for the k-means method, replacing the default one. 

457 update_means 

458 Update the Gaussians means at every m step. 

459 update_variances 

460 Update the Gaussians variances at every m step. 

461 update_weights 

462 Update the GMM weights at every m step. 

463 mean_var_update_threshold 

464 Threshold value used when updating the means and variances. 

465 map_alpha 

466 Ratio for MAP adaptation. Used when `trainer == "map"` and 

467 `relevance_factor is None`) 

468 map_relevance_factor 

469 Factor for the computation of alpha with Reynolds adaptation. (Used when 

470 `trainer == "map"`) 

471 """ 

472 

473 super().__init__(**kwargs) 

474 

475 self.n_gaussians = n_gaussians 

476 self.trainer = trainer if trainer in ["ml", "map"] else "ml" 

477 self.m_step_func = ( 

478 map_gmm_m_step if self.trainer == "map" else ml_gmm_m_step 

479 ) 

480 if self.trainer == "map" and ubm is None: 

481 raise ValueError("A UBM is required for MAP GMM.") 

482 if ubm is not None and ubm.n_gaussians != self.n_gaussians: 

483 raise ValueError( 

484 "The UBM machine is not compatible with this machine." 

485 ) 

486 self.ubm = ubm 

487 if max_fitting_steps is None and convergence_threshold is None: 

488 raise ValueError( 

489 "Either or both convergence_threshold and max_fitting_steps must be set" 

490 ) 

491 self.convergence_threshold = convergence_threshold 

492 self.max_fitting_steps = max_fitting_steps 

493 self.random_state = random_state 

494 self.k_means_trainer = k_means_trainer 

495 self.update_means = update_means 

496 self.update_variances = update_variances 

497 self.update_weights = update_weights 

498 self.mean_var_update_threshold = mean_var_update_threshold 

499 self._means = None 

500 self._variances = None 

501 self._variance_thresholds = mean_var_update_threshold 

502 self._g_norms = None 

503 

504 if self.ubm is not None: 

505 self.means = copy.deepcopy(self.ubm.means) 

506 self.variances = copy.deepcopy(self.ubm.variances) 

507 self.variance_thresholds = copy.deepcopy( 

508 self.ubm.variance_thresholds 

509 ) 

510 self.weights = copy.deepcopy(self.ubm.weights) 

511 else: 

512 self.weights = np.full( 

513 (self.n_gaussians,), 

514 fill_value=(1 / self.n_gaussians), 

515 dtype=float, 

516 ) 

517 if weights is not None: 

518 self.weights = weights 

519 self.map_alpha = map_alpha 

520 self.map_relevance_factor = map_relevance_factor 

521 

522 @property 

523 def weights(self): 

524 """The weights of each Gaussian mixture.""" 

525 return self._weights 

526 

527 @weights.setter 

528 def weights( 

529 self, weights: "np.ndarray[('n_gaussians',), float]" # noqa: F821 

530 ): # noqa: F821 

531 self._weights = weights 

532 self._log_weights = np.log(self._weights) 

533 

534 @property 

535 def means(self): 

536 """The means of each Gaussian.""" 

537 if self._means is None: 

538 raise ValueError("GMMMachine means were never set.") 

539 return self._means 

540 

541 @means.setter 

542 def means( 

543 self, 

544 means: "np.ndarray[('n_gaussians', 'n_features'), float]", # noqa: F821 

545 ): # noqa: F821 

546 self._means = means 

547 

548 @property 

549 def variances(self): 

550 """The (diagonal) variances of the gaussians.""" 

551 if self._variances is None: 

552 raise ValueError("GMMMachine variances were never set.") 

553 return self._variances 

554 

555 @variances.setter 

556 def variances( 

557 self, 

558 variances: "np.ndarray[('n_gaussians', 'n_features'), float]", # noqa: F821 

559 ): 

560 self._variances = np.maximum(self.variance_thresholds, variances) 

561 # Recompute g_norm for each gaussian [array of shape (n_gaussians,)] 

562 n_log_2pi = self._variances.shape[-1] * np.log(2 * np.pi) 

563 self._g_norms = n_log_2pi + np.log(self._variances).sum(axis=-1) 

564 

565 @property 

566 def variance_thresholds(self): 

567 """Threshold below which variances are clamped to prevent precision losses.""" 

568 if self._variance_thresholds is None: 

569 return EPSILON 

570 return self._variance_thresholds 

571 

572 @variance_thresholds.setter 

573 def variance_thresholds( 

574 self, 

575 threshold: "Union[float, np.ndarray[('n_gaussians', 'n_features'), float]]", # noqa: F821 

576 ): 

577 self._variance_thresholds = threshold 

578 if self._variances is not None: 

579 self.variances = np.maximum(threshold, self._variances) 

580 

581 @property 

582 def g_norms(self): 

583 """Precomputed g_norms (depends on variances and feature shape).""" 

584 if self._g_norms is None: 

585 # Recompute g_norm for each gaussian [array of shape (n_gaussians,)] 

586 n_log_2pi = self.variances.shape[-1] * np.log(2 * np.pi) 

587 self._g_norms = n_log_2pi + np.log(self._variances).sum(axis=-1) 

588 return self._g_norms 

589 

590 @property 

591 def log_weights(self): 

592 """Retrieve the logarithm of the weights.""" 

593 return self._log_weights 

594 

595 @property 

596 def shape(self): 

597 """Shape of the gaussians in the GMM machine.""" 

598 return self.means.shape 

599 

600 @classmethod 

601 def from_hdf5(cls, hdf5: Union[str, HDF5File], ubm: "GMMMachine" = None): 

602 """Creates a new GMMMachine object from an `HDF5File` object.""" 

603 if isinstance(hdf5, str): 

604 hdf5 = HDF5File(hdf5, "r") 

605 try: 

606 version_major, version_minor = hdf5.attrs["file_version"].split(".") 

607 logger.debug( 

608 f"Reading a GMMMachine HDF5 file of version {version_major}.{version_minor}" 

609 ) 

610 except (KeyError, RuntimeError): 

611 version_major, version_minor = 0, 0 

612 if int(version_major) >= 1: 

613 if hdf5.attrs["writer_class"] != str(cls): 

614 logger.warning(f"{hdf5.attrs['writer_class']} is not {cls}.") 

615 if hdf5["trainer"] == "map" and ubm is None: 

616 raise ValueError( 

617 "The UBM is needed when loading a MAP machine." 

618 ) 

619 self = cls( 

620 n_gaussians=hdf5["n_gaussians"][()], 

621 trainer=hdf5["trainer"][()], 

622 ubm=ubm, 

623 convergence_threshold=1e-5, 

624 max_fitting_steps=hdf5["max_fitting_steps"][()], 

625 weights=hdf5["weights"][...], 

626 k_means_trainer=None, 

627 update_means=hdf5["update_means"][()], 

628 update_variances=hdf5["update_variances"][()], 

629 update_weights=hdf5["update_weights"][()], 

630 ) 

631 gaussians_group = hdf5["gaussians"] 

632 self.means = gaussians_group["means"][...] 

633 self.variances = gaussians_group["variances"][...] 

634 self.variance_thresholds = gaussians_group["variance_thresholds"][ 

635 ... 

636 ] 

637 else: # Legacy file version 

638 logger.info("Loading a legacy HDF5 machine file.") 

639 n_gaussians = hdf5["m_n_gaussians"][()][0] 

640 g_means = [] 

641 g_variances = [] 

642 g_variance_thresholds = [] 

643 for i in range(n_gaussians): 

644 gaussian_group = hdf5[f"m_gaussians{i}"] 

645 g_means.append(gaussian_group["m_mean"][...]) 

646 g_variances.append(gaussian_group["m_variance"][...]) 

647 g_variance_thresholds.append( 

648 gaussian_group["m_variance_thresholds"][...] 

649 ) 

650 weights = np.reshape(hdf5["m_weights"], (n_gaussians,)) 

651 self = cls(n_gaussians=n_gaussians, ubm=ubm, weights=weights) 

652 self.means = np.array(g_means).reshape(n_gaussians, -1) 

653 self.variances = np.array(g_variances).reshape(n_gaussians, -1) 

654 self.variance_thresholds = np.array(g_variance_thresholds).reshape( 

655 n_gaussians, -1 

656 ) 

657 return self 

658 

659 def load(self, hdf5): 

660 """Overwrites the current state with those in an `HDF5File` object.""" 

661 new_self = self.from_hdf5(hdf5) 

662 self.__dict__.update(new_self.__dict__) 

663 

664 def save(self, hdf5): 

665 """Saves the current statistics in an `HDF5File` object.""" 

666 if isinstance(hdf5, str): 

667 hdf5 = HDF5File(hdf5, "w") 

668 hdf5.attrs["file_version"] = "1.0" 

669 hdf5.attrs["writer_class"] = str(self.__class__) 

670 hdf5["n_gaussians"] = self.n_gaussians 

671 hdf5["trainer"] = self.trainer 

672 hdf5["convergence_threshold"] = self.convergence_threshold 

673 hdf5["max_fitting_steps"] = self.max_fitting_steps 

674 hdf5["weights"] = self.weights 

675 hdf5["update_means"] = self.update_means 

676 hdf5["update_variances"] = self.update_variances 

677 hdf5["update_weights"] = self.update_weights 

678 gaussians_group = hdf5.create_group("gaussians") 

679 gaussians_group["means"] = self.means 

680 gaussians_group["variances"] = self.variances 

681 gaussians_group["variance_thresholds"] = self.variance_thresholds 

682 

683 def __eq__(self, other): 

684 if self._means is None: 

685 return False 

686 

687 return ( 

688 np.allclose(self.means, other.means) 

689 and np.allclose(self.variances, other.variances) 

690 and np.allclose(self.variance_thresholds, other.variance_thresholds) 

691 and np.allclose(self.weights, other.weights) 

692 ) 

693 

694 def is_similar_to(self, other, rtol=1e-5, atol=1e-8): 

695 """Returns True if `other` has the same gaussians (within a tolerance).""" 

696 return ( 

697 np.allclose(self.means, other.means, rtol=rtol, atol=atol) 

698 and np.allclose( 

699 self.variances, other.variances, rtol=rtol, atol=atol 

700 ) 

701 and np.allclose( 

702 self.variance_thresholds, 

703 other.variance_thresholds, 

704 rtol=rtol, 

705 atol=atol, 

706 ) 

707 and np.allclose(self.weights, other.weights, rtol=rtol, atol=atol) 

708 ) 

709 

710 def initialize_gaussians( 

711 self, 

712 data: "Union[np.ndarray[('n_samples', 'n_features'), float], None]" = None, # noqa: F821 

713 ): 

714 """Populates gaussians parameters with either k-means or the UBM values.""" 

715 if self.trainer == "map": 

716 self.means = copy.deepcopy(self.ubm.means) 

717 self.variances = copy.deepcopy(self.ubm.variances) 

718 self.variance_thresholds = copy.deepcopy( 

719 self.ubm.variance_thresholds 

720 ) 

721 self.weights = copy.deepcopy(self.ubm.weights) 

722 else: 

723 logger.debug("GMM means was never set. Initializing with k-means.") 

724 if data is None: 

725 raise ValueError("Data is required when training with k-means.") 

726 logger.info("Initializing GMM with k-means.") 

727 kmeans_machine = self.k_means_trainer or KMeansMachine( 

728 self.n_gaussians, 

729 random_state=self.random_state, 

730 ) 

731 kmeans_machine = kmeans_machine.fit(data) 

732 

733 # Set the GMM machine's gaussians with the results of k-means 

734 self.means = copy.deepcopy(kmeans_machine.centroids_) 

735 logger.debug( 

736 "Estimating the variance and weights of each gaussian from kmeans." 

737 ) 

738 ( 

739 self.variances, 

740 self.weights, 

741 ) = kmeans_machine.get_variances_and_weights_for_each_cluster(data) 

742 logger.debug("Done.") 

743 

744 def log_weighted_likelihood( 

745 self, 

746 data: "np.ndarray[('n_samples', 'n_features'), float]", # noqa: F821 

747 ): 

748 """Returns the weighted log likelihood for each Gaussian for a set of data. 

749 

750 Parameters 

751 ---------- 

752 data 

753 Data to compute the log likelihood on. 

754 

755 Returns 

756 ------- 

757 array of shape (n_gaussians, n_samples) 

758 The weighted log likelihood of each sample of each Gaussian. 

759 """ 

760 return log_weighted_likelihood( 

761 data=data, 

762 machine=self, 

763 ) 

764 

765 def log_likelihood( 

766 self, 

767 data: "np.ndarray[('n_samples', 'n_features'), float]", # noqa: F821 

768 ): 

769 """Returns the current log likelihood for a set of data in this Machine. 

770 

771 Parameters 

772 ---------- 

773 data 

774 Data to compute the log likelihood on. 

775 

776 Returns 

777 ------- 

778 array of shape (n_samples) 

779 The log likelihood of each sample. 

780 """ 

781 return log_likelihood( 

782 data=data, 

783 machine=self, 

784 ) 

785 

786 def fit(self, X, y=None): 

787 """Trains the GMM on data until convergence or maximum step is reached.""" 

788 

789 input_is_dask, X = check_and_persist_dask_input(X) 

790 

791 if self._means is None: 

792 self.initialize_gaussians(X) 

793 else: 

794 logger.debug("GMM means already set. Initialization was not run!") 

795 

796 if self._variances is None: 

797 logger.warning( 

798 "Variances were not defined before fit. Using variance=1" 

799 ) 

800 self.variances = np.ones_like(self.means) 

801 

802 X = array_to_delayed_list(X, input_is_dask) 

803 

804 average_output = 0 

805 logger.info("Training GMM...") 

806 step = 0 

807 while self.max_fitting_steps is None or step < self.max_fitting_steps: 

808 step += 1 

809 logger.info( 

810 f"Iteration {step:3d}" 

811 + ( 

812 f"/{self.max_fitting_steps:3d}" 

813 if self.max_fitting_steps 

814 else "" 

815 ) 

816 ) 

817 

818 average_output_previous = average_output 

819 

820 # compute the e-m steps 

821 if input_is_dask: 

822 stats = [ 

823 dask.delayed(e_step)( 

824 data=xx, 

825 machine=self, 

826 ) 

827 for xx in X 

828 ] 

829 new_machine, average_output = dask.compute( 

830 dask.delayed(m_step)(stats, self) 

831 )[0] 

832 for attr in ["weights", "means", "variances"]: 

833 setattr(self, attr, getattr(new_machine, attr)) 

834 else: 

835 stats = [ 

836 e_step( 

837 data=X, 

838 machine=self, 

839 ) 

840 ] 

841 _, average_output = m_step(stats, self) 

842 

843 logger.debug(f"log likelihood = {average_output}") 

844 if step > 1: 

845 convergence_value = abs( 

846 (average_output_previous - average_output) 

847 / average_output_previous 

848 ) 

849 logger.debug( 

850 f"convergence val = {convergence_value} and threshold = {self.convergence_threshold}" 

851 ) 

852 

853 # Terminates if converged (and likelihood computation is set) 

854 if ( 

855 self.convergence_threshold is not None 

856 and convergence_value <= self.convergence_threshold 

857 ): 

858 logger.info( 

859 "Reached convergence threshold. Training stopped." 

860 ) 

861 break 

862 

863 else: 

864 logger.info( 

865 "Reached maximum step. Training stopped without convergence." 

866 ) 

867 return self 

868 

869 def acc_stats(self, X): 

870 """Returns the statistics for `X`.""" 

871 # we need this because sometimes the transform function gets overridden 

872 return e_step(data=X, machine=self) 

873 

874 def transform(self, X): 

875 """Returns the statistics for `X`.""" 

876 return self.stats_per_sample(X) 

877 

878 def stats_per_sample(self, X): 

879 return [e_step(data=xx, machine=self) for xx in X] 

880 

881 

882def ml_gmm_m_step( 

883 machine: GMMMachine, 

884 statistics: GMMStats, 

885 update_means=True, 

886 update_variances=False, 

887 update_weights=False, 

888 mean_var_update_threshold=EPSILON, 

889 **kwargs, 

890): 

891 """Updates a gmm machine parameter according to the e-step statistics.""" 

892 logger.debug("ML GMM Trainer m-step") 

893 

894 # Threshold the low n to prevent divide by zero 

895 thresholded_n = np.clip(statistics.n, mean_var_update_threshold, None) 

896 

897 # Update weights if requested 

898 # (Equation 9.26 of Bishop, "Pattern recognition and machine learning", 2006) 

899 if update_weights: 

900 logger.debug("Update weights.") 

901 machine.weights = thresholded_n / statistics.t 

902 

903 # Update GMM parameters using the sufficient statistics (m_ss): 

904 

905 # Update means if requested 

906 # (Equation 9.24 of Bishop, "Pattern recognition and machine learning", 2006) 

907 if update_means: 

908 logger.debug("Update means.") 

909 # Using n with the applied threshold 

910 machine.means = statistics.sum_px / thresholded_n[:, None] 

911 

912 # Update variances if requested 

913 # (Equation 9.25 of Bishop, "Pattern recognition and machine learning", 2006) 

914 # ...but we use the "computational formula for the variance", i.e. 

915 # var = 1/n * sum (P(x-mean)(x-mean)) 

916 # = 1/n * sum (Pxx) - mean^2 

917 if update_variances: 

918 logger.debug("Update variances.") 

919 machine.variances = statistics.sum_pxx / thresholded_n[ 

920 :, None 

921 ] - np.power(machine.means, 2) 

922 

923 

924def map_gmm_m_step( 

925 machine: GMMMachine, 

926 statistics: GMMStats, 

927 update_means=True, 

928 update_variances=False, 

929 update_weights=False, 

930 reynolds_adaptation=True, 

931 relevance_factor=4, 

932 alpha=0.5, 

933 mean_var_update_threshold=EPSILON, 

934): 

935 """Updates a GMMMachine parameters using statistics adapted from a UBM.""" 

936 if machine.ubm is None: 

937 raise ValueError("A machine used for MAP must have a UBM.") 

938 # Calculate the "data-dependent adaptation coefficient", alpha_i 

939 # [array of shape (n_gaussians, )] 

940 if reynolds_adaptation: 

941 alpha = statistics.n / (statistics.n + relevance_factor) 

942 else: 

943 if not hasattr(alpha, "ndim"): 

944 alpha = np.full((machine.n_gaussians,), alpha) 

945 

946 # - Update weights if requested 

947 # Equation 11 of Reynolds et al., "Speaker Verification Using Adapted 

948 # Gaussian Mixture Models", Digital Signal Processing, 2000 

949 if update_weights: 

950 # Calculate the maximum likelihood weights [array of shape (n_gaussians,)] 

951 ml_weights = statistics.n / statistics.t 

952 

953 # Calculate the new weights 

954 machine.weights = alpha * ml_weights + (1 - alpha) * machine.ubm.weights 

955 

956 # Apply the scale factor, gamma, to ensure the new weights sum to unity 

957 gamma = machine.weights.sum() 

958 machine.weights /= gamma 

959 

960 # Update GMM parameters 

961 # - Update means if requested 

962 # Equation 12 of Reynolds et al., "Speaker Verification Using Adapted 

963 # Gaussian Mixture Models", Digital Signal Processing, 2000 

964 if update_means: 

965 # Apply threshold to prevent divide by zero below 

966 n_threshold = np.where( 

967 statistics.n < mean_var_update_threshold, 

968 mean_var_update_threshold, 

969 statistics.n, 

970 ) 

971 # n_threshold = np.full(statistics.n.shape, fill_value=mean_var_update_threshold) 

972 # n_threshold[statistics.n > mean_var_update_threshold] = statistics.n[ 

973 # statistics.n > mean_var_update_threshold 

974 # ] 

975 new_means = np.multiply( 

976 alpha[:, None], 

977 (statistics.sum_px / n_threshold[:, None]), 

978 ) + np.multiply((1 - alpha[:, None]), machine.ubm.means) 

979 machine.means = np.where( 

980 statistics.n[:, None] < mean_var_update_threshold, 

981 machine.ubm.means, 

982 new_means, 

983 ) 

984 

985 # - Update variance if requested 

986 # Equation 13 of Reynolds et al., "Speaker Verification Using Adapted 

987 # Gaussian Mixture Models", Digital Signal Processing, 2000 

988 if update_variances: 

989 # Calculate new variances (equation 13) 

990 prior_norm_variances = ( 

991 machine.ubm.variances + machine.ubm.means 

992 ) - np.power(machine.means, 2) 

993 new_variances = ( 

994 alpha[:, None] * statistics.sum_pxx / statistics.n[:, None] 

995 + (1 - alpha[:, None]) * (machine.ubm.variances + machine.ubm.means) 

996 - np.power(machine.means, 2) 

997 ) 

998 machine.variances = np.where( 

999 statistics.n[:, None] < mean_var_update_threshold, 

1000 prior_norm_variances, 

1001 new_variances, 

1002 )