Coverage for src/bob/bio/base/pipelines/wrappers.py: 95%
115 statements
« prev ^ index » next coverage.py v7.6.0, created at 2024-07-12 22:34 +0200
« prev ^ index » next coverage.py v7.6.0, created at 2024-07-12 22:34 +0200
1import functools
2import logging
3import os
5from typing import Any, Callable, Optional
7import dask
8import h5py
9import numpy as np
11from sklearn.pipeline import Pipeline
13import bob.pipelines
15from bob.bio.base.pipelines import PipelineSimple
16from bob.pipelines import (
17 CheckpointWrapper,
18 DelayedSample,
19 Sample,
20 is_instance_nested,
21)
22from bob.pipelines.wrappers import BaseWrapper, _frmt, get_bob_tags
24from .abstract_classes import BioAlgorithm
26logger = logging.getLogger(__name__)
29def default_save(data: np.ndarray, path: str):
30 os.makedirs(os.path.dirname(path), exist_ok=True)
31 with h5py.File(path, "w") as f:
32 f["data"] = data
35def default_load(path: str) -> np.ndarray:
36 with h5py.File(path, "r") as f:
37 return f["data"][()]
40def get_bio_alg_tags(estimator=None, force_tags=None):
41 bob_tags = get_bob_tags(estimator=estimator, force_tags=force_tags)
42 default_tags = {
43 "bob_enrolled_extension": ".h5",
44 "bob_enrolled_save_fn": default_save,
45 "bob_enrolled_load_fn": default_load,
46 }
47 force_tags = force_tags or {}
48 estimator_tags = estimator._get_tags() if estimator is not None else {}
49 return {**bob_tags, **default_tags, **estimator_tags, **force_tags}
52class BioAlgorithmBaseWrapper(BioAlgorithm, BaseWrapper):
53 def create_templates(self, feature_sets, enroll):
54 return self.biometric_algorithm.create_templates(feature_sets, enroll)
56 def compare(self, enroll_templates, probe_templates):
57 return self.biometric_algorithm.compare(
58 enroll_templates, probe_templates
59 )
62class BioAlgCheckpointWrapper(BioAlgorithmBaseWrapper):
63 """Wrapper used to checkpoint enrolled and Scoring samples.
65 Parameters
66 ----------
67 biometric_algorithm
68 An implemented :any:`BioAlgorithm`
70 base_dir
71 Path to store biometric references and scores
73 extension
74 Default extension of the enrolled references files.
75 If None, will use the ``bob_checkpoint_extension`` tag in the estimator, or
76 default to ``.h5``.
78 save_func
79 Pointer to a customized function that saves an enrolled reference to the disk.
80 If None, will use the ``bob_enrolled_save_fn`` tag in the estimator, or default
81 to h5py.
83 load_func
84 Pointer to a customized function that loads an enrolled reference from disk.
85 If None, will use the ``bob_enrolled_load_fn`` tag in the estimator, or default
86 to h5py.
88 group
89 group of the data, used to save different group's checkpoints in different dirs.
91 force
92 If True, will recompute scores and biometric references no matter if a file
93 exists
95 hash_fn
96 Pointer to a hash function. This hash function maps
97 `sample.key` to a hash code and this hash code corresponds a relative directory
98 where a single `sample` will be checkpointed.
99 This is useful when is desirable file directories with less than a certain
100 number of files.
102 Examples
103 --------
105 >>> from bob.bio.base.algorithm import Distance
106 >>> from bob.bio.base.pipelines import BioAlgCheckpointWrapper
107 >>> biometric_algorithm = BioAlgCheckpointWrapper(Distance(), base_dir="./")
108 >>> biometric_algorithm.create_templates(samples, enroll=True) # doctest: +SKIP
110 """
112 def __init__(
113 self,
114 biometric_algorithm: BioAlgorithm,
115 base_dir: str,
116 extension: Optional[str] = None,
117 save_func: Optional[Callable[[str, Any], None]] = None,
118 load_func: Optional[Callable[[str], Any]] = None,
119 group: Optional[str] = None,
120 force: bool = False,
121 hash_fn: Optional[Callable[[str], str]] = None,
122 **kwargs,
123 ):
124 super().__init__(**kwargs)
126 self.base_dir = base_dir
127 self.set_score_references_path(group)
128 self.group = group
129 self.biometric_algorithm = biometric_algorithm
130 self.force = force
131 self.hash_fn = hash_fn
132 bob_tags = get_bio_alg_tags(self.biometric_algorithm)
133 self.extension = extension or bob_tags["bob_enrolled_extension"]
134 self.save_func = save_func or bob_tags["bob_enrolled_save_fn"]
135 self.load_func = load_func or bob_tags["bob_enrolled_load_fn"]
137 def set_score_references_path(self, group):
138 if group is None:
139 self.biometric_reference_dir = os.path.join(
140 self.base_dir, "biometric_references"
141 )
142 else:
143 self.biometric_reference_dir = os.path.join(
144 self.base_dir, group, "biometric_references"
145 )
147 def write_biometric_reference(self, sample, path):
148 data = sample.data
149 if data is None:
150 raise RuntimeError("Cannot checkpoint template of None")
151 return self.save_func(sample.data, path)
153 def _enroll_sample_set(self, sampleset):
154 """
155 Enroll a sample set with checkpointing
156 """
158 # If sampleset has a key use it, otherwise use the first sample's key
159 model_key = getattr(sampleset, "key", sampleset.samples[0].key)
161 # Amending `models` directory
162 hash_dir_name = (
163 self.hash_fn(str(model_key)) if self.hash_fn is not None else ""
164 )
166 path = os.path.join(
167 self.biometric_reference_dir,
168 hash_dir_name,
169 str(model_key) + self.extension,
170 )
172 if self.force or not os.path.exists(path):
173 enrolled_sample = (
174 self.biometric_algorithm.create_templates_from_samplesets(
175 [sampleset], enroll=True
176 )[0]
177 )
179 # saving the new sample
180 os.makedirs(os.path.dirname(path), exist_ok=True)
181 self.write_biometric_reference(enrolled_sample, path)
183 # This seems inefficient, but it's crucial for large datasets
184 delayed_enrolled_sample = DelayedSample(
185 functools.partial(self.load_func, path), parent=sampleset
186 )
188 return delayed_enrolled_sample
190 def create_templates_from_samplesets(self, list_of_samplesets, enroll):
191 logger.debug(
192 f"{_frmt(self, attr='biometric_algorithm')}.create_templates_from_samplesets(... enroll={enroll})"
193 )
194 if not enroll:
195 return self.biometric_algorithm.create_templates_from_samplesets(
196 list_of_samplesets, enroll
197 )
198 retval = []
199 for sampleset in list_of_samplesets:
200 # if it exists, load it!
201 sample = self._enroll_sample_set(sampleset)
202 retval.append(sample)
203 return retval
206class BioAlgDaskWrapper(BioAlgorithmBaseWrapper):
207 """
208 Wrap :any:`bob.bio.base.pipelines.BioAlgorithm` to work with DASK
209 """
211 def __init__(self, biometric_algorithm: BioAlgorithm, **kwargs):
212 self.biometric_algorithm = biometric_algorithm
214 def create_templates_from_samplesets(self, list_of_samplesets, enroll):
215 logger.debug(
216 f"{_frmt(self, attr='biometric_algorithm')}.create_templates_from_samplesets(... enroll={enroll})"
217 )
218 templates = list_of_samplesets.map_partitions(
219 self.biometric_algorithm.create_templates_from_samplesets,
220 enroll=enroll,
221 )
222 return templates
224 def score_sample_templates(
225 self, probe_samples, enroll_samples, score_all_vs_all
226 ):
227 logger.debug(
228 f"{_frmt(self, attr='biometric_algorithm')}.score_sample_templates(... score_all_vs_all={score_all_vs_all})"
229 )
230 # load the templates into memory because they could be delayed samples
231 enroll_samples = enroll_samples.map_partitions(
232 _delayed_samples_to_samples
233 )
234 probe_samples = probe_samples.map_partitions(
235 _delayed_samples_to_samples
236 )
238 all_references = dask.delayed(list)(enroll_samples)
239 scores = probe_samples.map_partitions(
240 self.biometric_algorithm.score_sample_templates,
241 all_references,
242 score_all_vs_all=score_all_vs_all,
243 )
244 return scores
247def _delayed_samples_to_samples(delayed_samples):
248 return [Sample(sample.data, parent=sample) for sample in delayed_samples]
251def dask_bio_pipeline(
252 pipeline: PipelineSimple,
253 npartitions: Optional[int] = None,
254 partition_size: Optional[int] = None,
255):
256 """
257 Given a :any:`PipelineSimple`, wraps their :attr:`transformer` and
258 :attr:`biometric_algorithm` to be executed with dask.
260 Parameters
261 ----------
263 pipeline
264 pipeline to be dasked
266 npartitions
267 Number of partitions for the initial `dask.bag`
269 partition_size
270 Size of the partition for the initial `dask.bag`
271 """
272 dask_wrapper_kw = {}
273 if partition_size is None:
274 dask_wrapper_kw["npartitions"] = npartitions
275 else:
276 dask_wrapper_kw["partition_size"] = partition_size
278 pipeline.transformer = bob.pipelines.wrap(
279 ["dask"], pipeline.transformer, **dask_wrapper_kw
280 )
281 pipeline.biometric_algorithm = BioAlgDaskWrapper(
282 pipeline.biometric_algorithm
283 )
285 def _write_scores(scores):
286 return scores.map_partitions(pipeline.write_scores_on_dask)
288 pipeline.write_scores_on_dask = pipeline.write_scores
289 pipeline.write_scores = _write_scores
291 if hasattr(pipeline, "post_processor"):
292 # cannot use bob.pipelines.wrap here because the input is already a dask bag.
293 pipeline.post_processor = bob.pipelines.DaskWrapper(
294 pipeline.post_processor
295 )
297 return pipeline
300def checkpoint_pipeline_simple(
301 pipeline: PipelineSimple,
302 base_dir: str,
303 biometric_algorithm_dir: Optional[str] = None,
304 hash_fn: Optional[Callable[[str], str]] = None,
305 force: bool = False,
306):
307 """
308 Given a :any:`PipelineSimple`, wraps their :attr:`transformer` and
309 :attr:`biometric_algorithm` to be checkpointed.
311 If an estimator of the pipeline is already checkpointed, it will not be
312 wrapped again.
314 Parameters
315 ----------
317 pipeline
318 pipeline to be checkpointed
320 base_dir
321 Path to store transformed input data and possibly biometric references and scores
323 biometric_algorithm_dir
324 If set, it will checkpoint the biometric references and scores to this path.
325 If not, `base_dir` will be used.
326 This is useful when it's suitable to have the transformed data path, and biometric references and scores
327 in different paths.
329 hash_fn
330 Pointer to a hash function. This hash function will map
331 `sample.key` to a hash code and this hash code will be the
332 relative directory where a single `sample` will be checkpointed.
333 This is useful when is desireable file directories with more than
334 a certain number of files.
335 force
336 Overwrite existing checkpoint files.
337 """
339 bio_ref_scores_dir = (
340 base_dir if biometric_algorithm_dir is None else biometric_algorithm_dir
341 )
343 if isinstance(pipeline.transformer, Pipeline):
344 for step, (name, transformer) in enumerate(pipeline.transformer.steps):
345 if not is_instance_nested(
346 transformer, "estimator", CheckpointWrapper
347 ):
348 pipeline.transformer.steps[step] = (
349 name,
350 bob.pipelines.wrap(
351 ["checkpoint"],
352 transformer,
353 features_dir=os.path.join(base_dir, name),
354 model_path=os.path.join(base_dir, name),
355 hash_fn=hash_fn,
356 force=force,
357 ),
358 )
359 else: # The pipeline.transformer is a lone transformer
360 if not is_instance_nested(
361 pipeline.transformer, "estimator", CheckpointWrapper
362 ):
363 pipeline.transformer = bob.pipelines.wrap(
364 ["checkpoint"],
365 pipeline.transformer,
366 features_dir=base_dir,
367 model_path=base_dir,
368 hash_fn=hash_fn,
369 force=force,
370 )
372 pipeline.biometric_algorithm = BioAlgCheckpointWrapper(
373 pipeline.biometric_algorithm,
374 base_dir=bio_ref_scores_dir,
375 hash_fn=hash_fn,
376 force=force,
377 )
379 return pipeline
382def is_biopipeline_checkpointed(pipeline: PipelineSimple) -> bool:
383 """
384 Check if :any:`PipelineSimple` is checkpointed
386 Parameters
387 ----------
389 pipeline
390 pipeline to check if checkpointed by a :any:`BioAlgCheckpointWrapper`.
392 """
394 # We have to check if biometric_algorithm is checkpointed
395 return is_instance_nested(
396 pipeline, "biometric_algorithm", BioAlgCheckpointWrapper
397 )