Coverage for src/bob/pad/base/database/csv_dataset.py: 33%

33 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2023-12-06 21:56 +0100

1#!/usr/bin/env python 

2# vim: set fileencoding=utf-8 : 

3 

4 

5from bob.bio.base.database.legacy import check_parameters_for_validity 

6from bob.pad.base.pipelines.abstract_classes import Database 

7from bob.pipelines.dataset import FileListDatabase 

8 

9 

10def validate_pad_sample(sample): 

11 if not hasattr(sample, "subject"): 

12 raise RuntimeError( 

13 "PAD samples should contain a `subject` attribute which " 

14 "reveals the identifies the person from whom the sample is created." 

15 ) 

16 if not hasattr(sample, "attack_type"): 

17 raise RuntimeError( 

18 "PAD samples should contain a `attack_type` attribute which " 

19 "should be '' for bona fide samples and something like " 

20 "print, replay, mask, etc. for attacks. This attribute is " 

21 "considered the PAI type of each attack is used to compute APCER." 

22 ) 

23 if sample.attack_type == "": 

24 sample.attack_type = None 

25 sample.is_bonafide = sample.attack_type is None 

26 if not hasattr(sample, "key"): 

27 sample.key = sample.filename 

28 return sample 

29 

30 

31class FileListPadDatabase(Database, FileListDatabase): 

32 """A PAD database interface from CSV files.""" 

33 

34 def __init__( 

35 self, 

36 name, 

37 dataset_protocols_path, 

38 protocol, 

39 transformer=None, 

40 **kwargs, 

41 ): 

42 super().__init__( 

43 name=name, 

44 dataset_protocols_path=dataset_protocols_path, 

45 protocol=protocol, 

46 transformer=transformer, 

47 **kwargs, 

48 ) 

49 

50 def __repr__(self) -> str: 

51 return "FileListPadDatabase(dataset_protocols_path='{}', protocol='{}', transformer={})".format( 

52 self.dataset_protocols_path, self.protocol, self.transformer 

53 ) 

54 

55 def purposes(self): 

56 return ("real", "attack") 

57 

58 def samples(self, groups=None, purposes=None): 

59 results = super().samples(groups=groups) 

60 purposes = check_parameters_for_validity( 

61 purposes, "purposes", self.purposes(), self.purposes() 

62 ) 

63 

64 def _filter(s): 

65 return (s.is_bonafide and "real" in purposes) or ( 

66 (not s.is_bonafide) and "attack" in purposes 

67 ) 

68 

69 results = [validate_pad_sample(sample) for sample in results] 

70 results = list(filter(_filter, results)) 

71 return results 

72 

73 def fit_samples(self): 

74 return self.samples(groups="train") 

75 

76 def predict_samples(self, group="dev"): 

77 return self.samples(groups=group)