Coverage for src/bob/pad/face/database/casia_surf.py: 42%

52 statements  

« prev     ^ index     » next       coverage.py v7.6.5, created at 2024-11-14 23:14 +0100

1import logging 

2import os 

3 

4from functools import partial 

5 

6from clapper.rc import UserDefaults 

7from sklearn.preprocessing import FunctionTransformer 

8 

9import bob.io.base 

10 

11from bob.bio.video import VideoLikeContainer 

12from bob.pad.base.database import FileListPadDatabase 

13from bob.pipelines import CSVToSamples, DelayedSample 

14 

15logger = logging.getLogger(__name__) 

16rc = UserDefaults("bobrc.toml") 

17 

18 

19def load_multi_stream(path): 

20 data = bob.io.base.load(path) 

21 video = VideoLikeContainer(data[None, ...], [0]) 

22 return video 

23 

24 

25def casia_surf_multistream_load(samples, original_directory): 

26 mod_to_attr = {} 

27 mod_to_attr["color"] = "filename" 

28 mod_to_attr["infrared"] = "ir_filename" 

29 mod_to_attr["depth"] = "depth_filename" 

30 mods = list(mod_to_attr.keys()) 

31 

32 def _load(sample): 

33 paths = dict() 

34 for mod in mods: 

35 paths[mod] = os.path.join( 

36 original_directory or "", getattr(sample, mod_to_attr[mod]) 

37 ) 

38 data = partial(load_multi_stream, paths["color"]) 

39 depth = partial(load_multi_stream, paths["depth"]) 

40 infrared = partial(load_multi_stream, paths["infrared"]) 

41 subject = None 

42 key = sample.filename 

43 is_bonafide = sample.is_bonafide == "1" 

44 attack_type = None if is_bonafide else "attack" 

45 

46 return DelayedSample( 

47 data, 

48 parent=sample, 

49 subject=subject, 

50 key=key, 

51 attack_type=attack_type, 

52 is_bonafide=is_bonafide, 

53 annotations=None, 

54 delayed_attributes={"depth": depth, "infrared": infrared}, 

55 ) 

56 

57 return [_load(s) for s in samples] 

58 

59 

60def CasiaSurfMultiStreamSample(original_directory): 

61 return FunctionTransformer( 

62 casia_surf_multistream_load, 

63 kw_args=dict(original_directory=original_directory), 

64 ) 

65 

66 

67class CasiaSurfPadDatabase(FileListPadDatabase): 

68 """The CASIA SURF Face PAD database interface. 

69 

70 Parameters 

71 ---------- 

72 stream_type : str 

73 A str or a list of str of the following choices: ``all``, ``color``, ``depth``, ``infrared``, by default ``all`` 

74 

75 The returned sample either have their data as a VideoLikeContainer or 

76 a dict of VideoLikeContainers depending on the chosen stream_type. 

77 """ 

78 

79 def __init__( 

80 self, 

81 **kwargs, 

82 ): 

83 original_directory = rc.get("bob.db.casia_surf.directory") 

84 if original_directory is None or not os.path.isdir(original_directory): 

85 raise FileNotFoundError( 

86 "The original_directory is not set. Please set it in the terminal using `bob config set bob.db.casia_surf.directory /path/to/database/CASIA-SURF/`." 

87 ) 

88 transformer = CasiaSurfMultiStreamSample( 

89 original_directory=original_directory, 

90 ) 

91 super().__init__( 

92 name="casia-surf", 

93 dataset_protocols_path=original_directory, 

94 protocol="all", 

95 reader_cls=partial( 

96 CSVToSamples, 

97 dict_reader_kwargs=dict( 

98 delimiter=" ", 

99 fieldnames=[ 

100 "filename", 

101 "ir_filename", 

102 "depth_filename", 

103 "is_bonafide", 

104 ], 

105 ), 

106 ), 

107 transformer=transformer, 

108 **kwargs, 

109 ) 

110 self.annotation_type = None 

111 self.fixed_positions = None 

112 

113 def protocols(self): 

114 return ["all"] 

115 

116 def groups(self): 

117 return ["train", "dev", "eval"] 

118 

119 def list_file(self, group): 

120 filename = { 

121 "train": "train_list.txt", 

122 "dev": "val_private_list.txt", 

123 "eval": "test_private_list.txt", 

124 }[group] 

125 return os.path.join(self.dataset_protocols_path, filename)