Coverage for src/bob/bio/base/algorithm/gmm.py: 36%

72 statements  

« prev     ^ index     » next       coverage.py v7.6.0, created at 2024-07-12 22:34 +0200

1"""Interface between the lower level GMM classes and the Algorithm Transformer. 

2 

3Implements the enroll and score methods using the low level GMM implementation. 

4 

5This adds the notions of models, probes, enrollment, and scores to GMM. 

6""" 

7 

8 

9import copy 

10import logging 

11 

12from typing import Callable, Union 

13 

14import dask.array as da 

15import numpy as np 

16 

17from h5py import File as HDF5File 

18 

19from bob.bio.base.pipelines import BioAlgorithm 

20from bob.learn.em import GMMMachine, GMMStats, linear_scoring 

21 

22logger = logging.getLogger(__name__) 

23 

24 

25def check_data_dim(data, expected_ndim): 

26 """Stacks the features into a matrix of shape (n_samples, n_features) or 

27 labels into shape of (n_samples,) if the input data is not like that already 

28 

29 Parameters 

30 ---------- 

31 data : array-like 

32 features or labels 

33 expected_ndim : int 

34 expected number of dimensions of the data 

35 

36 Returns 

37 ------- 

38 stacked_data : array-like 

39 stacked features or labels if needed 

40 """ 

41 if expected_ndim not in (1, 2): 

42 raise ValueError( 

43 f"expected_ndim must be 1 or 2 but got {expected_ndim}" 

44 ) 

45 

46 if expected_ndim == 1: 

47 stack_function = np.concatenate 

48 else: 

49 stack_function = np.vstack 

50 

51 if data[0].ndim == expected_ndim: 

52 return stack_function(data) 

53 

54 return data 

55 

56 

57class GMM(GMMMachine, BioAlgorithm): 

58 """Algorithm for computing UBM and Gaussian Mixture Models of the features. 

59 

60 Features must be normalized to zero mean and unit standard deviation. 

61 

62 Models are MAP GMM machines trained from a UBM on the enrollment feature set. 

63 

64 The UBM is a ML GMM machine trained on the training feature set. 

65 

66 Probes are GMM statistics of features projected on the UBM. 

67 """ 

68 

69 def __init__( 

70 self, 

71 # parameters for the GMM 

72 n_gaussians: int, 

73 # parameters of UBM training 

74 k_means_trainer=None, 

75 max_fitting_steps: int = 25, # Maximum number of iterations for GMM Training 

76 convergence_threshold: float = 5e-4, # Threshold to end the ML training 

77 mean_var_update_threshold: float = 5e-4, # Minimum value that a variance can reach 

78 update_means: bool = True, 

79 update_variances: bool = True, 

80 update_weights: bool = True, 

81 # parameters of the GMM enrollment (MAP) 

82 enroll_iterations: int = 1, 

83 enroll_update_means: bool = True, 

84 enroll_update_variances: bool = False, 

85 enroll_update_weights: bool = False, 

86 enroll_relevance_factor: Union[float, None] = 4, 

87 enroll_alpha: float = 0.5, 

88 # scoring 

89 scoring_function: Callable = linear_scoring, 

90 # RNG 

91 random_state: int = 5489, 

92 # other 

93 return_stats_in_transform: bool = False, 

94 **kwargs, 

95 ): 

96 """Initializes the local UBM-GMM tool chain. 

97 

98 Parameters 

99 ---------- 

100 n_gaussians 

101 The number of Gaussians used in the UBM and the models. 

102 kmeans_trainer 

103 The kmeans machine used to train and initialize the UBM. 

104 kmeans_init_iterations 

105 Number of iterations used for setting the k-means initial centroids. 

106 if None, will use the same as kmeans_training_iterations. 

107 kmeans_oversampling_factor 

108 Oversampling factor used by k-means initializer. 

109 max_fitting_steps 

110 Number of e-m iterations for training the UBM. 

111 convergence_threshold 

112 Convergence threshold to halt the GMM training early. 

113 mean_var_update_threshold 

114 Minimum value a variance of the Gaussians can reach. 

115 update_weights 

116 Decides wether the weights of the Gaussians are updated while training. 

117 update_means 

118 Decides wether the means of the Gaussians are updated while training. 

119 update_variances 

120 Decides wether the variances of the Gaussians are updated while training. 

121 enroll_iterations 

122 Number of iterations for the MAP GMM used for enrollment. 

123 enroll_update_weights 

124 Decides wether the weights of the Gaussians are updated while enrolling. 

125 enroll_update_means 

126 Decides wether the means of the Gaussians are updated while enrolling. 

127 enroll_update_variances 

128 Decides wether the variances of the Gaussians are updated while enrolling. 

129 enroll_relevance_factor 

130 For enrollment: MAP relevance factor as described in Reynolds paper. 

131 If None, will not apply Reynolds adaptation. 

132 enroll_alpha 

133 For enrollment: MAP adaptation coefficient. 

134 random_state 

135 Seed for the random number generation. 

136 scoring_function 

137 Function returning a score from a model, a UBM, and a probe. 

138 """ 

139 super().__init__( 

140 n_gaussians=n_gaussians, 

141 trainer="ml", 

142 max_fitting_steps=max_fitting_steps, 

143 convergence_threshold=convergence_threshold, 

144 update_means=update_means, 

145 update_variances=update_variances, 

146 update_weights=update_weights, 

147 mean_var_update_threshold=mean_var_update_threshold, 

148 k_means_trainer=k_means_trainer, 

149 random_state=random_state, 

150 **kwargs, 

151 ) 

152 

153 self.enroll_relevance_factor = enroll_relevance_factor 

154 self.enroll_alpha = enroll_alpha 

155 self.enroll_iterations = enroll_iterations 

156 self.enroll_update_means = enroll_update_means 

157 self.enroll_update_weights = enroll_update_weights 

158 self.enroll_update_variances = enroll_update_variances 

159 self.scoring_function = scoring_function 

160 self.return_stats_in_transform = return_stats_in_transform 

161 

162 def save_model(self, ubm_file): 

163 """Saves the projector (UBM) to file.""" 

164 super().save(ubm_file) 

165 

166 def load_model(self, ubm_file): 

167 """Loads the projector (UBM) from a file.""" 

168 super().load(ubm_file) 

169 

170 def project(self, array): 

171 """Computes GMM statistics against a UBM, given a 2D array of feature vectors 

172 

173 This is applied to the probes before scoring. 

174 """ 

175 array = check_data_dim(array, expected_ndim=2) 

176 logger.debug("Projecting %d feature vectors", array.shape[0]) 

177 # Accumulates statistics 

178 gmm_stats = self.acc_stats(array) 

179 

180 # Return the resulting statistics 

181 return gmm_stats 

182 

183 def enroll(self, data): 

184 """Enrolls a GMM using MAP adaptation given a reference's feature vectors 

185 

186 Returns a GMMMachine tuned from the UBM with MAP on a biometric reference data. 

187 """ 

188 

189 # if input is a list (or SampleBatch) of 2 dimensional arrays, stack them 

190 data = check_data_dim(data, expected_ndim=2) 

191 

192 # Use the array to train a GMM and return it 

193 logger.info("Enrolling with %d feature vectors", data.shape[0]) 

194 

195 gmm = GMMMachine( 

196 n_gaussians=self.n_gaussians, 

197 trainer="map", 

198 ubm=copy.deepcopy(self), 

199 convergence_threshold=self.convergence_threshold, 

200 max_fitting_steps=self.enroll_iterations, 

201 random_state=self.random_state, 

202 update_means=self.enroll_update_means, 

203 update_variances=self.enroll_update_variances, 

204 update_weights=self.enroll_update_weights, 

205 mean_var_update_threshold=self.mean_var_update_threshold, 

206 map_relevance_factor=self.enroll_relevance_factor, 

207 map_alpha=self.enroll_alpha, 

208 ) 

209 gmm.fit(data) 

210 return gmm 

211 

212 def create_templates(self, list_of_feature_sets, enroll): 

213 if enroll: 

214 return [ 

215 self.enroll(feature_set) for feature_set in list_of_feature_sets 

216 ] 

217 else: 

218 return [ 

219 self.project(feature_set) 

220 for feature_set in list_of_feature_sets 

221 ] 

222 

223 def compare(self, enroll_templates, probe_templates): 

224 return self.scoring_function( 

225 models_means=enroll_templates, 

226 ubm=self, 

227 test_stats=probe_templates, 

228 frame_length_normalization=True, 

229 ) 

230 

231 def read_biometric_reference(self, model_file): 

232 """Reads an enrolled reference model, which is a MAP GMMMachine.""" 

233 return GMMMachine.from_hdf5(HDF5File(model_file, "r"), ubm=self) 

234 

235 def write_biometric_reference(self, model: GMMMachine, model_file): 

236 """Write the enrolled reference (MAP GMMMachine) into a file.""" 

237 return model.save(model_file) 

238 

239 def fit(self, X, y=None, **kwargs): 

240 """Trains the UBM.""" 

241 # Stack all the samples in a 2D array of features 

242 if isinstance(X, da.Array): 

243 X = X.persist() 

244 

245 # if input is a list (or SampleBatch) of 2 dimensional arrays, stack them 

246 X = check_data_dim(X, expected_ndim=2) 

247 

248 logger.debug( 

249 f"Training UBM machine with {self.n_gaussians} gaussians and {len(X)} samples" 

250 ) 

251 

252 super().fit(X) 

253 

254 return self 

255 

256 def transform(self, X, **kwargs): 

257 """Passthrough. Enroll applies a different transform as score.""" 

258 # The idea would be to apply the projection in Transform (going from extracted 

259 # to GMMStats), but we must not apply this during the training or enrollment 

260 # (those require extracted data directly, not projected). 

261 # `project` is applied in the score function directly. 

262 if not self.return_stats_in_transform: 

263 return X 

264 return super().transform(X) 

265 

266 @classmethod 

267 def custom_enrolled_save_fn(cls, data, path): 

268 data.save(path) 

269 

270 def custom_enrolled_load_fn(self, path): 

271 return GMMMachine.from_hdf5(path, ubm=self) 

272 

273 def _more_tags(self): 

274 return { 

275 "bob_fit_supports_dask_array": True, 

276 "bob_features_save_fn": GMMStats.save, 

277 "bob_features_load_fn": GMMStats.from_hdf5, 

278 "bob_enrolled_save_fn": self.custom_enrolled_save_fn, 

279 "bob_enrolled_load_fn": self.custom_enrolled_load_fn, 

280 "bob_checkpoint_features": self.return_stats_in_transform, 

281 }