Coverage for src/bob/bio/base/script/vuln_figure.py: 92%

499 statements  

« prev     ^ index     » next       coverage.py v7.6.5, created at 2024-11-14 21:41 +0100

1"""Runs error analysis on score sets, outputs metrics and plots""" 

2 

3import logging 

4 

5import click 

6import matplotlib.pyplot as mpl 

7import numpy as np 

8 

9from tabulate import tabulate 

10 

11import bob.measure.script.figure as measure_figure 

12 

13from bob.measure import ( 

14 f_score, 

15 far_threshold, 

16 farfrr, 

17 frr_threshold, 

18 min_weighted_error_rate_threshold, 

19 plot, 

20 ppndf, 

21 precision_recall, 

22 roc_auc_score, 

23) 

24from bob.measure.utils import get_thres, remove_nan 

25 

26from . import error_utils 

27 

28logger = logging.getLogger(__name__) 

29 

30 

31def clean_scores(input_scores): 

32 """Returns a dict with each scores groups cleaned 

33 

34 Parameters 

35 ---------- 

36 input_scores: dict 

37 

38 Returns 

39 ------- 

40 clean_scores: dict 

41 """ 

42 clean_scores = {} 

43 for key, scores in input_scores.items(): 

44 clean_scores[key], _, _ = remove_nan(scores) 

45 return clean_scores 

46 

47 

48class Metrics(measure_figure.Metrics): 

49 """Compute metrics from score files 

50 

51 Attributes 

52 ---------- 

53 

54 names: dict {str:str} 

55 pairs of metrics keys and corresponding row titles to display. 

56 """ 

57 

58 def __init__( 

59 self, 

60 ctx, 

61 scores, 

62 evaluation, 

63 func_load, 

64 names={ 

65 "fta": "Licit Failure to Acquire", 

66 "fmr": "Licit False Match Rate", 

67 "fnmr": "Licit False Non Match Rate", 

68 "far": "Licit False Accept Rate", 

69 "frr": "Licit False Reject Rate", 

70 "hter": "Licit Half Total Error Rate", 

71 "iapmr": "Attack Presentation Match Rate", 

72 }, 

73 **kwargs, 

74 ): 

75 super(Metrics, self).__init__( 

76 ctx, scores, evaluation, func_load, names, **kwargs 

77 ) 

78 

79 def _get_all_metrics(self, idx, input_scores, input_names): 

80 """Compute all metrics for dev and eval scores""" 

81 # Parse input and remove/count failed samples (NaN) 

82 dev_neg, dev_neg_na, dev_neg_count = remove_nan( 

83 input_scores[0]["licit_neg"] 

84 ) 

85 dev_pos, dev_pos_na, dev_pos_count = remove_nan( 

86 input_scores[0]["licit_pos"] 

87 ) 

88 dev_spoof, dev_spoof_na, dev_spoof_count = remove_nan( 

89 input_scores[0]["spoof"] 

90 ) 

91 dev_fta = (dev_neg_na + dev_pos_na + dev_spoof_na) / ( 

92 dev_neg_count + dev_pos_count + dev_spoof_count 

93 ) 

94 if self._eval: 

95 eval_neg, eval_neg_na, eval_neg_count = remove_nan( 

96 input_scores[1]["licit_neg"] 

97 ) 

98 eval_pos, eval_pos_na, eval_pos_count = remove_nan( 

99 input_scores[1]["licit_pos"] 

100 ) 

101 eval_spoof, eval_spoof_na, eval_spoof_count = remove_nan( 

102 input_scores[1]["spoof"] 

103 ) 

104 eval_fta = (eval_neg_na + eval_pos_na + eval_spoof_na) / ( 

105 eval_neg_count + eval_pos_count + eval_spoof_count 

106 ) 

107 dev_file = input_names[0] 

108 

109 # Compute threshold on dev set 

110 threshold = ( 

111 self.get_thres(self._criterion, dev_neg, dev_pos, self._far) 

112 if self._thres is None 

113 else self._thres[idx] 

114 ) 

115 

116 title = self._legends[idx] if self._legends is not None else None 

117 if self._thres is None: 

118 far_str = "" 

119 if self._criterion == "far" and self._far is not None: 

120 far_str = str(self._far) 

121 click.echo( 

122 f"[Min. criterion: {self._criterion.upper()} {far_str}] " 

123 f"Threshold on Development set `{title or dev_file}`: {threshold:e}", 

124 file=self.log_file, 

125 ) 

126 else: 

127 click.echo( 

128 "[Min. criterion: user provided] Threshold on " 

129 f"Development set `{dev_file or title}`: {threshold:e}", 

130 file=self.log_file, 

131 ) 

132 

133 res = [] 

134 res.append( 

135 self._strings( 

136 self._numbers(dev_neg, dev_pos, dev_spoof, threshold, dev_fta) 

137 ) 

138 ) 

139 

140 if self._eval: 

141 # computes statistics for the eval set based on the threshold a 

142 # priori computed on the dev set 

143 res.append( 

144 self._strings( 

145 self._numbers( 

146 eval_neg, eval_pos, eval_spoof, threshold, eval_fta 

147 ) 

148 ) 

149 ) 

150 else: 

151 res.append(None) 

152 

153 return res 

154 

155 def _numbers(self, neg, pos, spoof, threshold, fta): 

156 """Computes each metric value""" 

157 # fpr and fnr 

158 fmr, fnmr = farfrr(neg, pos, threshold) 

159 hter = (fmr + fnmr) / 2.0 

160 far = fmr * (1 - fta) 

161 frr = fta + fnmr * (1 - fta) 

162 

163 ni = neg.shape[0] # number of impostors 

164 fm = int(round(fmr * ni)) # number of false accepts 

165 nc = pos.shape[0] # number of clients 

166 fnm = int(round(fnmr * nc)) # number of false rejects 

167 

168 # precision and recall 

169 precision, recall = precision_recall(neg, pos, threshold) 

170 

171 # f_score 

172 f1_score = f_score(neg, pos, threshold, 1) 

173 

174 # AUC ROC 

175 auc = roc_auc_score(neg, pos) 

176 auc_log = roc_auc_score(neg, pos, log_scale=True) 

177 

178 # IAPMR at threshold 

179 iapmr, _ = farfrr(spoof, [0.0], threshold) 

180 spoof_total = len(spoof) 

181 spoof_match = int(round(iapmr * spoof_total)) 

182 

183 return { 

184 "fta": fta, 

185 "fmr": fmr, 

186 "fnmr": fnmr, 

187 "hter": hter, 

188 "far": far, 

189 "frr": frr, 

190 "fm": fm, 

191 "ni": ni, 

192 "fnm": fnm, 

193 "nc": nc, 

194 "precision": precision, 

195 "recall": recall, 

196 "f1_score": f1_score, 

197 "auc": auc, 

198 "auc_log": auc_log, 

199 "iapmr": iapmr, 

200 "spoof_match": spoof_match, 

201 "spoof_total": spoof_total, 

202 } 

203 

204 def _strings(self, metrics): 

205 """Formats the metrics values into strings""" 

206 return { 

207 "fta": f"{100 * metrics['fta']:.{self._decimal}f}%", 

208 "fmr": f"{100 * metrics['fmr']:.{self._decimal}f}% ({metrics['fm']}/{metrics['ni']})", 

209 "fnmr": f"{100 * metrics['fnmr']:.{self._decimal}f}% ({metrics['fnm']}/{metrics['nc']})", 

210 "far": f"{100 * metrics['far']:.{self._decimal}f}%", 

211 "frr": f"{100 * metrics['frr']:.{self._decimal}f}%", 

212 "hter": f"{100 * metrics['hter']:.{self._decimal}f}%", 

213 "precision": f"{metrics['precision']:.{self._decimal}f}", 

214 "recall": f"{metrics['recall']:.{self._decimal}f}", 

215 "f1_score": f"{metrics['f1_score']:.{self._decimal}f}", 

216 "auc": f"{metrics['auc']:.{self._decimal}f}", 

217 "auc_log": f"{metrics['auc_log']:.{self._decimal}f}", 

218 "iapmr": f"{100 * metrics['iapmr']:.{self._decimal}f}% ({metrics['spoof_match']}/{metrics['spoof_total']})", 

219 } 

220 

221 def compute(self, idx, input_scores, input_names): 

222 """Compute metrics thresholds and tables for given system inputs""" 

223 # Title and headers 

224 title = self._legends[idx] if self._legends is not None else None 

225 headers = ["" or title, "Dev. %s" % input_names[0]] 

226 if self._eval and input_scores[1] is not None: 

227 headers.append("eval % s" % input_names[1]) 

228 

229 # Tables rows 

230 all_metrics = self._get_all_metrics(idx, input_scores, input_names) 

231 headers = [" " or title, "Development"] 

232 

233 rows = [] 

234 for key, name in self.names.items(): 

235 if key not in all_metrics[0]: 

236 logger.warning(f"{key} not present in metrics.") 

237 rows.append([name, all_metrics[0].get(key, "N/A")]) 

238 

239 if self._eval: 

240 # computes statistics for the eval set based on the threshold a 

241 # priori 

242 headers.append("Evaluation") 

243 for row, key in zip(rows, self.names.keys()): 

244 row.append(all_metrics[1].get(key, "N/A")) 

245 

246 click.echo(tabulate(rows, headers, self._tablefmt), file=self.log_file) 

247 

248 

249def _iapmr_dot(threshold, iapmr, real_data, **kwargs): 

250 # plot a dot on threshold versus IAPMR line and show IAPMR as a number 

251 axlim = mpl.axis() 

252 mpl.plot(threshold, 100.0 * iapmr, "o", color="C3", **kwargs) 

253 if not real_data: 

254 mpl.annotate( 

255 "IAPMR at\noperating point", 

256 xy=(threshold, 100.0 * iapmr), 

257 xycoords="data", 

258 xytext=(0.85, 0.6), 

259 textcoords="axes fraction", 

260 color="black", 

261 size="large", 

262 arrowprops=dict(facecolor="black", shrink=0.05, width=2), 

263 horizontalalignment="center", 

264 verticalalignment="top", 

265 ) 

266 else: 

267 mpl.text( 

268 threshold + (threshold - axlim[0]) / 12, 

269 100.0 * iapmr, 

270 "%.1f%%" % (100.0 * iapmr,), 

271 color="C3", 

272 ) 

273 

274 

275def _iapmr_line_plot(scores, n_points=100, **kwargs): 

276 axlim = mpl.axis() 

277 step = (axlim[1] - axlim[0]) / float(n_points) 

278 thres = [(k * step) + axlim[0] for k in range(2, n_points - 1)] 

279 mix_prob_y = [] 

280 for k in thres: 

281 mix_prob_y.append(100.0 * error_utils.calc_pass_rate(k, scores)) 

282 

283 mpl.plot(thres, mix_prob_y, label="IAPMR", color="C3", **kwargs) 

284 

285 

286def _iapmr_plot(scores, threshold, iapmr, real_data, **kwargs): 

287 _iapmr_dot(threshold, iapmr, real_data, **kwargs) 

288 _iapmr_line_plot(scores, n_points=100, **kwargs) 

289 

290 

291class HistVuln(measure_figure.Hist): 

292 """Histograms for vulnerability""" 

293 

294 def __init__(self, ctx, scores, evaluation, func_load): 

295 super(HistVuln, self).__init__( 

296 ctx, scores, evaluation, func_load, nhist_per_system=3 

297 ) 

298 

299 def _setup_hist(self, neg, pos): 

300 self._title_base = " " 

301 self._density_hist(pos[0], n=0, label="Genuine", color="C2") 

302 self._density_hist( 

303 neg[0], n=1, label="Zero-effort impostors", alpha=0.8, color="C0" 

304 ) 

305 self._density_hist( 

306 neg[1], 

307 n=2, 

308 label="Presentation attack", 

309 alpha=0.4, 

310 color="C7", 

311 hatch="\\\\", 

312 ) 

313 

314 def _get_neg_pos_thres(self, idx, input_scores, input_names): 

315 """Get scores and threshold for the given system at index idx for vuln 

316 

317 Returns 

318 ------- 

319 dev_neg, dev_pos, eval_neg, eval_pos: list of arrays 

320 The scores negatives and positives for each set. Each element 

321 contains two lists: licit [0] and spoof [1] 

322 threshold: int 

323 The value of the threshold computed on the `dev` set licit scores. 

324 """ 

325 

326 dev_scores = clean_scores(input_scores[0]) 

327 if self._eval: 

328 eval_scores = clean_scores(input_scores[1]) 

329 else: 

330 eval_scores = {"licit_neg": [], "licit_pos": [], "spoof": []} 

331 

332 threshold = ( 

333 get_thres( 

334 self._criterion, 

335 dev_scores["licit_neg"], 

336 dev_scores["licit_pos"], 

337 ) 

338 if self._thres is None 

339 else self._thres[idx] 

340 ) 

341 return ( 

342 [ 

343 dev_scores["licit_neg"], 

344 dev_scores["spoof"], 

345 ], 

346 [dev_scores["licit_pos"]], 

347 [ 

348 eval_scores["licit_neg"], 

349 eval_scores["spoof"], 

350 ], 

351 [ 

352 eval_scores["licit_pos"], 

353 ], 

354 threshold, 

355 ) 

356 

357 def _lines(self, threshold, label, neg, pos, idx, **kwargs): 

358 spoof = neg[1] 

359 neg = neg[0] 

360 pos = pos[0] 

361 # plot EER treshold vertical line 

362 super(HistVuln, self)._lines(threshold, label, neg, pos, idx, **kwargs) 

363 

364 if "iapmr_line" not in self._ctx.meta or self._ctx.meta["iapmr_line"]: 

365 # Plot iapmr_line (accepted PA vs threshold) 

366 iapmr, _ = farfrr(spoof, [0.0], threshold) 

367 ax2 = mpl.twinx() 

368 # we never want grid lines on axis 2 

369 ax2.grid(False) 

370 real_data = self._ctx.meta.get("real_data", True) 

371 _iapmr_plot(spoof, threshold, iapmr, real_data=real_data) 

372 n = idx % self._step_print 

373 col = n % self._ncols 

374 rest_print = ( 

375 self.n_systems - int(idx / self._step_print) * self._step_print 

376 ) 

377 if col == self._ncols - 1 or n == rest_print - 1: 

378 ax2.set_ylabel("IAPMR (%)", color="C3") 

379 ax2.tick_params(axis="y", colors="C3") 

380 ax2.yaxis.label.set_color("C3") 

381 ax2.spines["right"].set_color("C3") 

382 

383 

384class Epc(measure_figure.PlotBase): 

385 """Handles the plotting of EPC""" 

386 

387 def __init__(self, ctx, scores, evaluation, func_load): 

388 super(Epc, self).__init__(ctx, scores, evaluation, func_load) 

389 self._iapmr = self._ctx.meta.get("iapmr", True) 

390 self._titles = self._titles or [ 

391 "EPC and IAPMR" if self._iapmr else "EPC" 

392 ] 

393 self._x_label = self._x_label or "Weight $\\beta$" 

394 self._y_label = self._y_label or "HTER (%)" 

395 self._eval = True # always eval data with EPC 

396 self._split = False 

397 self._nb_figs = 1 

398 

399 if self._min_arg != 2: 

400 raise click.BadParameter( 

401 "You must provide 2 scores files: " "scores-{dev,eval}.csv" 

402 ) 

403 

404 def compute(self, idx, input_scores, input_names): 

405 """Plot EPC with IAPMR for vuln""" 

406 dev_scores = clean_scores(input_scores[0]) 

407 if self._eval: 

408 eval_scores = clean_scores(input_scores[1]) 

409 else: 

410 eval_scores = {"licit_neg": [], "licit_pos": [], "spoof": []} 

411 

412 mpl.gcf().clear() 

413 mpl.grid() 

414 logger.info(f"EPC using {input_names[0]} and {input_names[1]}") 

415 plot.epc( 

416 dev_scores["licit_neg"], 

417 dev_scores["licit_pos"], 

418 eval_scores["licit_neg"], 

419 eval_scores["licit_pos"], 

420 self._points, 

421 color="C0", 

422 linestyle=self._linestyles[idx], 

423 label=self._label("HTER (licit)", idx), 

424 ) 

425 mpl.xlabel(self._x_label) 

426 mpl.ylabel(self._y_label) 

427 if self._iapmr: 

428 ax1 = mpl.gca() 

429 mpl.gca().set_axisbelow(True) 

430 prob_ax = mpl.gca().twinx() 

431 step = 1.0 / float(self._points) 

432 thres = [float(k * step) for k in range(self._points)] 

433 thres.append(1.0) 

434 apply_thres = [ 

435 min_weighted_error_rate_threshold( 

436 dev_scores["licit_neg"], dev_scores["licit_pos"], t 

437 ) 

438 for t in thres 

439 ] 

440 mix_prob_y = [] 

441 for k in apply_thres: 

442 mix_prob_y.append( 

443 100.0 * error_utils.calc_pass_rate(k, eval_scores["spoof"]) 

444 ) 

445 

446 logger.info( 

447 f"IAPMR in EPC plot using {input_names[0]}, {input_names[1]}" 

448 ) 

449 mpl.plot( 

450 thres, 

451 mix_prob_y, 

452 label=self._label("IAPMR (spoof)", idx), 

453 color="C3", 

454 ) 

455 

456 prob_ax.tick_params(axis="y", colors="C3") 

457 prob_ax.yaxis.label.set_color("C3") 

458 prob_ax.spines["right"].set_color("C3") 

459 prob_ax.set_ylabel("IAPMR (%)", color="C3") 

460 prob_ax.set_axisbelow(True) 

461 ax1.yaxis.label.set_color("C0") 

462 ax1.tick_params(axis="y", colors="C0") 

463 ax1.spines["left"].set_color("C0") 

464 mpl.sca(ax1) 

465 

466 

467class Epsc(measure_figure.GridSubplot): 

468 """Handles the plotting of EPSC""" 

469 

470 def __init__(self, ctx, scores, func_load, criteria, var_param, **kwargs): 

471 evaluation = ctx.meta.get("evaluation", True) 

472 super(Epsc, self).__init__(ctx, scores, evaluation, func_load) 

473 self._iapmr = self._ctx.meta.get("iapmr", False) 

474 self._wer = self._ctx.meta.get("wer", True) 

475 self._criteria = criteria or "eer" 

476 self._var_param = var_param or "omega" 

477 self._fixed_params = ctx.meta.get("fixed_params", [0.5]) 

478 self._nb_subplots = 2 if (self._wer and self._iapmr) else 1 

479 if len(self._titles) < self._nb_figs * self._nb_subplots: 

480 self._titles = [ 

481 v for v in self._titles for _ in range(self._nb_subplots) 

482 ] 

483 self._eval = True # always eval data with EPSC 

484 self._split = False 

485 self._nb_figs = 1 

486 self._sampling = ctx.meta.get("sampling", 5) 

487 self._axis1 = None 

488 self._axis2 = None 

489 

490 if self._min_arg != 2: 

491 raise click.BadParameter( 

492 "You must provide 2 scores files: " "scores-{dev,eval}.csv" 

493 ) 

494 

495 self._ncols = 1 if self._iapmr else 0 

496 self._ncols += 1 if self._wer else 0 

497 

498 def compute(self, idx, input_scores, input_names): 

499 """Plot EPSC for vuln""" 

500 dev_scores = clean_scores(input_scores[0]) 

501 if self._eval: 

502 eval_scores = clean_scores(input_scores[1]) 

503 else: 

504 eval_scores = {"licit_neg": [], "licit_pos": [], "spoof": []} 

505 

506 merge_sys = ( 

507 self._fixed_params is None or len(self._fixed_params) == 1 

508 ) and self.n_systems > 1 

509 legend = "" 

510 if self._legends is not None and idx < len(self._legends): 

511 legend = self._legends[idx] 

512 elif self.n_systems > 1: 

513 legend = "Sys%d" % (idx + 1) 

514 

515 if self._axis1 is None: 

516 # axes should only be created once 

517 self._axis1 = self.create_subplot(0) 

518 if self._ncols == 2: 

519 self._axis2 = self.create_subplot(1) 

520 else: 

521 self._axis2 = self._axis1 

522 points = 10 

523 for pi, fp in enumerate(self._fixed_params): 

524 if merge_sys: 

525 assert pi == 0 

526 pi = idx 

527 if self._var_param == "omega": 

528 omega, beta, thrs = error_utils.epsc_thresholds( 

529 dev_scores["licit_neg"], 

530 dev_scores["licit_pos"], 

531 dev_scores["spoof"], 

532 dev_scores["licit_pos"], 

533 points=points, 

534 criteria=self._criteria, 

535 beta=fp, 

536 ) 

537 else: 

538 omega, beta, thrs = error_utils.epsc_thresholds( 

539 dev_scores["licit_neg"], 

540 dev_scores["licit_pos"], 

541 dev_scores["spoof"], 

542 dev_scores["licit_pos"], 

543 points=points, 

544 criteria=self._criteria, 

545 omega=fp, 

546 ) 

547 

548 errors = error_utils.all_error_rates( 

549 eval_scores["licit_neg"], 

550 eval_scores["licit_pos"], 

551 eval_scores["spoof"], 

552 eval_scores["licit_pos"], 

553 thrs, 

554 omega, 

555 beta, 

556 ) # error rates are returned in a list in the 

557 # following order: frr, far, IAPMR, far_w, wer_w 

558 mpl.sca(self._axis1) 

559 # Between zero-effort impostors and Presentation attacks 

560 if self._wer: 

561 logger.debug( 

562 f"Plotting EPSC: WER for system {idx+1}, fix param {pi}: {fp}" 

563 ) 

564 set_title = ( 

565 self._titles[(idx // self.n_systems) * self._nb_subplots] 

566 if self._titles 

567 else None 

568 ) 

569 display = ( 

570 set_title.replace(" ", "") 

571 if set_title is not None 

572 else True 

573 ) 

574 wer_title = set_title or "EPSC" 

575 if display: 

576 self._axis1.set_title(wer_title) 

577 base = f"({legend}) " if legend.strip() else "" 

578 if self._var_param == "omega": 

579 label = f"{base}$\\beta={fp:.1f}$" 

580 self._axis1.plot( 

581 omega, 

582 100.0 * errors[4].flatten(), 

583 color=self._colors[pi], 

584 linestyle="-", 

585 label=label, 

586 ) 

587 self._axis1.set_xlabel(self._x_label or "Weight $\\omega$") 

588 else: 

589 label = f"{base}$\\omega={fp:.1f}$" 

590 self._axis1.plot( 

591 beta, 

592 100.0 * errors[4].flatten(), 

593 color=self._colors[pi], 

594 linestyle="-", 

595 label=label, 

596 ) 

597 self._axis1.set_xlabel(self._x_label or "Weight $\\beta$") 

598 self._axis1.set_ylabel( 

599 self._y_label or "WER$_{\\omega,\\beta}$ (%)" 

600 ) 

601 self._axis1.grid(True) 

602 self._axis1.legend(loc=self._legend_loc) 

603 

604 if self._iapmr: 

605 logger.debug( 

606 f"Plotting EPSC: IAPMR for system {idx+1}, fix param {pi}: {fp}" 

607 ) 

608 mpl.sca(self._axis2) 

609 set_title = ( 

610 self._titles[ 

611 (idx // self.n_systems) * self._nb_subplots + 1 

612 ] 

613 if self._titles 

614 else None 

615 ) 

616 display = ( 

617 set_title.replace(" ", "") 

618 if set_title is not None 

619 else True 

620 ) 

621 iapmr_title = set_title or "EPSC" 

622 if display: 

623 self._axis2.set_title(iapmr_title) 

624 base = f"({legend}) " if legend.strip() else "" 

625 if self._var_param == "omega": 

626 label = f"{base} $\\beta={fp:.1f}$" 

627 self._axis2.plot( 

628 omega, 

629 100.0 * errors[2].flatten(), 

630 color=self._colors[pi], 

631 linestyle="-", 

632 label=label, 

633 ) 

634 self._axis2.set_xlabel(self._x_label or "Weight $\\omega$") 

635 else: 

636 label = f"{base} $\\omega={fp:.1f}$" 

637 self._axis2.plot( 

638 beta, 

639 100.0 * errors[2].flatten(), 

640 linestyle="-", 

641 color=self._colors[pi], 

642 label=label, 

643 ) 

644 self._axis2.set_xlabel(self._x_label or "Weight $\\beta$") 

645 

646 self._axis2.set_ylabel(self._y_label or "IAPMR (%)") 

647 self._axis2.grid(True) 

648 self._axis2.legend(loc=self._legend_loc) 

649 

650 def end_process(self): 

651 """Sets the legend.""" 

652 if self._end_setup_plot: 

653 for i in range(self._nb_figs): 

654 fig = mpl.figure(i + 1) 

655 if self._disp_legend: 

656 mpl.legend(loc=self._legend_loc) 

657 self._pdf_page.savefig(fig) 

658 self._end_setup_plot = False 

659 super().end_process() 

660 

661 

662class Epsc3D(Epsc): 

663 """3D EPSC plots for vuln""" 

664 

665 def __init__(self, ctx, scores, func_load, criteria, var_param, **kwargs): 

666 super().__init__(ctx, scores, func_load, criteria, var_param, **kwargs) 

667 if self._nb_subplots != 1: 

668 raise ValueError( 

669 "You cannot plot more than one type of plot (WER or IAPMR)." 

670 ) 

671 

672 def compute(self, idx, input_scores, input_names): 

673 """Implements plots""" 

674 dev_scores = clean_scores(input_scores[0]) 

675 if self._eval: 

676 eval_scores = clean_scores(input_scores[1]) 

677 else: 

678 eval_scores = {"licit_neg": [], "licit_pos": [], "spoof": []} 

679 

680 default_title = "EPSC 3D" 

681 title = self._titles[idx] if self._titles else default_title 

682 

683 mpl.rcParams.pop("key", None) 

684 

685 points = self._sampling or 5 

686 

687 # Compute threshold values on dev 

688 omega, beta, thrs = error_utils.epsc_thresholds( 

689 dev_scores["licit_neg"], 

690 dev_scores["licit_pos"], 

691 dev_scores["spoof"], 

692 dev_scores["licit_pos"], 

693 points=points, 

694 criteria=self._criteria, 

695 ) 

696 

697 # Compute errors on eval 

698 errors = error_utils.all_error_rates( 

699 eval_scores["licit_neg"], 

700 eval_scores["licit_pos"], 

701 eval_scores["spoof"], 

702 eval_scores["licit_pos"], 

703 thrs, 

704 omega, 

705 beta, 

706 ) 

707 # error rates are returned in a list as 2D numpy.ndarrays in 

708 # the following order: frr, far, IAPMR, far_w, wer_wb, hter_wb 

709 wer_errors = 100 * errors[2 if self._iapmr else 4] 

710 

711 if not self._axis1: 

712 self._axis1 = mpl.gcf().add_subplot(111, projection="3d") 

713 

714 W, B = np.meshgrid(omega, beta) 

715 

716 label = self._legends[idx] if self._legends else f"Sys {idx+1}" 

717 self._axis1.plot_wireframe( 

718 W, 

719 B, 

720 wer_errors, 

721 color=self._colors[idx], 

722 antialiased=False, 

723 label=label, 

724 ) 

725 

726 if self._iapmr: 

727 self._axis1.azim = -30 

728 self._axis1.elev = 50 

729 

730 self._axis1.set_xlabel(self._x_label or r"Weight $\omega$") 

731 self._axis1.set_ylabel(self._y_label or r"Weight $\beta$") 

732 self._axis1.set_zlabel( 

733 r"WER$_{\omega,\beta}$ (%)" if self._wer else "IAPMR (%)" 

734 ) 

735 

736 if title.replace(" ", ""): 

737 mpl.title(title) 

738 

739 

740class BaseVulnDetRoc(measure_figure.PlotBase): 

741 """Base for DET and ROC""" 

742 

743 def __init__(self, ctx, scores, evaluation, func_load, real_data, no_spoof): 

744 super(BaseVulnDetRoc, self).__init__(ctx, scores, evaluation, func_load) 

745 self._no_spoof = no_spoof 

746 self._fnmrs_at = ctx.meta.get("fnmr", []) 

747 self._fnmrs_at = [] if self._fnmrs_at is None else self._fnmrs_at 

748 self._real_data = True if real_data is None else real_data 

749 self._min_dig = -4 if self._min_dig is None else self._min_dig 

750 self._tpr = ctx.meta.get("tpr", True) 

751 

752 def compute(self, idx, input_scores, input_names): 

753 """Implements plots""" 

754 dev_scores = clean_scores(input_scores[0]) 

755 if self._eval: 

756 eval_scores = clean_scores(input_scores[1]) 

757 else: 

758 eval_scores = {"licit_neg": [], "licit_pos": [], "spoof": []} 

759 

760 mpl.figure(1) 

761 if self._eval: 

762 logger.info(f"dev curve using {input_names[0]}") 

763 self._plot( 

764 dev_scores["licit_neg"], 

765 dev_scores["licit_pos"], 

766 dev_scores["spoof"], 

767 npoints=self._points, 

768 tpr=self._tpr, 

769 min_far=self._min_dig, 

770 color=self._colors[idx], 

771 linestyle=self._linestyles[idx], 

772 label=self._label("dev", idx), 

773 alpha=self._alpha, 

774 ) 

775 if not self._fnmrs_at: 

776 logger.info("Plotting fnmr line at dev eer threshold for dev") 

777 dev_threshold = get_thres( 

778 criter="eer", 

779 neg=dev_scores["licit_neg"], 

780 pos=dev_scores["licit_pos"], 

781 ) 

782 _, fnmr_at_dev_threshold = farfrr( 

783 [0.0], dev_scores["licit_pos"], dev_threshold 

784 ) 

785 fnmrs_dev = self._fnmrs_at or [fnmr_at_dev_threshold] 

786 self._draw_fnmrs(idx, dev_scores, fnmrs_dev) 

787 

788 if self._split: 

789 mpl.figure(2) 

790 

791 # Add the eval plot 

792 linestyle = "--" if not self._split else self._linestyles[idx] 

793 logger.info(f"eval curve using {input_names[1]}") 

794 self._plot( 

795 eval_scores["licit_neg"], 

796 eval_scores["licit_pos"], 

797 eval_scores["spoof"], 

798 linestyle=linestyle, 

799 npoints=self._points, 

800 tpr=self._tpr, 

801 min_far=self._min_dig, 

802 color=self._colors[idx], 

803 label=self._label("eval", idx), 

804 alpha=self._alpha, 

805 ) 

806 if not self._fnmrs_at: 

807 logger.info("printing fnmr at dev eer threshold for eval") 

808 _, fnmr_at_dev_threshold = farfrr( 

809 [0.0], eval_scores["licit_pos"], dev_threshold 

810 ) 

811 fnmrs_dev = self._fnmrs_at or [fnmr_at_dev_threshold] 

812 self._draw_fnmrs(idx, eval_scores, fnmrs_dev, True) 

813 

814 # Only dev scores available 

815 else: 

816 logger.info(f"dev curve using {input_names[0]}") 

817 self._plot( 

818 dev_scores["licit_neg"], 

819 dev_scores["licit_pos"], 

820 dev_scores["spoof"], 

821 npoints=self._points, 

822 tpr=self._tpr, 

823 min_far=self._min_dig, 

824 color=self._colors[idx], 

825 linestyle=self._linestyles[idx], 

826 label=self._label("dev", idx), 

827 alpha=self._alpha, 

828 ) 

829 if not self._fnmrs_at: 

830 logger.info("Plotting fnmr line at dev eer threshold for dev") 

831 dev_threshold = get_thres( 

832 criter="eer", 

833 neg=dev_scores["licit_neg"], 

834 pos=dev_scores["licit_pos"], 

835 ) 

836 _, fnmr_at_dev_threshold = farfrr( 

837 [0.0], dev_scores["licit_pos"], dev_threshold 

838 ) 

839 fnmrs_dev = self._fnmrs_at or [fnmr_at_dev_threshold] 

840 self._draw_fnmrs(idx, dev_scores, fnmrs_dev) 

841 

842 def _get_farfrr(self, x, y, thres): 

843 return None, None 

844 

845 def _plot(self, x, y, s, npoints, **kwargs): 

846 pass 

847 

848 def _draw_fnmrs(self, idx, scores, fnmrs=[], eval=False): 

849 pass 

850 

851 

852class DetVuln(BaseVulnDetRoc): 

853 """DET for vuln""" 

854 

855 def __init__(self, ctx, scores, evaluation, func_load, real_data, no_spoof): 

856 super(DetVuln, self).__init__( 

857 ctx, scores, evaluation, func_load, real_data, no_spoof 

858 ) 

859 self._x_label = self._x_label or "FMR (%)" 

860 self._y_label = self._y_label or "FNMR (%)" 

861 self._semilogx = ctx.meta.get("semilogx", False) 

862 add = "" 

863 if not self._titles: 

864 self._titles = [""] * self._nb_figs 

865 if not self._no_spoof: 

866 add = " and overlaid SPOOF scenario" 

867 for i, t in enumerate(self._titles): 

868 if self._eval and (i % 2): 

869 dev_eval = ", eval group" 

870 elif self._eval: 

871 dev_eval = ", dev group" 

872 else: 

873 dev_eval = "" 

874 self._titles[i] = t or ("DET: LICIT" + add + dev_eval) 

875 self._legend_loc = self._legend_loc or "upper right" 

876 

877 def _set_axis(self): 

878 if self._axlim is not None and None not in self._axlim: 

879 plot.det_axis(self._axlim) 

880 else: 

881 plot.det_axis([0.01, 99, 0.01, 99]) 

882 

883 def _get_farfrr(self, x, y, thres): 

884 points = farfrr(x, y, thres) 

885 return points, [ppndf(i) for i in points] 

886 

887 def _plot(self, x, y, s, npoints, **kwargs): 

888 logger.info("Plotting DET") 

889 plot.det( 

890 x, 

891 y, 

892 npoints, 

893 min_far=self._min_dig, 

894 color=kwargs.get("color"), 

895 linestyle=kwargs.get("linestyle"), 

896 label=kwargs.get("label"), 

897 ) 

898 if not self._no_spoof and s is not None: 

899 ax1 = mpl.gca() 

900 ax2 = ax1.twiny() 

901 ax2.set_xlabel("IAPMR (%)", color="C3") 

902 ax2.tick_params( 

903 axis="x", 

904 colors="C3", 

905 labelrotation=self._x_rotation, 

906 labelcolor="C3", 

907 ) 

908 # Prevent tick labels overlap 

909 ax2.tick_params(axis="both", which="major", labelsize="x-small") 

910 ax1.tick_params(axis="both", which="major", labelsize="x-small") 

911 ax2.spines["top"].set_color("C3") 

912 plot.det( 

913 s, 

914 y, 

915 npoints, 

916 min_far=self._min_dig, 

917 color="C3", 

918 linestyle=":", 

919 label="Spoof " + kwargs.get("label"), 

920 ) 

921 self._set_axis() 

922 mpl.sca(ax1) 

923 

924 def _draw_fnmrs(self, idx, scores, fnmrs=[], eval=False): 

925 for line in fnmrs: 

926 thres_baseline = frr_threshold( 

927 scores["licit_neg"], scores["licit_pos"], line 

928 ) 

929 

930 axlim = mpl.axis() 

931 

932 farfrr_licit, farfrr_licit_det = self._get_farfrr( 

933 scores["licit_neg"], scores["licit_pos"], thres_baseline 

934 ) 

935 if farfrr_licit is None: 

936 return 

937 

938 farfrr_spoof, farfrr_spoof_det = self._get_farfrr( 

939 scores["spoof"], 

940 scores["licit_pos"], 

941 frr_threshold( 

942 scores["spoof"], scores["licit_pos"], farfrr_licit[1] 

943 ), 

944 ) 

945 

946 if not self._real_data: 

947 # Takes specified FNMR value as EER 

948 mpl.axhline( 

949 y=farfrr_licit_det[1], 

950 xmin=axlim[2], 

951 xmax=axlim[3], 

952 color="k", 

953 linestyle="--", 

954 label="%s @ EER" % self._y_label, 

955 ) 

956 else: 

957 mpl.axhline( 

958 y=farfrr_licit_det[1], 

959 xmin=axlim[0], 

960 xmax=axlim[1], 

961 color="k", 

962 linestyle="--", 

963 label="%s = %.2f%%" % ("FNMR", farfrr_licit[1] * 100), 

964 ) 

965 

966 if not self._real_data: 

967 label_licit = "%s @ operating point" % self._x_label 

968 label_spoof = "IAPMR @ operating point" 

969 else: 

970 label_licit = "FMR=%.2f%%" % (farfrr_licit[0] * 100) 

971 label_spoof = "IAPMR=%.2f%%" % (farfrr_spoof[0] * 100) 

972 

973 # Annotations and drawing of the points 

974 text_x_offset = 2 

975 text_y_offset = 5 

976 # Licit 

977 mpl.annotate( 

978 xy=(farfrr_licit_det[0], farfrr_licit_det[1]), 

979 text=label_licit, 

980 xytext=(text_x_offset, text_y_offset), 

981 textcoords="offset points", 

982 fontsize="small", 

983 ) 

984 mpl.plot( 

985 farfrr_licit_det[0], 

986 farfrr_licit_det[1], 

987 "o", 

988 color=self._colors[idx], 

989 ) # FAR point, licit scenario 

990 # Spoof 

991 axlim = self._axlim or [0.01, 99, 0.1, 99] 

992 if ( 

993 farfrr_spoof_det[0] > axlim[0] 

994 and farfrr_spoof_det[0] < axlim[1] 

995 ): 

996 mpl.annotate( 

997 xy=(farfrr_spoof_det[0], farfrr_spoof_det[1]), 

998 text=label_spoof, 

999 xytext=(text_x_offset, text_y_offset), 

1000 textcoords="offset points", 

1001 fontsize="small", 

1002 ) 

1003 mpl.plot( 

1004 farfrr_spoof_det[0], 

1005 farfrr_spoof_det[1], 

1006 "o", 

1007 color="C3", 

1008 ) # FAR point, spoof scenario 

1009 else: 

1010 logger.warning( 

1011 f"The IAPMR for an FNMR of {line} is outside the plot." 

1012 ) 

1013 

1014 

1015class RocVuln(BaseVulnDetRoc): 

1016 """ROC for vuln""" 

1017 

1018 def __init__(self, ctx, scores, evaluation, func_load, real_data, no_spoof): 

1019 super(RocVuln, self).__init__( 

1020 ctx, scores, evaluation, func_load, real_data, no_spoof 

1021 ) 

1022 self._x_label = self._x_label or "FMR" 

1023 self._y_label = self._y_label or "1 - FNMR" 

1024 self._semilogx = ctx.meta.get("semilogx", True) 

1025 add = "" 

1026 if not self._titles: 

1027 self._titles = [""] * self._nb_figs 

1028 if not self._no_spoof: 

1029 add = " and overlaid SPOOF scenario" 

1030 for i, t in enumerate(self._titles): 

1031 if self._eval and (i % 2): 

1032 dev_eval = ", eval group" 

1033 elif self._eval: 

1034 dev_eval = ", dev group" 

1035 else: 

1036 dev_eval = "" 

1037 self._titles[i] = t or ("ROC: LICIT" + add + dev_eval) 

1038 if self._legend_loc == "best": 

1039 self._legend_loc = ( 

1040 "lower right" if self._semilogx else "upper right" 

1041 ) 

1042 

1043 def _plot(self, x, y, s, npoints, **kwargs): 

1044 logger.info("Plotting ROC") 

1045 plot.roc( 

1046 x, 

1047 y, 

1048 npoints=npoints, 

1049 semilogx=self._semilogx, 

1050 tpr=self._tpr, 

1051 min_far=self._min_dig, 

1052 color=kwargs.get("color"), 

1053 linestyle=kwargs.get("linestyle"), 

1054 label=kwargs.get("label"), 

1055 ) 

1056 if not self._no_spoof and s is not None: 

1057 ax1 = mpl.gca() 

1058 ax1.plot( 

1059 [0], 

1060 [0], 

1061 linestyle=":", 

1062 color="C3", 

1063 label="Spoof " + kwargs.get("label"), 

1064 ) 

1065 ax2 = ax1.twiny() 

1066 ax2.set_xlabel("IAPMR (%)", color="C3") 

1067 mpl.xticks(rotation=self._x_rotation) 

1068 ax2.tick_params( 

1069 axis="x", 

1070 colors="C3", 

1071 labelrotation=self._x_rotation, 

1072 labelcolor="C3", 

1073 ) 

1074 ax2.spines["top"].set_color("C3") 

1075 plot.roc( 

1076 s, 

1077 y, 

1078 npoints=npoints, 

1079 semilogx=self._semilogx, 

1080 tpr=self._tpr, 

1081 min_far=self._min_dig, 

1082 color="C3", 

1083 linestyle=":", 

1084 label="Spoof " + kwargs.get("label"), 

1085 ) 

1086 self._set_axis() 

1087 mpl.sca(ax1) 

1088 

1089 def _get_farfrr(self, x, y, thres): 

1090 points = farfrr(x, y, thres) 

1091 points2 = (points[0], 1 - points[1]) 

1092 return points, points2 

1093 

1094 def _draw_fnmrs(self, idx, scores, fnmrs=[], evaluation=False): 

1095 for line in fnmrs: 

1096 thres_baseline = frr_threshold( 

1097 scores["licit_neg"], scores["licit_pos"], line 

1098 ) 

1099 

1100 axlim = mpl.axis() 

1101 

1102 farfrr_licit, farfrr_licit_roc = self._get_farfrr( 

1103 scores["licit_neg"], scores["licit_pos"], thres_baseline 

1104 ) 

1105 if farfrr_licit is None: 

1106 return 

1107 

1108 farfrr_spoof, farfrr_spoof_roc = self._get_farfrr( 

1109 scores["spoof"], 

1110 scores["licit_pos"], 

1111 frr_threshold( 

1112 scores["spoof"], scores["licit_pos"], farfrr_licit[1] 

1113 ), 

1114 ) 

1115 

1116 if not self._real_data and not evaluation: 

1117 mpl.axhline( 

1118 y=farfrr_licit_roc[1], 

1119 xmin=axlim[2], 

1120 xmax=axlim[3], 

1121 color="k", 

1122 linestyle="--", 

1123 label=f"{self._y_label} @ EER", 

1124 ) 

1125 elif not evaluation: 

1126 mpl.axhline( 

1127 y=farfrr_licit_roc[1], 

1128 xmin=axlim[0], 

1129 xmax=axlim[1], 

1130 color="k", 

1131 linestyle="--", 

1132 label=f"FNMR = {farfrr_licit[1] * 100:.2f}%", 

1133 ) 

1134 

1135 if not self._real_data: 

1136 label_licit = f"{self._x_label} @ operating point" 

1137 label_spoof = "IAPMR @ operating point" 

1138 else: 

1139 label_licit = f"FMR={farfrr_licit[0] * 100:.2f}%" 

1140 label_spoof = f"IAPMR={farfrr_spoof[0] * 100:.2f}%" 

1141 

1142 mpl.plot( 

1143 farfrr_licit_roc[0], 

1144 farfrr_licit_roc[1], 

1145 "o", 

1146 color=self._colors[idx], 

1147 label=label_licit, 

1148 ) # FAR point, licit scenario 

1149 mpl.plot( 

1150 farfrr_spoof_roc[0], 

1151 farfrr_spoof_roc[1], 

1152 "o", 

1153 color="C3", 

1154 label=label_spoof, 

1155 ) # FAR point, spoof scenario 

1156 

1157 

1158class FmrIapmr(measure_figure.PlotBase): 

1159 """FMR vs IAPMR""" 

1160 

1161 def __init__(self, ctx, scores, evaluation, func_load): 

1162 super(FmrIapmr, self).__init__(ctx, scores, evaluation, func_load) 

1163 self._eval = True # Always ask for eval data 

1164 self._split = False 

1165 self._nb_figs = 1 

1166 self._semilogx = ctx.meta.get("semilogx", False) 

1167 if not self._titles: 

1168 self._titles = [""] * self._nb_figs 

1169 for i, t in enumerate(self._titles): 

1170 self._titles[i] = t or "FMR vs IAPMR" 

1171 self._x_label = self._x_label or "FMR" 

1172 self._y_label = self._y_label or "IAPMR" 

1173 if self._min_arg != 2: 

1174 raise click.BadParameter( 

1175 "You must provide 2 scores files: " "scores-{dev,eval}.csv" 

1176 ) 

1177 

1178 def compute(self, idx, input_scores, input_names): 

1179 """Implements plots""" 

1180 dev_scores = clean_scores(input_scores[0]) 

1181 if self._eval: 

1182 eval_scores = clean_scores(input_scores[1]) 

1183 fmr_list = np.linspace(0, 1, 100) 

1184 iapmr_list = [] 

1185 for i, fmr in enumerate(fmr_list): 

1186 thr = far_threshold( 

1187 dev_scores["licit_neg"], dev_scores["licit_pos"], fmr 

1188 ) 

1189 iapmr_list.append(farfrr(eval_scores["spoof"], [0.0], thr)[0]) 

1190 # re-calculate fmr since threshold might give a different result 

1191 # for fmr. 

1192 fmr_list[i], _ = farfrr(eval_scores["licit_neg"], [0.0], thr) 

1193 label = ( 

1194 self._legends[idx] 

1195 if self._legends is not None 

1196 else f"system {idx+1}" 

1197 ) 

1198 logger.info(f"Plot FmrIapmr using: {input_names[1]}") 

1199 if self._semilogx: 

1200 mpl.semilogx(fmr_list, iapmr_list, label=label) 

1201 else: 

1202 mpl.plot(fmr_list, iapmr_list, label=label)