Coverage for src/bob/fusion/base/algorithm/GMM.py: 100%

38 statements  

« prev     ^ index     » next       coverage.py v7.6.5, created at 2024-11-14 22:15 +0100

1#!/usr/bin/env python 

2 

3from __future__ import absolute_import, division 

4 

5import logging 

6 

7import numpy as np 

8 

9from bob.learn.em import GMMMachine 

10 

11from .AlgorithmBob import AlgorithmBob 

12 

13logger = logging.getLogger(__name__) 

14 

15 

16class GMM(AlgorithmBob): 

17 """GMM Score fusion""" 

18 

19 def __init__( 

20 self, 

21 # parameters for the GMM 

22 number_of_gaussians=None, 

23 # parameters of GMM training 

24 # Maximum number of iterations for ML GMM Training 

25 gmm_training_iterations=25, 

26 # Threshold to end the ML training 

27 training_threshold=5e-4, 

28 # Minimum value that a variance can reach 

29 variance_threshold=5e-4, 

30 update_weights=True, 

31 update_means=True, 

32 update_variances=True, 

33 init_seed=5489, 

34 **kwargs, 

35 ): 

36 super().__init__(classifier=self, **kwargs) 

37 self.str["number_of_gaussians"] = number_of_gaussians 

38 self.str["gmm_training_iterations"] = gmm_training_iterations 

39 self.str["training_threshold"] = training_threshold 

40 self.str["variance_threshold"] = variance_threshold 

41 self.str["update_weights"] = update_weights 

42 self.str["update_means"] = update_means 

43 self.str["update_variances"] = update_variances 

44 self.str["init_seed"] = init_seed 

45 

46 # copy parameters 

47 self.n_gaussians = number_of_gaussians 

48 self.gmm_training_iterations = gmm_training_iterations 

49 self.training_threshold = training_threshold 

50 self.variance_threshold = variance_threshold 

51 self.update_weights = update_weights 

52 self.update_means = update_means 

53 self.update_variances = update_variances 

54 self.init_seed = init_seed 

55 

56 # this is needed to be able to load the machine 

57 self.machine = GMMMachine(n_gaussians=1) 

58 

59 def train(self, train_neg, train_pos, devel_neg=None, devel_pos=None): 

60 logger.info("Using only positive samples for training") 

61 array = train_pos 

62 logger.debug("Training files have the shape of {}".format(array.shape)) 

63 

64 if self.n_gaussians is None: 

65 self.n_gaussians = array.shape[1] + 1 

66 logger.warning( 

67 "Number of Gaussians was None. " 

68 "Using {}.".format(self.n_gaussians) 

69 ) 

70 

71 # Creates the machines (KMeans and GMM) 

72 logger.debug("Training GMM machine") 

73 self.machine = GMMMachine( 

74 n_gaussians=self.n_gaussians, 

75 convergence_threshold=self.training_threshold, 

76 max_fitting_steps=self.gmm_training_iterations, 

77 random_state=self.init_seed, 

78 update_means=self.update_means, 

79 update_variances=self.update_variances, 

80 update_weights=self.update_weights, 

81 ) 

82 self.machine.fit(array) 

83 

84 def decision_function(self, scores): 

85 return np.fromiter( 

86 (self.machine.log_likelihood(s) for s in scores), 

87 float, 

88 scores.shape[0], 

89 )