Coverage for src/bob/fusion/base/algorithm/GMM.py: 100%
38 statements
« prev ^ index » next coverage.py v7.6.5, created at 2024-11-14 22:15 +0100
« prev ^ index » next coverage.py v7.6.5, created at 2024-11-14 22:15 +0100
1#!/usr/bin/env python
3from __future__ import absolute_import, division
5import logging
7import numpy as np
9from bob.learn.em import GMMMachine
11from .AlgorithmBob import AlgorithmBob
13logger = logging.getLogger(__name__)
16class GMM(AlgorithmBob):
17 """GMM Score fusion"""
19 def __init__(
20 self,
21 # parameters for the GMM
22 number_of_gaussians=None,
23 # parameters of GMM training
24 # Maximum number of iterations for ML GMM Training
25 gmm_training_iterations=25,
26 # Threshold to end the ML training
27 training_threshold=5e-4,
28 # Minimum value that a variance can reach
29 variance_threshold=5e-4,
30 update_weights=True,
31 update_means=True,
32 update_variances=True,
33 init_seed=5489,
34 **kwargs,
35 ):
36 super().__init__(classifier=self, **kwargs)
37 self.str["number_of_gaussians"] = number_of_gaussians
38 self.str["gmm_training_iterations"] = gmm_training_iterations
39 self.str["training_threshold"] = training_threshold
40 self.str["variance_threshold"] = variance_threshold
41 self.str["update_weights"] = update_weights
42 self.str["update_means"] = update_means
43 self.str["update_variances"] = update_variances
44 self.str["init_seed"] = init_seed
46 # copy parameters
47 self.n_gaussians = number_of_gaussians
48 self.gmm_training_iterations = gmm_training_iterations
49 self.training_threshold = training_threshold
50 self.variance_threshold = variance_threshold
51 self.update_weights = update_weights
52 self.update_means = update_means
53 self.update_variances = update_variances
54 self.init_seed = init_seed
56 # this is needed to be able to load the machine
57 self.machine = GMMMachine(n_gaussians=1)
59 def train(self, train_neg, train_pos, devel_neg=None, devel_pos=None):
60 logger.info("Using only positive samples for training")
61 array = train_pos
62 logger.debug("Training files have the shape of {}".format(array.shape))
64 if self.n_gaussians is None:
65 self.n_gaussians = array.shape[1] + 1
66 logger.warning(
67 "Number of Gaussians was None. "
68 "Using {}.".format(self.n_gaussians)
69 )
71 # Creates the machines (KMeans and GMM)
72 logger.debug("Training GMM machine")
73 self.machine = GMMMachine(
74 n_gaussians=self.n_gaussians,
75 convergence_threshold=self.training_threshold,
76 max_fitting_steps=self.gmm_training_iterations,
77 random_state=self.init_seed,
78 update_means=self.update_means,
79 update_variances=self.update_variances,
80 update_weights=self.update_weights,
81 )
82 self.machine.fit(array)
84 def decision_function(self, scores):
85 return np.fromiter(
86 (self.machine.log_likelihood(s) for s in scores),
87 float,
88 scores.shape[0],
89 )