Coverage for src/bob/bio/base/pipelines/wrappers.py: 95%

115 statements  

« prev     ^ index     » next       coverage.py v7.6.0, created at 2024-07-12 22:34 +0200

1import functools 

2import logging 

3import os 

4 

5from typing import Any, Callable, Optional 

6 

7import dask 

8import h5py 

9import numpy as np 

10 

11from sklearn.pipeline import Pipeline 

12 

13import bob.pipelines 

14 

15from bob.bio.base.pipelines import PipelineSimple 

16from bob.pipelines import ( 

17 CheckpointWrapper, 

18 DelayedSample, 

19 Sample, 

20 is_instance_nested, 

21) 

22from bob.pipelines.wrappers import BaseWrapper, _frmt, get_bob_tags 

23 

24from .abstract_classes import BioAlgorithm 

25 

26logger = logging.getLogger(__name__) 

27 

28 

29def default_save(data: np.ndarray, path: str): 

30 os.makedirs(os.path.dirname(path), exist_ok=True) 

31 with h5py.File(path, "w") as f: 

32 f["data"] = data 

33 

34 

35def default_load(path: str) -> np.ndarray: 

36 with h5py.File(path, "r") as f: 

37 return f["data"][()] 

38 

39 

40def get_bio_alg_tags(estimator=None, force_tags=None): 

41 bob_tags = get_bob_tags(estimator=estimator, force_tags=force_tags) 

42 default_tags = { 

43 "bob_enrolled_extension": ".h5", 

44 "bob_enrolled_save_fn": default_save, 

45 "bob_enrolled_load_fn": default_load, 

46 } 

47 force_tags = force_tags or {} 

48 estimator_tags = estimator._get_tags() if estimator is not None else {} 

49 return {**bob_tags, **default_tags, **estimator_tags, **force_tags} 

50 

51 

52class BioAlgorithmBaseWrapper(BioAlgorithm, BaseWrapper): 

53 def create_templates(self, feature_sets, enroll): 

54 return self.biometric_algorithm.create_templates(feature_sets, enroll) 

55 

56 def compare(self, enroll_templates, probe_templates): 

57 return self.biometric_algorithm.compare( 

58 enroll_templates, probe_templates 

59 ) 

60 

61 

62class BioAlgCheckpointWrapper(BioAlgorithmBaseWrapper): 

63 """Wrapper used to checkpoint enrolled and Scoring samples. 

64 

65 Parameters 

66 ---------- 

67 biometric_algorithm 

68 An implemented :any:`BioAlgorithm` 

69 

70 base_dir 

71 Path to store biometric references and scores 

72 

73 extension 

74 Default extension of the enrolled references files. 

75 If None, will use the ``bob_checkpoint_extension`` tag in the estimator, or 

76 default to ``.h5``. 

77 

78 save_func 

79 Pointer to a customized function that saves an enrolled reference to the disk. 

80 If None, will use the ``bob_enrolled_save_fn`` tag in the estimator, or default 

81 to h5py. 

82 

83 load_func 

84 Pointer to a customized function that loads an enrolled reference from disk. 

85 If None, will use the ``bob_enrolled_load_fn`` tag in the estimator, or default 

86 to h5py. 

87 

88 group 

89 group of the data, used to save different group's checkpoints in different dirs. 

90 

91 force 

92 If True, will recompute scores and biometric references no matter if a file 

93 exists 

94 

95 hash_fn 

96 Pointer to a hash function. This hash function maps 

97 `sample.key` to a hash code and this hash code corresponds a relative directory 

98 where a single `sample` will be checkpointed. 

99 This is useful when is desirable file directories with less than a certain 

100 number of files. 

101 

102 Examples 

103 -------- 

104 

105 >>> from bob.bio.base.algorithm import Distance 

106 >>> from bob.bio.base.pipelines import BioAlgCheckpointWrapper 

107 >>> biometric_algorithm = BioAlgCheckpointWrapper(Distance(), base_dir="./") 

108 >>> biometric_algorithm.create_templates(samples, enroll=True) # doctest: +SKIP 

109 

110 """ 

111 

112 def __init__( 

113 self, 

114 biometric_algorithm: BioAlgorithm, 

115 base_dir: str, 

116 extension: Optional[str] = None, 

117 save_func: Optional[Callable[[str, Any], None]] = None, 

118 load_func: Optional[Callable[[str], Any]] = None, 

119 group: Optional[str] = None, 

120 force: bool = False, 

121 hash_fn: Optional[Callable[[str], str]] = None, 

122 **kwargs, 

123 ): 

124 super().__init__(**kwargs) 

125 

126 self.base_dir = base_dir 

127 self.set_score_references_path(group) 

128 self.group = group 

129 self.biometric_algorithm = biometric_algorithm 

130 self.force = force 

131 self.hash_fn = hash_fn 

132 bob_tags = get_bio_alg_tags(self.biometric_algorithm) 

133 self.extension = extension or bob_tags["bob_enrolled_extension"] 

134 self.save_func = save_func or bob_tags["bob_enrolled_save_fn"] 

135 self.load_func = load_func or bob_tags["bob_enrolled_load_fn"] 

136 

137 def set_score_references_path(self, group): 

138 if group is None: 

139 self.biometric_reference_dir = os.path.join( 

140 self.base_dir, "biometric_references" 

141 ) 

142 else: 

143 self.biometric_reference_dir = os.path.join( 

144 self.base_dir, group, "biometric_references" 

145 ) 

146 

147 def write_biometric_reference(self, sample, path): 

148 data = sample.data 

149 if data is None: 

150 raise RuntimeError("Cannot checkpoint template of None") 

151 return self.save_func(sample.data, path) 

152 

153 def _enroll_sample_set(self, sampleset): 

154 """ 

155 Enroll a sample set with checkpointing 

156 """ 

157 

158 # If sampleset has a key use it, otherwise use the first sample's key 

159 model_key = getattr(sampleset, "key", sampleset.samples[0].key) 

160 

161 # Amending `models` directory 

162 hash_dir_name = ( 

163 self.hash_fn(str(model_key)) if self.hash_fn is not None else "" 

164 ) 

165 

166 path = os.path.join( 

167 self.biometric_reference_dir, 

168 hash_dir_name, 

169 str(model_key) + self.extension, 

170 ) 

171 

172 if self.force or not os.path.exists(path): 

173 enrolled_sample = ( 

174 self.biometric_algorithm.create_templates_from_samplesets( 

175 [sampleset], enroll=True 

176 )[0] 

177 ) 

178 

179 # saving the new sample 

180 os.makedirs(os.path.dirname(path), exist_ok=True) 

181 self.write_biometric_reference(enrolled_sample, path) 

182 

183 # This seems inefficient, but it's crucial for large datasets 

184 delayed_enrolled_sample = DelayedSample( 

185 functools.partial(self.load_func, path), parent=sampleset 

186 ) 

187 

188 return delayed_enrolled_sample 

189 

190 def create_templates_from_samplesets(self, list_of_samplesets, enroll): 

191 logger.debug( 

192 f"{_frmt(self, attr='biometric_algorithm')}.create_templates_from_samplesets(... enroll={enroll})" 

193 ) 

194 if not enroll: 

195 return self.biometric_algorithm.create_templates_from_samplesets( 

196 list_of_samplesets, enroll 

197 ) 

198 retval = [] 

199 for sampleset in list_of_samplesets: 

200 # if it exists, load it! 

201 sample = self._enroll_sample_set(sampleset) 

202 retval.append(sample) 

203 return retval 

204 

205 

206class BioAlgDaskWrapper(BioAlgorithmBaseWrapper): 

207 """ 

208 Wrap :any:`bob.bio.base.pipelines.BioAlgorithm` to work with DASK 

209 """ 

210 

211 def __init__(self, biometric_algorithm: BioAlgorithm, **kwargs): 

212 self.biometric_algorithm = biometric_algorithm 

213 

214 def create_templates_from_samplesets(self, list_of_samplesets, enroll): 

215 logger.debug( 

216 f"{_frmt(self, attr='biometric_algorithm')}.create_templates_from_samplesets(... enroll={enroll})" 

217 ) 

218 templates = list_of_samplesets.map_partitions( 

219 self.biometric_algorithm.create_templates_from_samplesets, 

220 enroll=enroll, 

221 ) 

222 return templates 

223 

224 def score_sample_templates( 

225 self, probe_samples, enroll_samples, score_all_vs_all 

226 ): 

227 logger.debug( 

228 f"{_frmt(self, attr='biometric_algorithm')}.score_sample_templates(... score_all_vs_all={score_all_vs_all})" 

229 ) 

230 # load the templates into memory because they could be delayed samples 

231 enroll_samples = enroll_samples.map_partitions( 

232 _delayed_samples_to_samples 

233 ) 

234 probe_samples = probe_samples.map_partitions( 

235 _delayed_samples_to_samples 

236 ) 

237 

238 all_references = dask.delayed(list)(enroll_samples) 

239 scores = probe_samples.map_partitions( 

240 self.biometric_algorithm.score_sample_templates, 

241 all_references, 

242 score_all_vs_all=score_all_vs_all, 

243 ) 

244 return scores 

245 

246 

247def _delayed_samples_to_samples(delayed_samples): 

248 return [Sample(sample.data, parent=sample) for sample in delayed_samples] 

249 

250 

251def dask_bio_pipeline( 

252 pipeline: PipelineSimple, 

253 npartitions: Optional[int] = None, 

254 partition_size: Optional[int] = None, 

255): 

256 """ 

257 Given a :any:`PipelineSimple`, wraps their :attr:`transformer` and 

258 :attr:`biometric_algorithm` to be executed with dask. 

259 

260 Parameters 

261 ---------- 

262 

263 pipeline 

264 pipeline to be dasked 

265 

266 npartitions 

267 Number of partitions for the initial `dask.bag` 

268 

269 partition_size 

270 Size of the partition for the initial `dask.bag` 

271 """ 

272 dask_wrapper_kw = {} 

273 if partition_size is None: 

274 dask_wrapper_kw["npartitions"] = npartitions 

275 else: 

276 dask_wrapper_kw["partition_size"] = partition_size 

277 

278 pipeline.transformer = bob.pipelines.wrap( 

279 ["dask"], pipeline.transformer, **dask_wrapper_kw 

280 ) 

281 pipeline.biometric_algorithm = BioAlgDaskWrapper( 

282 pipeline.biometric_algorithm 

283 ) 

284 

285 def _write_scores(scores): 

286 return scores.map_partitions(pipeline.write_scores_on_dask) 

287 

288 pipeline.write_scores_on_dask = pipeline.write_scores 

289 pipeline.write_scores = _write_scores 

290 

291 if hasattr(pipeline, "post_processor"): 

292 # cannot use bob.pipelines.wrap here because the input is already a dask bag. 

293 pipeline.post_processor = bob.pipelines.DaskWrapper( 

294 pipeline.post_processor 

295 ) 

296 

297 return pipeline 

298 

299 

300def checkpoint_pipeline_simple( 

301 pipeline: PipelineSimple, 

302 base_dir: str, 

303 biometric_algorithm_dir: Optional[str] = None, 

304 hash_fn: Optional[Callable[[str], str]] = None, 

305 force: bool = False, 

306): 

307 """ 

308 Given a :any:`PipelineSimple`, wraps their :attr:`transformer` and 

309 :attr:`biometric_algorithm` to be checkpointed. 

310 

311 If an estimator of the pipeline is already checkpointed, it will not be 

312 wrapped again. 

313 

314 Parameters 

315 ---------- 

316 

317 pipeline 

318 pipeline to be checkpointed 

319 

320 base_dir 

321 Path to store transformed input data and possibly biometric references and scores 

322 

323 biometric_algorithm_dir 

324 If set, it will checkpoint the biometric references and scores to this path. 

325 If not, `base_dir` will be used. 

326 This is useful when it's suitable to have the transformed data path, and biometric references and scores 

327 in different paths. 

328 

329 hash_fn 

330 Pointer to a hash function. This hash function will map 

331 `sample.key` to a hash code and this hash code will be the 

332 relative directory where a single `sample` will be checkpointed. 

333 This is useful when is desireable file directories with more than 

334 a certain number of files. 

335 force 

336 Overwrite existing checkpoint files. 

337 """ 

338 

339 bio_ref_scores_dir = ( 

340 base_dir if biometric_algorithm_dir is None else biometric_algorithm_dir 

341 ) 

342 

343 if isinstance(pipeline.transformer, Pipeline): 

344 for step, (name, transformer) in enumerate(pipeline.transformer.steps): 

345 if not is_instance_nested( 

346 transformer, "estimator", CheckpointWrapper 

347 ): 

348 pipeline.transformer.steps[step] = ( 

349 name, 

350 bob.pipelines.wrap( 

351 ["checkpoint"], 

352 transformer, 

353 features_dir=os.path.join(base_dir, name), 

354 model_path=os.path.join(base_dir, name), 

355 hash_fn=hash_fn, 

356 force=force, 

357 ), 

358 ) 

359 else: # The pipeline.transformer is a lone transformer 

360 if not is_instance_nested( 

361 pipeline.transformer, "estimator", CheckpointWrapper 

362 ): 

363 pipeline.transformer = bob.pipelines.wrap( 

364 ["checkpoint"], 

365 pipeline.transformer, 

366 features_dir=base_dir, 

367 model_path=base_dir, 

368 hash_fn=hash_fn, 

369 force=force, 

370 ) 

371 

372 pipeline.biometric_algorithm = BioAlgCheckpointWrapper( 

373 pipeline.biometric_algorithm, 

374 base_dir=bio_ref_scores_dir, 

375 hash_fn=hash_fn, 

376 force=force, 

377 ) 

378 

379 return pipeline 

380 

381 

382def is_biopipeline_checkpointed(pipeline: PipelineSimple) -> bool: 

383 """ 

384 Check if :any:`PipelineSimple` is checkpointed 

385 

386 Parameters 

387 ---------- 

388 

389 pipeline 

390 pipeline to check if checkpointed by a :any:`BioAlgCheckpointWrapper`. 

391 

392 """ 

393 

394 # We have to check if biometric_algorithm is checkpointed 

395 return is_instance_nested( 

396 pipeline, "biometric_algorithm", BioAlgCheckpointWrapper 

397 )