Coverage for src/bob/bio/base/algorithm/gmm.py: 36%
72 statements
« prev ^ index » next coverage.py v7.6.0, created at 2024-07-12 22:34 +0200
« prev ^ index » next coverage.py v7.6.0, created at 2024-07-12 22:34 +0200
1"""Interface between the lower level GMM classes and the Algorithm Transformer.
3Implements the enroll and score methods using the low level GMM implementation.
5This adds the notions of models, probes, enrollment, and scores to GMM.
6"""
9import copy
10import logging
12from typing import Callable, Union
14import dask.array as da
15import numpy as np
17from h5py import File as HDF5File
19from bob.bio.base.pipelines import BioAlgorithm
20from bob.learn.em import GMMMachine, GMMStats, linear_scoring
22logger = logging.getLogger(__name__)
25def check_data_dim(data, expected_ndim):
26 """Stacks the features into a matrix of shape (n_samples, n_features) or
27 labels into shape of (n_samples,) if the input data is not like that already
29 Parameters
30 ----------
31 data : array-like
32 features or labels
33 expected_ndim : int
34 expected number of dimensions of the data
36 Returns
37 -------
38 stacked_data : array-like
39 stacked features or labels if needed
40 """
41 if expected_ndim not in (1, 2):
42 raise ValueError(
43 f"expected_ndim must be 1 or 2 but got {expected_ndim}"
44 )
46 if expected_ndim == 1:
47 stack_function = np.concatenate
48 else:
49 stack_function = np.vstack
51 if data[0].ndim == expected_ndim:
52 return stack_function(data)
54 return data
57class GMM(GMMMachine, BioAlgorithm):
58 """Algorithm for computing UBM and Gaussian Mixture Models of the features.
60 Features must be normalized to zero mean and unit standard deviation.
62 Models are MAP GMM machines trained from a UBM on the enrollment feature set.
64 The UBM is a ML GMM machine trained on the training feature set.
66 Probes are GMM statistics of features projected on the UBM.
67 """
69 def __init__(
70 self,
71 # parameters for the GMM
72 n_gaussians: int,
73 # parameters of UBM training
74 k_means_trainer=None,
75 max_fitting_steps: int = 25, # Maximum number of iterations for GMM Training
76 convergence_threshold: float = 5e-4, # Threshold to end the ML training
77 mean_var_update_threshold: float = 5e-4, # Minimum value that a variance can reach
78 update_means: bool = True,
79 update_variances: bool = True,
80 update_weights: bool = True,
81 # parameters of the GMM enrollment (MAP)
82 enroll_iterations: int = 1,
83 enroll_update_means: bool = True,
84 enroll_update_variances: bool = False,
85 enroll_update_weights: bool = False,
86 enroll_relevance_factor: Union[float, None] = 4,
87 enroll_alpha: float = 0.5,
88 # scoring
89 scoring_function: Callable = linear_scoring,
90 # RNG
91 random_state: int = 5489,
92 # other
93 return_stats_in_transform: bool = False,
94 **kwargs,
95 ):
96 """Initializes the local UBM-GMM tool chain.
98 Parameters
99 ----------
100 n_gaussians
101 The number of Gaussians used in the UBM and the models.
102 kmeans_trainer
103 The kmeans machine used to train and initialize the UBM.
104 kmeans_init_iterations
105 Number of iterations used for setting the k-means initial centroids.
106 if None, will use the same as kmeans_training_iterations.
107 kmeans_oversampling_factor
108 Oversampling factor used by k-means initializer.
109 max_fitting_steps
110 Number of e-m iterations for training the UBM.
111 convergence_threshold
112 Convergence threshold to halt the GMM training early.
113 mean_var_update_threshold
114 Minimum value a variance of the Gaussians can reach.
115 update_weights
116 Decides wether the weights of the Gaussians are updated while training.
117 update_means
118 Decides wether the means of the Gaussians are updated while training.
119 update_variances
120 Decides wether the variances of the Gaussians are updated while training.
121 enroll_iterations
122 Number of iterations for the MAP GMM used for enrollment.
123 enroll_update_weights
124 Decides wether the weights of the Gaussians are updated while enrolling.
125 enroll_update_means
126 Decides wether the means of the Gaussians are updated while enrolling.
127 enroll_update_variances
128 Decides wether the variances of the Gaussians are updated while enrolling.
129 enroll_relevance_factor
130 For enrollment: MAP relevance factor as described in Reynolds paper.
131 If None, will not apply Reynolds adaptation.
132 enroll_alpha
133 For enrollment: MAP adaptation coefficient.
134 random_state
135 Seed for the random number generation.
136 scoring_function
137 Function returning a score from a model, a UBM, and a probe.
138 """
139 super().__init__(
140 n_gaussians=n_gaussians,
141 trainer="ml",
142 max_fitting_steps=max_fitting_steps,
143 convergence_threshold=convergence_threshold,
144 update_means=update_means,
145 update_variances=update_variances,
146 update_weights=update_weights,
147 mean_var_update_threshold=mean_var_update_threshold,
148 k_means_trainer=k_means_trainer,
149 random_state=random_state,
150 **kwargs,
151 )
153 self.enroll_relevance_factor = enroll_relevance_factor
154 self.enroll_alpha = enroll_alpha
155 self.enroll_iterations = enroll_iterations
156 self.enroll_update_means = enroll_update_means
157 self.enroll_update_weights = enroll_update_weights
158 self.enroll_update_variances = enroll_update_variances
159 self.scoring_function = scoring_function
160 self.return_stats_in_transform = return_stats_in_transform
162 def save_model(self, ubm_file):
163 """Saves the projector (UBM) to file."""
164 super().save(ubm_file)
166 def load_model(self, ubm_file):
167 """Loads the projector (UBM) from a file."""
168 super().load(ubm_file)
170 def project(self, array):
171 """Computes GMM statistics against a UBM, given a 2D array of feature vectors
173 This is applied to the probes before scoring.
174 """
175 array = check_data_dim(array, expected_ndim=2)
176 logger.debug("Projecting %d feature vectors", array.shape[0])
177 # Accumulates statistics
178 gmm_stats = self.acc_stats(array)
180 # Return the resulting statistics
181 return gmm_stats
183 def enroll(self, data):
184 """Enrolls a GMM using MAP adaptation given a reference's feature vectors
186 Returns a GMMMachine tuned from the UBM with MAP on a biometric reference data.
187 """
189 # if input is a list (or SampleBatch) of 2 dimensional arrays, stack them
190 data = check_data_dim(data, expected_ndim=2)
192 # Use the array to train a GMM and return it
193 logger.info("Enrolling with %d feature vectors", data.shape[0])
195 gmm = GMMMachine(
196 n_gaussians=self.n_gaussians,
197 trainer="map",
198 ubm=copy.deepcopy(self),
199 convergence_threshold=self.convergence_threshold,
200 max_fitting_steps=self.enroll_iterations,
201 random_state=self.random_state,
202 update_means=self.enroll_update_means,
203 update_variances=self.enroll_update_variances,
204 update_weights=self.enroll_update_weights,
205 mean_var_update_threshold=self.mean_var_update_threshold,
206 map_relevance_factor=self.enroll_relevance_factor,
207 map_alpha=self.enroll_alpha,
208 )
209 gmm.fit(data)
210 return gmm
212 def create_templates(self, list_of_feature_sets, enroll):
213 if enroll:
214 return [
215 self.enroll(feature_set) for feature_set in list_of_feature_sets
216 ]
217 else:
218 return [
219 self.project(feature_set)
220 for feature_set in list_of_feature_sets
221 ]
223 def compare(self, enroll_templates, probe_templates):
224 return self.scoring_function(
225 models_means=enroll_templates,
226 ubm=self,
227 test_stats=probe_templates,
228 frame_length_normalization=True,
229 )
231 def read_biometric_reference(self, model_file):
232 """Reads an enrolled reference model, which is a MAP GMMMachine."""
233 return GMMMachine.from_hdf5(HDF5File(model_file, "r"), ubm=self)
235 def write_biometric_reference(self, model: GMMMachine, model_file):
236 """Write the enrolled reference (MAP GMMMachine) into a file."""
237 return model.save(model_file)
239 def fit(self, X, y=None, **kwargs):
240 """Trains the UBM."""
241 # Stack all the samples in a 2D array of features
242 if isinstance(X, da.Array):
243 X = X.persist()
245 # if input is a list (or SampleBatch) of 2 dimensional arrays, stack them
246 X = check_data_dim(X, expected_ndim=2)
248 logger.debug(
249 f"Training UBM machine with {self.n_gaussians} gaussians and {len(X)} samples"
250 )
252 super().fit(X)
254 return self
256 def transform(self, X, **kwargs):
257 """Passthrough. Enroll applies a different transform as score."""
258 # The idea would be to apply the projection in Transform (going from extracted
259 # to GMMStats), but we must not apply this during the training or enrollment
260 # (those require extracted data directly, not projected).
261 # `project` is applied in the score function directly.
262 if not self.return_stats_in_transform:
263 return X
264 return super().transform(X)
266 @classmethod
267 def custom_enrolled_save_fn(cls, data, path):
268 data.save(path)
270 def custom_enrolled_load_fn(self, path):
271 return GMMMachine.from_hdf5(path, ubm=self)
273 def _more_tags(self):
274 return {
275 "bob_fit_supports_dask_array": True,
276 "bob_features_save_fn": GMMStats.save,
277 "bob_features_load_fn": GMMStats.from_hdf5,
278 "bob_enrolled_save_fn": self.custom_enrolled_save_fn,
279 "bob_enrolled_load_fn": self.custom_enrolled_load_fn,
280 "bob_checkpoint_features": self.return_stats_in_transform,
281 }