Coverage for src/bob/pad/base/script/cross.py: 0%

125 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2023-12-06 21:56 +0100

1"""Prints Cross-db metrics analysis 

2""" 

3import itertools 

4import json 

5import logging 

6import math 

7import os 

8 

9import click 

10import jinja2 

11import yaml 

12 

13from clapper.click import log_parameters, verbosity_option 

14from tabulate import tabulate 

15 

16from bob.bio.base.score.load import get_negatives_positives, load_score 

17from bob.measure import farfrr 

18from bob.measure.script import common_options 

19from bob.measure.utils import get_fta 

20 

21from ..error_utils import calc_threshold 

22from .pad_commands import CRITERIA 

23 

24logger = logging.getLogger(__name__) 

25 

26 

27def bool_option(name, short_name, desc, dflt=False, **kwargs): 

28 """Generic provider for boolean options 

29 

30 Parameters 

31 ---------- 

32 name : str 

33 name of the option 

34 short_name : str 

35 short name for the option 

36 desc : str 

37 short description for the option 

38 dflt : bool or None 

39 Default value 

40 **kwargs 

41 All kwargs are passed to click.option. 

42 

43 Returns 

44 ------- 

45 ``callable`` 

46 A decorator to be used for adding this option. 

47 """ 

48 

49 def custom_bool_option(func): 

50 def callback(ctx, param, value): 

51 ctx.meta[name.replace("-", "_")] = value 

52 return value 

53 

54 return click.option( 

55 "-%s/-n%s" % (short_name, short_name), 

56 "--%s/--no-%s" % (name, name), 

57 default=dflt, 

58 help=desc, 

59 show_default=True, 

60 callback=callback, 

61 is_eager=True, 

62 **kwargs, 

63 )(func) 

64 

65 return custom_bool_option 

66 

67 

68def _ordered_load(stream, Loader=yaml.Loader, object_pairs_hook=dict): 

69 """Loads the contents of the YAML stream into :py:class:`collections.OrderedDict`'s 

70 

71 See: https://stackoverflow.com/questions/5121931/in-python-how-can-you-load-yaml-mappings-as-ordereddicts 

72 

73 """ 

74 

75 class OrderedLoader(Loader): 

76 pass 

77 

78 def construct_mapping(loader, node): 

79 loader.flatten_mapping(node) 

80 return object_pairs_hook(loader.construct_pairs(node)) 

81 

82 OrderedLoader.add_constructor( 

83 yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG, construct_mapping 

84 ) 

85 

86 return yaml.load(stream, OrderedLoader) 

87 

88 

89def expand(data): 

90 """Generates configuration sets based on the YAML input contents 

91 

92 For an introduction to the YAML mark-up, just search the net. Here is one of 

93 its references: https://en.wikipedia.org/wiki/YAML 

94 

95 A configuration set corresponds to settings for **all** variables in the 

96 input template that needs replacing. For example, if your template mentions 

97 the variables ``name`` and ``version``, then each configuration set should 

98 yield values for both ``name`` and ``version``. 

99 

100 For example: 

101 

102 .. code-block:: yaml 

103 

104 name: [john, lisa] 

105 version: [v1, v2] 

106 

107 

108 This should yield to the following configuration sets: 

109 

110 .. code-block:: python 

111 

112 [ 

113 {'name': 'john', 'version': 'v1'}, 

114 {'name': 'john', 'version': 'v2'}, 

115 {'name': 'lisa', 'version': 'v1'}, 

116 {'name': 'lisa', 'version': 'v2'}, 

117 ] 

118 

119 

120 Each key in the input file should correspond to either an object or a YAML 

121 array. If the object is a list, then we'll iterate over it for every possible 

122 combination of elements in the lists. If the element in question is not a 

123 list, then it is considered unique and repeated for each yielded 

124 configuration set. Example 

125 

126 .. code-block:: yaml 

127 

128 name: [john, lisa] 

129 version: [v1, v2] 

130 text: > 

131 hello, 

132 world! 

133 

134 Should yield to the following configuration sets: 

135 

136 .. code-block:: python 

137 

138 [ 

139 {'name': 'john', 'version': 'v1', 'text': 'hello, world!'}, 

140 {'name': 'john', 'version': 'v2', 'text': 'hello, world!'}, 

141 {'name': 'lisa', 'version': 'v1', 'text': 'hello, world!'}, 

142 {'name': 'lisa', 'version': 'v2', 'text': 'hello, world!'}, 

143 ] 

144 

145 Keys starting with one `_` (underscore) are treated as "unique" objects as 

146 well. Example: 

147 

148 .. code-block:: yaml 

149 

150 name: [john, lisa] 

151 version: [v1, v2] 

152 _unique: [i1, i2] 

153 

154 Should yield to the following configuration sets: 

155 

156 .. code-block:: python 

157 

158 [ 

159 {'name': 'john', 'version': 'v1', '_unique': ['i1', 'i2']}, 

160 {'name': 'john', 'version': 'v2', '_unique': ['i1', 'i2']}, 

161 {'name': 'lisa', 'version': 'v1', '_unique': ['i1', 'i2']}, 

162 {'name': 'lisa', 'version': 'v2', '_unique': ['i1', 'i2']}, 

163 ] 

164 

165 

166 Parameters: 

167 

168 data (str): YAML data to be parsed 

169 

170 

171 Yields: 

172 

173 dict: A dictionary of key-value pairs for building the templates 

174 

175 """ 

176 

177 data = _ordered_load(data, yaml.SafeLoader) 

178 

179 # separates "unique" objects from the ones we have to iterate 

180 # pre-assemble return dictionary 

181 iterables = dict() 

182 unique = dict() 

183 for key, value in data.items(): 

184 if isinstance(value, list) and not key.startswith("_"): 

185 iterables[key] = value 

186 else: 

187 unique[key] = value 

188 

189 # generates all possible combinations of iterables 

190 for values in itertools.product(*iterables.values()): 

191 retval = dict(unique) 

192 keys = list(iterables.keys()) 

193 retval.update(dict(zip(keys, values))) 

194 yield retval 

195 

196 

197@click.command( 

198 epilog="""\b 

199Examples: 

200 $ bob pad cross 'results/{{ evaluation.database }}/{{ algorithm }}/{{ evaluation.protocol }}/scores/scores-{{ group }}' \ 

201 -td replaymobile \ 

202 -d replaymobile -p grandtest \ 

203 -d oulunpu -p Protocol_1 \ 

204 -a replaymobile_grandtest_frame-diff-svm \ 

205 -a replaymobile_grandtest_qm-svm-64 \ 

206 -a replaymobile_grandtest_lbp-svm-64 \ 

207 > replaymobile.rst & 

208""" 

209) 

210@click.argument("score_jinja_template") 

211@click.option( 

212 "-d", 

213 "--database", 

214 "databases", 

215 multiple=True, 

216 required=True, 

217 show_default=True, 

218 help="Names of the evaluation databases", 

219) 

220@click.option( 

221 "-p", 

222 "--protocol", 

223 "protocols", 

224 multiple=True, 

225 required=True, 

226 show_default=True, 

227 help="Names of the protocols of the evaluation databases", 

228) 

229@click.option( 

230 "-a", 

231 "--algorithm", 

232 "algorithms", 

233 multiple=True, 

234 required=True, 

235 show_default=True, 

236 help="Names of the algorithms", 

237) 

238@click.option( 

239 "-n", 

240 "--names", 

241 type=click.File("r"), 

242 help="Name of algorithms to show in the table. Provide a path " 

243 "to a json file maps algorithm names to names that you want to " 

244 "see in the table.", 

245) 

246@click.option( 

247 "-td", 

248 "--train-database", 

249 required=True, 

250 help="The database that was used to train the algorithms.", 

251) 

252@click.option( 

253 "-pn", 

254 "--pai-names", 

255 type=click.File("r"), 

256 help="Name of PAIs to compute the errors per PAI. Provide a path " 

257 "to a json file maps attack_type in scores to PAIs that you want to " 

258 "see in the table.", 

259) 

260@click.option( 

261 "-g", 

262 "--group", 

263 "groups", 

264 multiple=True, 

265 show_default=True, 

266 default=["train", "dev", "eval"], 

267) 

268@bool_option("sort", "s", "whether the table should be sorted.", True) 

269@common_options.criterion_option(lcriteria=CRITERIA, check=False) 

270@common_options.far_option() 

271@common_options.table_option() 

272@common_options.output_log_metric_option() 

273@common_options.decimal_option(dflt=2, short="-dec") 

274@verbosity_option(logger) 

275@click.pass_context 

276def cross( 

277 ctx, 

278 score_jinja_template, 

279 databases, 

280 protocols, 

281 algorithms, 

282 names, 

283 train_database, 

284 pai_names, 

285 groups, 

286 sort, 

287 decimal, 

288 verbose, 

289 **kwargs, 

290): 

291 """Cross-db analysis metrics""" 

292 log_parameters(logger) 

293 

294 names = {} if names is None else json.load(names) 

295 

296 env = jinja2.Environment(undefined=jinja2.StrictUndefined) 

297 

298 data = { 

299 "evaluation": [ 

300 {"database": db, "protocol": proto} 

301 for db, proto in zip(databases, protocols) 

302 ], 

303 "algorithm": algorithms, 

304 "group": groups, 

305 } 

306 

307 metrics = {} 

308 

309 for variables in expand(yaml.dump(data, Dumper=yaml.SafeDumper)): 

310 logger.debug(variables) 

311 

312 score_path = env.from_string(score_jinja_template).render(variables) 

313 logger.info(score_path) 

314 

315 database, protocol, algorithm, group = ( 

316 variables["evaluation"]["database"], 

317 variables["evaluation"]["protocol"], 

318 variables["algorithm"], 

319 variables["group"], 

320 ) 

321 

322 # if algorithm name does not have train_database name in it. 

323 if train_database not in algorithm and database != train_database: 

324 score_path = score_path.replace( 

325 algorithm, database + "_" + algorithm 

326 ) 

327 logger.info("Score path changed to: %s", score_path) 

328 

329 if not os.path.exists(score_path): 

330 metrics[(database, protocol, algorithm, group)] = ( 

331 float("nan"), 

332 ) * 5 

333 continue 

334 

335 scores = load_score(score_path) 

336 neg, pos = get_negatives_positives(scores) 

337 (neg, pos), fta = get_fta((neg, pos)) 

338 

339 if group == "eval": 

340 threshold = metrics[(database, protocol, algorithm, "dev")][1] 

341 else: 

342 try: 

343 threshold = calc_threshold( 

344 ctx.meta["criterion"], 

345 pos, 

346 [neg], 

347 neg, 

348 ctx.meta["far_value"], 

349 ) 

350 except RuntimeError: 

351 logger.error("Something wrong with {}".format(score_path)) 

352 raise 

353 

354 far, frr = farfrr(neg, pos, threshold) 

355 hter = (far + frr) / 2 

356 

357 metrics[(database, protocol, algorithm, group)] = ( 

358 hter, 

359 threshold, 

360 fta, 

361 far, 

362 frr, 

363 ) 

364 

365 logger.debug("metrics: %s", metrics) 

366 

367 headers = ["Algorithms"] 

368 for db in databases: 

369 headers += [db + "\nEER_t", "\nEER_d", "\nAPCER", "\nBPCER", "\nACER"] 

370 rows = [] 

371 

372 # sort the algorithms based on HTER test, EER dev, EER train 

373 train_protocol = protocols[databases.index(train_database)] 

374 if sort: 

375 

376 def sort_key(alg): 

377 r = [] 

378 for grp in ("eval", "dev", "train"): 

379 hter = metrics[(train_database, train_protocol, alg, group)][0] 

380 r.append(1 if math.isnan(hter) else hter) 

381 return tuple(r) 

382 

383 algorithms = sorted(algorithms, key=sort_key) 

384 

385 for algorithm in algorithms: 

386 name = algorithm.replace(train_database + "_", "") 

387 name = name.replace(train_protocol + "_", "") 

388 name = names.get(name, name) 

389 rows.append([name]) 

390 for database, protocol in zip(databases, protocols): 

391 cell = [] 

392 for group in groups: 

393 hter, threshold, fta, far, frr = metrics[ 

394 (database, protocol, algorithm, group) 

395 ] 

396 if group == "eval": 

397 cell += [far, frr, hter] 

398 else: 

399 cell += [hter] 

400 cell = [round(c * 100, decimal) for c in cell] 

401 rows[-1].extend(cell) 

402 

403 title = " Trained on {} ".format(train_database) 

404 title_line = "\n" + "=" * len(title) + "\n" 

405 # open log file for writing if any 

406 ctx.meta["log"] = ( 

407 ctx.meta["log"] 

408 if ctx.meta["log"] is None 

409 else open(ctx.meta["log"], "w") 

410 ) 

411 click.echo(title_line + title + title_line, file=ctx.meta["log"]) 

412 click.echo( 

413 tabulate(rows, headers, ctx.meta["tablefmt"], floatfmt=".1f"), 

414 file=ctx.meta["log"], 

415 )