Coverage for src/bob/fusion/base/script/fuse.py: 96%

157 statements  

« prev     ^ index     » next       coverage.py v7.6.5, created at 2024-11-14 22:15 +0100

1"""A script to help for score fusion experiments 

2""" 

3from __future__ import absolute_import, division, print_function 

4 

5import logging 

6import os 

7import sys 

8 

9import click 

10import numpy as np 

11 

12from clapper.click import ResourceOption, verbosity_option 

13 

14from bob.bio.base import utils 

15from bob.bio.base.score import dump_score, load_score 

16 

17from ..tools import ( 

18 check_consistency, 

19 get_2negatives_1positive, 

20 get_gza_from_lines_list, 

21 get_score_lines, 

22 get_scores, 

23 remove_nan, 

24) 

25 

26logger = logging.getLogger(__name__) 

27 

28 

29def write_info( 

30 scores, 

31 algorithm, 

32 groups, 

33 output_dir, 

34 model_file, 

35 skip_check, 

36 force, 

37 **kwargs, 

38): 

39 info = """ 

40scores: %s 

41algorithm: %s 

42groups: %s 

43output_dir: %s 

44model_file: %s 

45skip_check: %s 

46force: %s 

47kwargs: %s 

48 """ % ( 

49 scores, 

50 algorithm, 

51 groups, 

52 output_dir, 

53 model_file, 

54 skip_check, 

55 force, 

56 kwargs, 

57 ) 

58 logger.debug(info) 

59 

60 info_file = os.path.join(output_dir, "Experiment.info") 

61 with open(info_file, "w") as f: 

62 f.write("Command line:\n") 

63 f.write(str(sys.argv[1:]) + "\n\n") 

64 f.write("Configuration:\n\n") 

65 f.write(info) 

66 

67 

68def save_fused_scores(save_path, fused_scores, score_lines): 

69 score_lines["score"] = fused_scores 

70 gen, zei, atk, _, _, _ = get_2negatives_1positive(score_lines) 

71 os.makedirs(os.path.dirname(save_path), exist_ok=True) 

72 dump_score(save_path, score_lines) 

73 dump_score(save_path + "-licit", np.append(gen, zei)) 

74 dump_score(save_path + "-spoof", np.append(gen, atk)) 

75 dump_score(save_path + "-real", np.append(gen, zei)) 

76 dump_score(save_path + "-attack", atk) 

77 

78 

79def routine_fusion( 

80 algorithm, 

81 model_file, 

82 scores_train_lines, 

83 scores_train, 

84 train_neg, 

85 train_pos, 

86 fused_train_file, 

87 scores_dev_lines=None, 

88 scores_dev=None, 

89 dev_neg=None, 

90 dev_pos=None, 

91 fused_dev_file=None, 

92 scores_eval_lines=None, 

93 scores_eval=None, 

94 fused_eval_file=None, 

95 force=False, 

96 min_file_size=1000, 

97 do_training=True, 

98): 

99 # load the model if model_file exists and no training data was provided 

100 if os.path.exists(model_file) and not do_training: 

101 logger.info("Loading the algorithm from %s", model_file) 

102 algorithm = algorithm.load(model_file) 

103 

104 # train the preprocessors 

105 if train_neg is not None and do_training: 

106 train_scores = np.vstack((train_neg, train_pos)) 

107 neg_len = train_neg.shape[0] 

108 y = np.zeros((train_scores.shape[0],), dtype="bool") 

109 y[neg_len:] = True 

110 algorithm.train_preprocessors(train_scores, y) 

111 

112 # preprocess data 

113 if scores_train is not None: 

114 scores_train = algorithm.preprocess(scores_train) 

115 train_neg, train_pos = algorithm.preprocess( 

116 train_neg 

117 ), algorithm.preprocess(train_pos) 

118 

119 if scores_dev is not None: 

120 scores_dev = algorithm.preprocess(scores_dev) 

121 dev_neg, dev_pos = algorithm.preprocess(dev_neg), algorithm.preprocess( 

122 dev_pos 

123 ) 

124 

125 if scores_eval is not None: 

126 scores_eval = algorithm.preprocess(scores_eval) 

127 

128 # Train the classifier 

129 if train_neg is not None and do_training: 

130 if utils.check_file(model_file, force, min_file_size): 

131 logger.info("model '%s' already exists.", model_file) 

132 algorithm = algorithm.load(model_file) 

133 else: 

134 algorithm.train(train_neg, train_pos, dev_neg, dev_pos) 

135 algorithm.save(model_file) 

136 

137 # fuse the scores (train) 

138 if scores_train is not None: 

139 if utils.check_file(fused_train_file, force, min_file_size): 

140 logger.info("score file '%s' already exists.", fused_train_file) 

141 else: 

142 fused_scores_train = algorithm.fuse(scores_train) 

143 save_fused_scores( 

144 fused_train_file, fused_scores_train, scores_train_lines 

145 ) 

146 

147 # fuse the scores (dev) 

148 if scores_dev is not None: 

149 if utils.check_file(fused_dev_file, force, min_file_size): 

150 logger.info("score file '%s' already exists.", fused_dev_file) 

151 else: 

152 fused_scores_dev = algorithm.fuse(scores_dev) 

153 save_fused_scores( 

154 fused_dev_file, fused_scores_dev, scores_dev_lines 

155 ) 

156 

157 # fuse the scores (eval) 

158 if scores_eval is not None: 

159 if utils.check_file(fused_eval_file, force, min_file_size): 

160 logger.info("score file '%s' already exists.", fused_eval_file) 

161 else: 

162 fused_scores_eval = algorithm.fuse(scores_eval) 

163 save_fused_scores( 

164 fused_eval_file, fused_scores_eval, scores_eval_lines 

165 ) 

166 

167 

168@click.command( 

169 epilog="""\b 

170Examples: 

171# normal score fusion using the mean algorithm: 

172$ bob fusion fuse -vvv sys1/scores-{world,dev,eval} sys2/scores-{world,dev,eval} -a mean 

173# same thing but more compact using bash expansion: 

174$ bob fusion fuse -vvv {sys1,sys2}/scores-{world,dev,eval} -a mean 

175# using an already trained algorithm: 

176$ bob fusion fuse -vvv {sys1,sys2}/scores-{dev,eval} -g dev -g eval -a mean -m /path/saved_model.pkl 

177# train an algorithm using development set scores: 

178$ bob fusion fuse -vvv {sys1,sys2}/scores-{dev,dev,eval} -a mean 

179# run fusion without eval scores: 

180$ bob fusion fuse -vvv {sys1,sys2}/scores-{world,dev} -g train -g dev -a mean 

181# run fusion with bio and pad systems: 

182$ bob fusion fuse -vvv sys_bio/scores-{world,dev,eval} sys_pad/scores-{train,dev,eval} -a mean 

183""" 

184) 

185@click.argument("scores", nargs=-1, required=True, type=click.Path(exists=True)) 

186@click.option( 

187 "--algorithm", 

188 "-a", 

189 required=True, 

190 cls=ResourceOption, 

191 entry_point_group="bob.fusion.algorithm", 

192 help="The fusion algorithm " "(:any:`bob.fusion.algorithm.Algorithm`).", 

193) 

194@click.option( 

195 "--groups", 

196 "-g", 

197 default=("train", "dev", "eval"), 

198 multiple=True, 

199 show_default=True, 

200 type=click.Choice(("train", "dev", "eval")), 

201 help="The groups of the scores. This should correspond to the " 

202 "scores that are provided. The order of options are important " 

203 "and should be in the same order as (train, dev, eval). Repeat " 

204 "this option for multiple values.", 

205) 

206@click.option( 

207 "--output-dir", 

208 "-o", 

209 required=True, 

210 default="fusion_result", 

211 show_default=True, 

212 type=click.Path(writable=True), 

213 help="The directory to save the annotations.", 

214) 

215@click.option( 

216 "--model-file", 

217 "-m", 

218 help="The path to where the algorithm will be saved/loaded.", 

219) 

220@click.option( 

221 "--skip-check", 

222 is_flag=True, 

223 show_default=True, 

224 help="If True, it will skip checking for " 

225 "the consistency between scores.", 

226) 

227@click.option( 

228 "--force", 

229 "-f", 

230 is_flag=True, 

231 show_default=True, 

232 help="Whether to overwrite existing files.", 

233) 

234@verbosity_option(logger) 

235def fuse( 

236 scores, 

237 algorithm, 

238 groups, 

239 output_dir, 

240 model_file, 

241 skip_check, 

242 force, 

243 **kwargs, 

244): 

245 """Score fusion 

246 

247 The script takes several scores from different biometric and pad systems 

248 and does score fusion based on the scores and the algorithm provided. 

249 

250 The scores are divided into 3 different sets: train, dev, and eval. 

251 Depending on which of these scores you provide, the script will skip parts 

252 of the execution. Provide train (and optionally dev) score files to train 

253 your algorithm. 

254 

255 \b 

256 Raises 

257 ------ 

258 click.BadArgumentUsage 

259 If the number of score files is not divisible by the number of groups. 

260 click.MissingParameter 

261 If the algorithm is not provided. 

262 """ 

263 os.makedirs(output_dir, exist_ok=True) 

264 if not model_file: 

265 do_training = True 

266 model_file = os.path.join(output_dir, "Model.pkl") 

267 else: 

268 do_training = False 

269 fused_train_file = os.path.join(output_dir, "scores-train") 

270 fused_dev_file = os.path.join(output_dir, "scores-dev") 

271 fused_eval_file = os.path.join(output_dir, "scores-eval") 

272 

273 if not len(scores) % len(groups) == 0: 

274 raise click.BadArgumentUsage( 

275 "The number of scores must be a multiple of the number of groups." 

276 ) 

277 

278 if algorithm is None: 

279 raise click.MissingParameter( 

280 "algorithm must be provided.", param_type="option" 

281 ) 

282 

283 write_info( 

284 scores, 

285 algorithm, 

286 groups, 

287 output_dir, 

288 model_file, 

289 skip_check, 

290 force, 

291 **kwargs, 

292 ) 

293 

294 """Do the actual fusion.""" 

295 

296 train_files, dev_files, eval_files = [], [], [] 

297 for i, (files, grp) in enumerate( 

298 zip((train_files, dev_files, eval_files), ("train", "dev", "eval")) 

299 ): 

300 try: 

301 idx = groups.index(grp) 

302 files.extend(scores[idx :: len(groups)]) 

303 except ValueError: 

304 pass 

305 

306 click.echo("train_files: %s" % train_files) 

307 click.echo("dev_files: %s" % dev_files) 

308 click.echo("eval_files: %s" % eval_files) 

309 

310 # load the scores 

311 if train_files: 

312 score_lines_list_train = [load_score(path) for path in train_files] 

313 if dev_files: 

314 score_lines_list_dev = [load_score(path) for path in dev_files] 

315 if eval_files: 

316 score_lines_list_eval = [load_score(path) for path in eval_files] 

317 

318 # genuine, zero effort impostor, and attack list of 

319 # train, development and evaluation data. 

320 if train_files: 

321 _, gen_lt, zei_lt, atk_lt = get_gza_from_lines_list( 

322 score_lines_list_train 

323 ) 

324 if dev_files: 

325 _, gen_ld, zei_ld, atk_ld = get_gza_from_lines_list( 

326 score_lines_list_dev 

327 ) 

328 if eval_files: 

329 _, gen_le, zei_le, atk_le = get_gza_from_lines_list( 

330 score_lines_list_eval 

331 ) 

332 

333 # check if score lines are consistent 

334 if not skip_check: 

335 if train_files: 

336 logger.info("Checking the training files for consistency ...") 

337 check_consistency(gen_lt, zei_lt, atk_lt) 

338 if dev_files: 

339 logger.info("Checking the development files for consistency ...") 

340 check_consistency(gen_ld, zei_ld, atk_ld) 

341 if eval_files: 

342 logger.info("Checking the evaluation files for consistency ...") 

343 check_consistency(gen_le, zei_le, atk_le) 

344 

345 if train_files: 

346 scores_train = get_scores(gen_lt, zei_lt, atk_lt) 

347 scores_train_lines = get_score_lines( 

348 gen_lt[0:1], zei_lt[0:1], atk_lt[0:1] 

349 ) 

350 train_neg = get_scores(zei_lt, atk_lt) 

351 train_pos = get_scores(gen_lt) 

352 else: 

353 scores_train, scores_train_lines, train_neg, train_pos = ( 

354 None, 

355 None, 

356 None, 

357 None, 

358 ) 

359 

360 if dev_files: 

361 scores_dev = get_scores(gen_ld, zei_ld, atk_ld) 

362 scores_dev_lines = get_score_lines( 

363 gen_ld[0:1], zei_ld[0:1], atk_ld[0:1] 

364 ) 

365 dev_neg = get_scores(zei_ld, atk_ld) 

366 dev_pos = get_scores(gen_ld) 

367 else: 

368 scores_dev, scores_dev_lines, dev_neg, dev_pos = None, None, None, None 

369 

370 if eval_files: 

371 scores_eval = get_scores(gen_le, zei_le, atk_le) 

372 scores_eval_lines = get_score_lines( 

373 gen_le[0:1], zei_le[0:1], atk_le[0:1] 

374 ) 

375 else: 

376 scores_eval, scores_eval_lines = None, None 

377 

378 # check for nan values 

379 found_nan = False 

380 if train_files: 

381 found_nan, nan_train, scores_train = remove_nan(scores_train, found_nan) 

382 scores_train_lines = scores_train_lines[~nan_train] 

383 found_nan, _, train_neg = remove_nan(train_neg, found_nan) 

384 found_nan, _, train_pos = remove_nan(train_pos, found_nan) 

385 if dev_files: 

386 found_nan, nan_dev, scores_dev = remove_nan(scores_dev, found_nan) 

387 scores_dev_lines = scores_dev_lines[~nan_dev] 

388 found_nan, _, dev_neg = remove_nan(dev_neg, found_nan) 

389 found_nan, _, dev_pos = remove_nan(dev_pos, found_nan) 

390 if eval_files: 

391 found_nan, nan_eval, scores_eval = remove_nan(scores_eval, found_nan) 

392 scores_eval_lines = scores_eval_lines[~nan_eval] 

393 

394 if found_nan: 

395 logger.warning("Some nan values were removed.") 

396 

397 routine_fusion( 

398 algorithm, 

399 model_file, 

400 scores_train_lines, 

401 scores_train, 

402 train_neg, 

403 train_pos, 

404 fused_train_file, 

405 scores_dev_lines, 

406 scores_dev, 

407 dev_neg, 

408 dev_pos, 

409 fused_dev_file, 

410 scores_eval_lines, 

411 scores_eval, 

412 fused_eval_file, 

413 force, 

414 do_training=do_training, 

415 )