Coverage for src/bob/pad/face/database/casia_surf.py: 42%
52 statements
« prev ^ index » next coverage.py v7.6.5, created at 2024-11-14 23:14 +0100
« prev ^ index » next coverage.py v7.6.5, created at 2024-11-14 23:14 +0100
1import logging
2import os
4from functools import partial
6from clapper.rc import UserDefaults
7from sklearn.preprocessing import FunctionTransformer
9import bob.io.base
11from bob.bio.video import VideoLikeContainer
12from bob.pad.base.database import FileListPadDatabase
13from bob.pipelines import CSVToSamples, DelayedSample
15logger = logging.getLogger(__name__)
16rc = UserDefaults("bobrc.toml")
19def load_multi_stream(path):
20 data = bob.io.base.load(path)
21 video = VideoLikeContainer(data[None, ...], [0])
22 return video
25def casia_surf_multistream_load(samples, original_directory):
26 mod_to_attr = {}
27 mod_to_attr["color"] = "filename"
28 mod_to_attr["infrared"] = "ir_filename"
29 mod_to_attr["depth"] = "depth_filename"
30 mods = list(mod_to_attr.keys())
32 def _load(sample):
33 paths = dict()
34 for mod in mods:
35 paths[mod] = os.path.join(
36 original_directory or "", getattr(sample, mod_to_attr[mod])
37 )
38 data = partial(load_multi_stream, paths["color"])
39 depth = partial(load_multi_stream, paths["depth"])
40 infrared = partial(load_multi_stream, paths["infrared"])
41 subject = None
42 key = sample.filename
43 is_bonafide = sample.is_bonafide == "1"
44 attack_type = None if is_bonafide else "attack"
46 return DelayedSample(
47 data,
48 parent=sample,
49 subject=subject,
50 key=key,
51 attack_type=attack_type,
52 is_bonafide=is_bonafide,
53 annotations=None,
54 delayed_attributes={"depth": depth, "infrared": infrared},
55 )
57 return [_load(s) for s in samples]
60def CasiaSurfMultiStreamSample(original_directory):
61 return FunctionTransformer(
62 casia_surf_multistream_load,
63 kw_args=dict(original_directory=original_directory),
64 )
67class CasiaSurfPadDatabase(FileListPadDatabase):
68 """The CASIA SURF Face PAD database interface.
70 Parameters
71 ----------
72 stream_type : str
73 A str or a list of str of the following choices: ``all``, ``color``, ``depth``, ``infrared``, by default ``all``
75 The returned sample either have their data as a VideoLikeContainer or
76 a dict of VideoLikeContainers depending on the chosen stream_type.
77 """
79 def __init__(
80 self,
81 **kwargs,
82 ):
83 original_directory = rc.get("bob.db.casia_surf.directory")
84 if original_directory is None or not os.path.isdir(original_directory):
85 raise FileNotFoundError(
86 "The original_directory is not set. Please set it in the terminal using `bob config set bob.db.casia_surf.directory /path/to/database/CASIA-SURF/`."
87 )
88 transformer = CasiaSurfMultiStreamSample(
89 original_directory=original_directory,
90 )
91 super().__init__(
92 name="casia-surf",
93 dataset_protocols_path=original_directory,
94 protocol="all",
95 reader_cls=partial(
96 CSVToSamples,
97 dict_reader_kwargs=dict(
98 delimiter=" ",
99 fieldnames=[
100 "filename",
101 "ir_filename",
102 "depth_filename",
103 "is_bonafide",
104 ],
105 ),
106 ),
107 transformer=transformer,
108 **kwargs,
109 )
110 self.annotation_type = None
111 self.fixed_positions = None
113 def protocols(self):
114 return ["all"]
116 def groups(self):
117 return ["train", "dev", "eval"]
119 def list_file(self, group):
120 filename = {
121 "train": "train_list.txt",
122 "dev": "val_private_list.txt",
123 "eval": "test_private_list.txt",
124 }[group]
125 return os.path.join(self.dataset_protocols_path, filename)