Coverage for src/bob/bio/base/pipelines/abstract_classes.py: 89%

145 statements  

« prev     ^ index     » next       coverage.py v7.6.5, created at 2024-11-14 21:41 +0100

1#!/usr/bin/env python 

2# vim: set fileencoding=utf-8 : 

3 

4 

5import logging 

6import os 

7 

8from abc import ABCMeta, abstractmethod 

9from typing import Any, Callable, Optional, Union 

10 

11import numpy as np 

12 

13from sklearn.base import BaseEstimator 

14 

15from bob.pipelines import Sample, SampleBatch, SampleSet 

16from bob.pipelines.wrappers import _frmt 

17 

18logger = logging.getLogger(__name__) 

19 

20 

21def reduce_scores( 

22 scores: np.ndarray, 

23 axis: int, 

24 fn: Union[str, Callable[[np.ndarray, int], np.ndarray]] = "max", 

25): 

26 """ 

27 Reduce scores using a function. 

28 

29 Parameters: 

30 ----------- 

31 scores 

32 Scores to reduce. 

33 

34 fn 

35 Function to use for reduction. You can also provide a string like 

36 ``max`` to use the corresponding function from numpy. Some possible 

37 values are: ``max``, ``min``, ``mean``, ``median``, ``sum``. 

38 

39 Returns: 

40 -------- 

41 Reduced scores. 

42 """ 

43 if isinstance(fn, str): 

44 fn = getattr(np, fn) 

45 return fn(scores, axis=axis) 

46 

47 

48def _data_valid(data: Any) -> bool: 

49 """Check if data is valid. 

50 

51 Parameters: 

52 ----------- 

53 data 

54 Data to check. 

55 

56 Returns: 

57 -------- 

58 True if data is valid, False otherwise. 

59 """ 

60 if data is None: 

61 return False 

62 if isinstance(data, np.ndarray): 

63 return data.size > 0 

64 # we also have to check for [[]] 

65 if isinstance(data, list) and len(data) > 0: 

66 if isinstance(data[0], (list, tuple)): 

67 return len(data[0]) > 0 

68 return bool(data) 

69 

70 

71class BioAlgorithm(BaseEstimator, metaclass=ABCMeta): 

72 """Describes a base biometric comparator for the PipelineSimple 

73 :ref:`bob.bio.base.biometric_algorithm`. 

74 

75 A biometric algorithm converts each SampleSet (which is a list of 

76 samples/features) into a single template. Template creation is done for both 

77 enroll and probe samples but the format of the templates can be different 

78 between enrollment and probe samples. After the creation of the templates, 

79 the algorithm computes one similarity score for comparison of an enroll 

80 template with a probe template. 

81 

82 Examples 

83 -------- 

84 >>> import numpy as np 

85 >>> from bob.bio.base.pipelines import BioAlgorithm 

86 >>> class MyAlgorithm(BioAlgorithm): 

87 ... 

88 ... def create_templates(self, list_of_feature_sets, enroll): 

89 ... # you cannot call np.mean(list_of_feature_sets, axis=1) because the 

90 ... # number of features in each feature set may vary. 

91 ... return [np.mean(feature_set, axis=0) for feature_set in list_of_feature_sets] 

92 ... 

93 ... def compare(self, enroll_templates, probe_templates): 

94 ... scores = [] 

95 ... for enroll_template in enroll_templates: 

96 ... scores.append([]) 

97 ... for probe_template in probe_templates: 

98 ... similarity = 1 / np.linalg.norm(model - probe) 

99 ... scores[-1].append(similarity) 

100 ... scores = np.array(scores, dtype=float) 

101 ... return scores 

102 """ 

103 

104 def __init__( 

105 self, 

106 probes_score_fusion: Union[ 

107 str, Callable[[list[np.ndarray], int], np.ndarray] 

108 ] = "max", 

109 enrolls_score_fusion: Union[ 

110 str, Callable[[list[np.ndarray], int], np.ndarray] 

111 ] = "max", 

112 **kwargs, 

113 ) -> None: 

114 super().__init__(**kwargs) 

115 self.probes_score_fusion = probes_score_fusion 

116 self.enrolls_score_fusion = enrolls_score_fusion 

117 

118 def fuse_probe_scores(self, scores, axis): 

119 return reduce_scores(scores, axis, self.probes_score_fusion) 

120 

121 def fuse_enroll_scores(self, scores, axis): 

122 return reduce_scores(scores, axis, self.enrolls_score_fusion) 

123 

124 @abstractmethod 

125 def create_templates( 

126 self, list_of_feature_sets: list[Any], enroll: bool 

127 ) -> list[Sample]: 

128 """Creates enroll or probe templates from multiple sets of features. 

129 

130 The enroll template format can be different from the probe templates. 

131 

132 Parameters 

133 ---------- 

134 list_of_feature_sets 

135 A list of list of features with the shape of Nx?xD. N templates 

136 should be computed. Note that you cannot call 

137 np.array(list_of_feature_sets) because the number of features per 

138 set can be different depending on the database. 

139 enroll 

140 If True, the features are for enrollment. If False, the features are 

141 for probe. 

142 

143 Returns 

144 ------- 

145 templates 

146 A list of templates which has the same length as 

147 ``list_of_feature_sets``. 

148 """ 

149 pass 

150 

151 @abstractmethod 

152 def compare( 

153 self, enroll_templates: list[Sample], probe_templates: list[Sample] 

154 ) -> np.ndarray: 

155 """Computes the similarity score between all enrollment and probe templates. 

156 

157 Parameters 

158 ---------- 

159 enroll_templates 

160 A list (length N) of enrollment templates. 

161 

162 probe_templates 

163 A list (length M) of probe templates. 

164 

165 Returns 

166 ------- 

167 scores 

168 A matrix of shape (N, M) containing the similarity scores. 

169 """ 

170 pass 

171 

172 def create_templates_from_samplesets( 

173 self, list_of_samplesets: list[SampleSet], enroll: bool 

174 ) -> list[Sample]: 

175 """Creates enroll or probe templates from multiple SampleSets. 

176 

177 Parameters 

178 ---------- 

179 list_of_samplesets 

180 A list (length N) of SampleSets. 

181 

182 enroll 

183 If True, the SampleSets are for enrollment. If False, the SampleSets 

184 are for probe. 

185 

186 Returns 

187 ------- 

188 templates 

189 A list of Samples which has the same length as ``list_of_samplesets``. 

190 Each Sample contains a template. 

191 """ 

192 logger.debug( 

193 f"{_frmt(self)}.create_templates_from_samplesets(... enroll={enroll})" 

194 ) 

195 # create templates from .data attribute of samples inside sample_sets 

196 list_of_feature_sets = [] 

197 for sampleset in list_of_samplesets: 

198 data = [s.data for s in sampleset.samples] 

199 valid_data = [d for d in data if d is not None] 

200 if len(data) != len(valid_data): 

201 logger.warning( 

202 f"Removed {len(data)-len(valid_data)} invalid enrollment samples." 

203 ) 

204 if not valid_data and enroll: 

205 # we do not support failure to enroll cases currently 

206 raise NotImplementedError( 

207 f"None of the enrollment samples were valid for {sampleset}." 

208 ) 

209 list_of_feature_sets.append(valid_data) 

210 

211 templates = self.create_templates(list_of_feature_sets, enroll) 

212 expected_size = len(list_of_samplesets) 

213 assert len(templates) == expected_size, ( 

214 "The number of (%s) templates (%d) created by the algorithm does not match " 

215 "the number of sample sets (%d)" 

216 % ( 

217 "enroll" if enroll else "probe", 

218 len(templates), 

219 expected_size, 

220 ) 

221 ) 

222 # return a list of Samples (one per template) 

223 templates = [ 

224 Sample(t, parent=sampleset) 

225 for t, sampleset in zip(templates, list_of_samplesets) 

226 ] 

227 return templates 

228 

229 def score_sample_templates( 

230 self, 

231 probe_samples: list[Sample], 

232 enroll_samples: list[Sample], 

233 score_all_vs_all: bool, 

234 ) -> list[SampleSet]: 

235 """Computes the similarity score between all probe and enroll templates. 

236 

237 Parameters 

238 ---------- 

239 probe_samples 

240 A list (length N) of Samples containing probe templates. 

241 

242 enroll_samples 

243 A list (length M) of Samples containing enroll templates. 

244 

245 score_all_vs_all 

246 If True, the similarity scores between all probe and enroll templates 

247 are computed. If False, the similarity scores between the probes and 

248 their associated enroll templates are computed. 

249 

250 Returns 

251 ------- 

252 score_samplesets 

253 A list of N SampleSets each containing a list of M score Samples if score_all_vs_all 

254 is True. Otherwise, a list of N SampleSets each containing a list of <=M score Samples 

255 depending on the database. 

256 """ 

257 logger.debug( 

258 f"{_frmt(self)}.score_sample_templates(... score_all_vs_all={score_all_vs_all})" 

259 ) 

260 # Returns a list of SampleSets where a Sampleset for each probe 

261 # SampleSet where each Sample inside the SampleSets contains the score 

262 # for one enroll SampleSet 

263 score_samplesets = [] 

264 if score_all_vs_all: 

265 probe_data = [s.data for s in probe_samples] 

266 valid_probe_indices = [ 

267 i for i, d in enumerate(probe_data) if _data_valid(d) 

268 ] 

269 valid_probe_data = [probe_data[i] for i in valid_probe_indices] 

270 scores = self.compare(SampleBatch(enroll_samples), valid_probe_data) 

271 scores = np.asarray(scores, dtype=float) 

272 

273 if len(valid_probe_indices) != len(probe_data): 

274 # inject None scores for invalid probe samples 

275 scores: list = scores.T.tolist() 

276 for i in range(len(probe_data)): 

277 if i not in valid_probe_indices: 

278 scores.insert(i, [None] * len(enroll_samples)) 

279 # transpose back to original shape 

280 scores = np.array(scores, dtype=float).T 

281 

282 expected_shape = (len(enroll_samples), len(probe_samples)) 

283 assert scores.shape == expected_shape, ( 

284 "The shape of the similarity scores (%s) does not match the expected shape (%s)" 

285 % (scores.shape, expected_shape) 

286 ) 

287 for j, probe in enumerate(probe_samples): 

288 samples = [] 

289 for i, enroll in enumerate(enroll_samples): 

290 samples.append(Sample(scores[i, j], parent=enroll)) 

291 score_samplesets.append(SampleSet(samples, parent=probe)) 

292 else: 

293 for probe in probe_samples: 

294 references = [str(ref) for ref in probe.references] 

295 # get the indices of references for enroll samplesets 

296 indices = [ 

297 i 

298 for i, enroll in enumerate(enroll_samples) 

299 if str(enroll.template_id) in references 

300 ] 

301 if not indices: 

302 raise ValueError( 

303 f"No enroll sampleset found for probe {probe} and its required references {references}. " 

304 "Did you mean to set score_all_vs_all=True?" 

305 ) 

306 if not _data_valid(probe.data): 

307 scores = [[None]] * len(indices) 

308 else: 

309 scores = self.compare( 

310 SampleBatch([enroll_samples[i] for i in indices]), 

311 SampleBatch([probe]), 

312 ) 

313 scores = np.asarray(scores, dtype=float) 

314 expected_shape = (len(indices), 1) 

315 assert scores.shape == expected_shape, ( 

316 "The shape of the similarity scores (%s) does not match the expected shape (%s)" 

317 % (scores.shape, expected_shape) 

318 ) 

319 samples = [] 

320 for i, j in enumerate(indices): 

321 samples.append( 

322 Sample(scores[i, 0], parent=enroll_samples[j]) 

323 ) 

324 score_samplesets.append(SampleSet(samples, parent=probe)) 

325 

326 return score_samplesets 

327 

328 

329class Database(metaclass=ABCMeta): 

330 """Base class for PipelineSimple databases""" 

331 

332 def __init__( 

333 self, 

334 protocol: Optional[str] = None, 

335 score_all_vs_all: bool = False, 

336 annotation_type: Optional[str] = None, 

337 fixed_positions: Optional[str] = None, 

338 memory_demanding: bool = False, 

339 **kwargs, 

340 ): 

341 """ 

342 Parameters 

343 ---------- 

344 protocol 

345 Name of the database protocol to use. 

346 score_all_vs_all 

347 Wether to allow scoring of all the probes against all the references, or to 

348 provide a list ``references`` provided with each probes to indicate against 

349 which references it needs to be compared. 

350 annotation_type 

351 The type of annotation passed to the annotation loading function. 

352 fixed_positions 

353 The constant eyes positions passed to the annotation loading function. 

354 TODO why keep this face-related name here? Which one is it, too (position 

355 when annotations are missing, or ending position in the result image)? 

356 --> move this when the FaceCrop annotator is correctly implemented. 

357 memory_demanding 

358 Flag to indicate that this should not be loaded locally. 

359 TODO Where is it used? 

360 """ 

361 super().__init__(**kwargs) 

362 if not hasattr(self, "protocol"): 

363 self.protocol = protocol 

364 self.score_all_vs_all = score_all_vs_all 

365 self.annotation_type = annotation_type 

366 self.fixed_positions = fixed_positions 

367 self.memory_demanding = memory_demanding 

368 

369 def __str__(self) -> str: 

370 args = ", ".join( 

371 [ 

372 "{}={}".format(k, v) 

373 for k, v in self.__dict__.items() 

374 if not k.startswith("_") 

375 ] 

376 ) 

377 return f"{self.__class__.__name__}({args})" 

378 

379 @abstractmethod 

380 def background_model_samples(self) -> list[Sample]: 

381 """Returns :any:`Sample`\ s to train a background model 

382 

383 

384 Returns 

385 ------- 

386 samples 

387 List of samples for background model training. 

388 

389 """ # noqa: W605 

390 pass 

391 

392 @abstractmethod 

393 def references(self, group: str = "dev") -> list[SampleSet]: 

394 """Returns references to enroll biometric references 

395 

396 

397 Parameters 

398 ---------- 

399 group 

400 Limits samples to this group 

401 

402 

403 Returns 

404 ------- 

405 references 

406 List of samples for the creation of biometric references. 

407 

408 """ 

409 pass 

410 

411 @abstractmethod 

412 def probes(self, group: str = "dev") -> list[SampleSet]: 

413 """Returns probes to score against enrolled biometric references 

414 

415 

416 Parameters 

417 ---------- 

418 group 

419 Limits samples to this group 

420 

421 

422 Returns 

423 ------- 

424 probes 

425 List of samples for the creation of biometric probes. 

426 

427 """ 

428 pass 

429 

430 @abstractmethod 

431 def all_samples(self, groups: Optional[str] = None) -> list[Sample]: 

432 """Returns all the samples of the dataset 

433 

434 Parameters 

435 ---------- 

436 groups 

437 List of groups to consider (like 'dev' or 'eval'). If `None`, will 

438 return samples from all the groups. 

439 

440 Returns 

441 ------- 

442 samples 

443 List of all the samples of the dataset. 

444 """ 

445 pass 

446 

447 @abstractmethod 

448 def groups(self) -> list[str]: 

449 """Returns all the possible groups for the current protocol.""" 

450 pass 

451 

452 @abstractmethod 

453 def protocols(self) -> list[str]: 

454 """Returns all the possible protocols of the database.""" 

455 pass 

456 

457 def template_ids(self, group: str) -> list[Any]: 

458 """Returns the ``template_id`` attribute of each reference.""" 

459 return [s.template_id for s in self.references(group=group)] 

460 

461 

462class ScoreWriter(metaclass=ABCMeta): 

463 """ 

464 Defines base methods to read, write scores and concatenate scores 

465 for :any:`bob.bio.base.pipelines.BioAlgorithm` 

466 """ 

467 

468 def __init__(self, path, extension=".txt", **kwargs): 

469 super().__init__(**kwargs) 

470 self.path = path 

471 self.extension = extension 

472 

473 @abstractmethod 

474 def write(self, sampleset, path): 

475 pass 

476 

477 def post_process(self, score_paths, filename): 

478 def _post_process(score_paths, filename): 

479 os.makedirs(os.path.dirname(filename), exist_ok=True) 

480 with open(filename, "w") as f: 

481 for path in score_paths: 

482 with open(path) as f2: 

483 f.writelines(f2.readlines()) 

484 return filename 

485 

486 import dask 

487 import dask.bag 

488 

489 if isinstance(score_paths, dask.bag.Bag): 

490 all_paths = dask.delayed(list)(score_paths) 

491 return dask.delayed(_post_process)(all_paths, filename) 

492 return _post_process(score_paths, filename)