Coverage for src/bob/pad/base/script/pad_figure.py: 58%

213 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2023-12-06 21:56 +0100

1"""Runs error analysis on score sets, outputs metrics and plots""" 

2 

3import click 

4import numpy as np 

5 

6from tabulate import tabulate 

7 

8import bob.bio.base.script.figure as bio_figure 

9import bob.measure.script.figure as measure_figure 

10 

11from bob.measure import f_score, farfrr, precision_recall, roc_auc_score 

12from bob.measure.utils import get_fta_list 

13 

14from ..error_utils import apcer_bpcer, calc_threshold 

15 

16 

17def _normalize_input_scores(input_score, input_name): 

18 pos, negs = input_score 

19 # convert scores to sorted numpy arrays and keep a copy of all negatives 

20 pos = np.ascontiguousarray(pos) 

21 pos.sort() 

22 all_negs = np.ascontiguousarray([s for neg in negs.values() for s in neg]) 

23 all_negs.sort() 

24 # FTA is calculated on pos and all_negs so we remove nans from negs 

25 for k, v in negs.items(): 

26 v = np.ascontiguousarray(v) 

27 v.sort() 

28 negs[k] = v[~np.isnan(v)] 

29 neg_list, pos_list, fta_list = get_fta_list([(all_negs, pos)]) 

30 all_negs, pos, fta = neg_list[0], pos_list[0], fta_list[0] 

31 return input_name, pos, negs, all_negs, fta 

32 

33 

34class Metrics(bio_figure.Metrics): 

35 """Compute metrics from score files""" 

36 

37 def __init__(self, ctx, scores, evaluation, func_load, names): 

38 if isinstance(names, str): 

39 names = names.split(",") 

40 super(Metrics, self).__init__(ctx, scores, evaluation, func_load, names) 

41 

42 def get_thres(self, criterion, pos, negs, all_negs, far_value): 

43 return calc_threshold( 

44 criterion, 

45 pos=pos, 

46 negs=negs.values(), 

47 all_negs=all_negs, 

48 far_value=far_value, 

49 is_sorted=True, 

50 ) 

51 

52 def _numbers(self, threshold, pos, negs, all_negs, fta): 

53 pais = list(negs.keys()) 

54 apcer_pais, apcer_ap, bpcer = apcer_bpcer( 

55 threshold, pos, *[negs[k] for k in pais] 

56 ) 

57 apcer_pais = {k: apcer_pais[i] for i, k in enumerate(pais)} 

58 acer = (apcer_ap + bpcer) / 2.0 

59 fpr, fnr = farfrr(all_negs, pos, threshold) 

60 hter = (fpr + fnr) / 2.0 

61 far = fpr * (1 - fta) 

62 frr = fta + fnr * (1 - fta) 

63 

64 nn = all_negs.shape[0] # number of attack 

65 fp = int(round(fpr * nn)) # number of false positives 

66 np = pos.shape[0] # number of bonafide 

67 fn = int(round(fnr * np)) # number of false negatives 

68 

69 # precision and recall 

70 precision, recall = precision_recall(all_negs, pos, threshold) 

71 

72 # f_score 

73 f1_score = f_score(all_negs, pos, threshold, 1) 

74 

75 # auc 

76 auc = roc_auc_score(all_negs, pos) 

77 auc_log = roc_auc_score(all_negs, pos, log_scale=True) 

78 

79 metrics = dict( 

80 apcer_pais=apcer_pais, 

81 apcer_ap=apcer_ap, 

82 bpcer=bpcer, 

83 acer=acer, 

84 fta=fta, 

85 fpr=fpr, 

86 fnr=fnr, 

87 hter=hter, 

88 far=far, 

89 frr=frr, 

90 fp=fp, 

91 nn=nn, 

92 fn=fn, 

93 np=np, 

94 precision=precision, 

95 recall=recall, 

96 f1_score=f1_score, 

97 auc=auc, 

98 ) 

99 metrics["auc-log-scale"] = auc_log 

100 return metrics 

101 

102 def _strings(self, metrics): 

103 n_dec = ".%df" % self._decimal 

104 for k, v in metrics.items(): 

105 if k in ("precision", "recall", "f1_score", "auc", "auc-log-scale"): 

106 metrics[k] = "%s" % format(v, n_dec) 

107 elif k in ("np", "nn", "fp", "fn"): 

108 continue 

109 elif k in ("fpr", "fnr"): 

110 if "fp" in metrics: 

111 metrics[k] = "%s%% (%d/%d)" % ( 

112 format(100 * v, n_dec), 

113 metrics["fp" if k == "fpr" else "fn"], 

114 metrics["nn" if k == "fpr" else "np"], 

115 ) 

116 else: 

117 metrics[k] = "%s%%" % format(100 * v, n_dec) 

118 elif k == "apcer_pais": 

119 metrics[k] = { 

120 k1: "%s%%" % format(100 * v1, n_dec) for k1, v1 in v.items() 

121 } 

122 else: 

123 metrics[k] = "%s%%" % format(100 * v, n_dec) 

124 

125 return metrics 

126 

127 def _get_all_metrics(self, idx, input_scores, input_names): 

128 """Compute all metrics for dev and eval scores""" 

129 for i, (score, name) in enumerate(zip(input_scores, input_names)): 

130 input_scores[i] = _normalize_input_scores(score, name) 

131 

132 dev_file, dev_pos, dev_negs, dev_all_negs, dev_fta = input_scores[0] 

133 if self._eval: 

134 ( 

135 eval_file, 

136 eval_pos, 

137 eval_negs, 

138 eval_all_negs, 

139 eval_fta, 

140 ) = input_scores[1] 

141 

142 threshold = ( 

143 self.get_thres( 

144 self._criterion, dev_pos, dev_negs, dev_all_negs, self._far 

145 ) 

146 if self._thres is None 

147 else self._thres[idx] 

148 ) 

149 

150 title = self._legends[idx] if self._legends is not None else None 

151 if self._thres is None: 

152 far_str = "" 

153 if self._criterion == "far" and self._far is not None: 

154 far_str = str(self._far) 

155 click.echo( 

156 "[Min. criterion: %s %s] Threshold on Development set `%s`: %e" 

157 % ( 

158 self._criterion.upper(), 

159 far_str, 

160 title or dev_file, 

161 threshold, 

162 ), 

163 file=self.log_file, 

164 ) 

165 else: 

166 click.echo( 

167 "[Min. criterion: user provided] Threshold on " 

168 "Development set `%s`: %e" % (dev_file or title, threshold), 

169 file=self.log_file, 

170 ) 

171 

172 res = [] 

173 res.append( 

174 self._strings( 

175 self._numbers( 

176 threshold, dev_pos, dev_negs, dev_all_negs, dev_fta 

177 ) 

178 ) 

179 ) 

180 

181 if self._eval: 

182 # computes statistics for the eval set based on the threshold a priori 

183 res.append( 

184 self._strings( 

185 self._numbers( 

186 threshold, eval_pos, eval_negs, eval_all_negs, eval_fta 

187 ) 

188 ) 

189 ) 

190 else: 

191 res.append(None) 

192 

193 return res 

194 

195 def compute(self, idx, input_scores, input_names): 

196 """Compute metrics for the given criteria""" 

197 title = self._legends[idx] if self._legends is not None else None 

198 all_metrics = self._get_all_metrics(idx, input_scores, input_names) 

199 headers = [" " or title, "Development"] 

200 if self._eval: 

201 headers.append("Evaluation") 

202 rows = [] 

203 

204 for name in self.names: 

205 if name == "apcer_pais": 

206 for k, v in all_metrics[0][name].items(): 

207 print_name = f"APCER ({k})" 

208 rows += [[print_name, v]] 

209 if self._eval: 

210 rows[-1].append(all_metrics[1][name][k]) 

211 continue 

212 print_name = name.upper() 

213 rows += [[print_name, all_metrics[0][name]]] 

214 if self._eval: 

215 rows[-1].append(all_metrics[1][name]) 

216 

217 click.echo(tabulate(rows, headers, self._tablefmt), file=self.log_file) 

218 

219 

220class MultiMetrics(Metrics): 

221 """Compute metrics from score files""" 

222 

223 def __init__(self, ctx, scores, evaluation, func_load, names): 

224 super(MultiMetrics, self).__init__( 

225 ctx, scores, evaluation, func_load, names=names 

226 ) 

227 self.rows = [] 

228 self.headers = None 

229 self.pais = None 

230 

231 def _compute_headers(self, pais): 

232 names = list(self.names) 

233 if "apcer_pais" in names: 

234 idx = names.index("apcer_pais") 

235 names = ( 

236 [n.upper() for n in names[:idx]] 

237 + self.pais 

238 + [n.upper() for n in names[idx + 1 :]] 

239 ) 

240 self.headers = ["Methods"] + names 

241 if self._eval and "hter" in self.names: 

242 self.headers.insert(1, "HTER (dev)") 

243 

244 def _strings(self, metrics): 

245 formatted_metrics = dict() 

246 for name in self.names: 

247 if name == "apcer_pais": 

248 for pai in self.pais: 

249 mean = metrics[pai].mean() 

250 std = metrics[pai].std() 

251 mean = super()._strings({pai: mean})[pai] 

252 std = super()._strings({pai: std})[pai] 

253 formatted_metrics[pai] = f"{mean} ({std})" 

254 else: 

255 mean = metrics[name].mean() 

256 std = metrics[name].std() 

257 mean = super()._strings({name: mean})[name] 

258 std = super()._strings({name: std})[name] 

259 formatted_metrics[name] = f"{mean} ({std})" 

260 

261 return formatted_metrics 

262 

263 def _structured_array(self, metrics): 

264 names = list(metrics[0].keys()) 

265 if "apcer_pais" in names: 

266 idx = names.index("apcer_pais") 

267 pais = list( 

268 f"APCER ({pai})" for pai in metrics[0]["apcer_pais"].keys() 

269 ) 

270 names = names[:idx] + pais + names[idx + 1 :] 

271 self.pais = self.pais or pais 

272 formats = [float] * len(names) 

273 dtype = dict(names=names, formats=formats) 

274 array = [] 

275 for each in metrics: 

276 array.append([]) 

277 for k, v in each.items(): 

278 if k == "apcer_pais": 

279 array[-1].extend(list(v.values())) 

280 else: 

281 array[-1].append(v) 

282 array = [tuple(a) for a in array] 

283 return np.array(array, dtype=dtype) 

284 

285 def compute(self, idx, input_scores, input_names): 

286 """Computes the average of metrics over several protocols.""" 

287 for i, (score, name) in enumerate(zip(input_scores, input_names)): 

288 input_scores[i] = _normalize_input_scores(score, name) 

289 

290 step = 2 if self._eval else 1 

291 self._dev_metrics = [] 

292 self._thresholds = [] 

293 for scores in input_scores[::step]: 

294 name, pos, negs, all_negs, fta = scores 

295 threshold = ( 

296 self.get_thres(self._criterion, pos, negs, all_negs, self._far) 

297 if self._thres is None 

298 else self._thres[idx] 

299 ) 

300 self._thresholds.append(threshold) 

301 self._dev_metrics.append( 

302 self._numbers(threshold, pos, negs, all_negs, fta) 

303 ) 

304 self._dev_metrics = self._structured_array(self._dev_metrics) 

305 

306 if self._eval: 

307 self._eval_metrics = [] 

308 for i, scores in enumerate(input_scores[1::step]): 

309 name, pos, negs, all_negs, fta = scores 

310 threshold = self._thresholds[i] 

311 self._eval_metrics.append( 

312 self._numbers(threshold, pos, negs, all_negs, fta) 

313 ) 

314 self._eval_metrics = self._structured_array(self._eval_metrics) 

315 

316 title = self._legends[idx] if self._legends is not None else name 

317 

318 dev_metrics = self._strings(self._dev_metrics) 

319 

320 if self._eval and "hter" in dev_metrics: 

321 self.rows.append([title, dev_metrics["hter"]]) 

322 elif not self._eval: 

323 row = [title] 

324 for name in self.names: 

325 if name == "apcer_pais": 

326 for pai in self.pais: 

327 row += [dev_metrics[pai]] 

328 else: 

329 row += [dev_metrics[name]] 

330 self.rows.append(row) 

331 else: 

332 self.rows.append([title]) 

333 

334 if self._eval: 

335 eval_metrics = self._strings(self._eval_metrics) 

336 row = [] 

337 for name in self.names: 

338 if name == "apcer_pais": 

339 for pai in self.pais: 

340 row += [eval_metrics[pai]] 

341 else: 

342 row += [eval_metrics[name]] 

343 

344 self.rows[-1].extend(row) 

345 

346 # compute header based on found PAI names 

347 if self.headers is None: 

348 self._compute_headers(self.pais) 

349 

350 def end_process(self): 

351 click.echo( 

352 tabulate(self.rows, self.headers, self._tablefmt), 

353 file=self.log_file, 

354 ) 

355 super(MultiMetrics, self).end_process() 

356 

357 

358class Roc(bio_figure.Roc): 

359 """ROC for PAD""" 

360 

361 def __init__(self, ctx, scores, evaluation, func_load): 

362 super(Roc, self).__init__(ctx, scores, evaluation, func_load) 

363 self._x_label = ctx.meta.get("x_label") or "APCER" 

364 default_y_label = "1-BPCER" if self._tpr else "BPCER" 

365 self._y_label = ctx.meta.get("y_label") or default_y_label 

366 

367 

368class Det(bio_figure.Det): 

369 def __init__(self, ctx, scores, evaluation, func_load): 

370 super(Det, self).__init__(ctx, scores, evaluation, func_load) 

371 self._x_label = ctx.meta.get("x_label") or "APCER (%)" 

372 self._y_label = ctx.meta.get("y_label") or "BPCER (%)" 

373 

374 

375class Hist(measure_figure.Hist): 

376 """Histograms for PAD""" 

377 

378 def _setup_hist(self, neg, pos): 

379 self._title_base = "PAD" 

380 self._density_hist(pos[0], n=0, label="Bona-fide", color="C1") 

381 self._density_hist( 

382 neg[0], 

383 n=1, 

384 label="Presentation attack", 

385 alpha=0.4, 

386 color="C7", 

387 hatch="\\\\", 

388 )