Coverage for src/bob/bio/base/annotator/FailSafe.py: 77%

48 statements  

« prev     ^ index     » next       coverage.py v7.6.0, created at 2024-07-12 22:34 +0200

1import logging 

2 

3from .. import load_resource 

4from . import Annotator 

5 

6logger = logging.getLogger(__name__) 

7 

8 

9class FailSafe(Annotator): 

10 """A fail-safe annotator. 

11 This annotator takes a list of annotator and tries them until you get your 

12 annotations. 

13 The annotations of previous annotator is passed to the next one. 

14 

15 Attributes 

16 ---------- 

17 annotators : list 

18 A list of annotators to try 

19 required_keys : list 

20 A list of keys that should be available in annotations to stop trying 

21 different annotators. 

22 only_required_keys : bool 

23 If True, the annotations will only contain the ``required_keys``. 

24 """ 

25 

26 def __init__( 

27 self, annotators, required_keys, only_required_keys=False, **kwargs 

28 ): 

29 super(FailSafe, self).__init__(**kwargs) 

30 self.annotators = [] 

31 for annotator in annotators: 

32 if isinstance(annotator, str): 

33 annotator = load_resource(annotator, "annotator") 

34 self.annotators.append(annotator) 

35 self.required_keys = list(required_keys) 

36 self.only_required_keys = only_required_keys 

37 

38 def annotate(self, sample, **kwargs): 

39 if "annotations" not in kwargs or kwargs["annotations"] is None: 

40 kwargs["annotations"] = {} 

41 for annotator in self.annotators: 

42 try: 

43 annotations = annotator.transform( 

44 [sample], **{k: [v] for k, v in kwargs.items()} 

45 )[0] 

46 except Exception: 

47 logger.debug( 

48 "The annotator `%s' failed to annotate!", 

49 annotator, 

50 exc_info=True, 

51 ) 

52 annotations = None 

53 if not annotations: 

54 logger.debug( 

55 "Annotator `%s' returned empty annotations.", annotator 

56 ) 

57 else: 

58 logger.debug("Annotator `%s' succeeded!", annotator) 

59 kwargs["annotations"].update(annotations or {}) 

60 # check if we have all the required annotations 

61 if all(key in kwargs["annotations"] for key in self.required_keys): 

62 break 

63 else: # this else is for the for loop 

64 # we don't want to return half of the annotations 

65 kwargs["annotations"] = None 

66 if self.only_required_keys: 

67 for key in list(kwargs["annotations"].keys()): 

68 if key not in self.required_keys: 

69 del kwargs["annotations"][key] 

70 return kwargs["annotations"] 

71 

72 def transform(self, samples, **kwargs): 

73 """ 

74 Takes a batch of data and tries annotating them while unsuccessful. 

75 

76 Tries each annotator given at the creation of FailSafe when the previous 

77 one fails. 

78 

79 Each ``kwargs`` value is a list of parameters, with each element of those 

80 lists corresponding to each element of ``sample_batch`` (for example: 

81 with ``[s1, s2, ...]`` as ``samples_batch``, ``kwargs['annotations']`` 

82 should contain ``[{<s1_annotations>}, {<s2_annotations>}, ...]``). 

83 """ 

84 kwargs = translate_kwargs(kwargs, len(samples)) 

85 return [ 

86 self.annotate(sample, **kw) for sample, kw in zip(samples, kwargs) 

87 ] 

88 

89 

90def translate_kwargs(kwargs, size): 

91 new_kwargs = [{}] * size 

92 

93 if not kwargs: 

94 return new_kwargs 

95 

96 for k, value_list in kwargs.items(): 

97 if len(value_list) != size: 

98 raise ValueError( 

99 f"Got {value_list} in kwargs which is not of the same length of samples {size}" 

100 ) 

101 for kw, v in zip(new_kwargs, value_list): 

102 kw[k] = v 

103 

104 return new_kwargs