Coverage for src/bob/bio/base/annotator/FailSafe.py: 77%
48 statements
« prev ^ index » next coverage.py v7.6.0, created at 2024-07-12 22:34 +0200
« prev ^ index » next coverage.py v7.6.0, created at 2024-07-12 22:34 +0200
1import logging
3from .. import load_resource
4from . import Annotator
6logger = logging.getLogger(__name__)
9class FailSafe(Annotator):
10 """A fail-safe annotator.
11 This annotator takes a list of annotator and tries them until you get your
12 annotations.
13 The annotations of previous annotator is passed to the next one.
15 Attributes
16 ----------
17 annotators : list
18 A list of annotators to try
19 required_keys : list
20 A list of keys that should be available in annotations to stop trying
21 different annotators.
22 only_required_keys : bool
23 If True, the annotations will only contain the ``required_keys``.
24 """
26 def __init__(
27 self, annotators, required_keys, only_required_keys=False, **kwargs
28 ):
29 super(FailSafe, self).__init__(**kwargs)
30 self.annotators = []
31 for annotator in annotators:
32 if isinstance(annotator, str):
33 annotator = load_resource(annotator, "annotator")
34 self.annotators.append(annotator)
35 self.required_keys = list(required_keys)
36 self.only_required_keys = only_required_keys
38 def annotate(self, sample, **kwargs):
39 if "annotations" not in kwargs or kwargs["annotations"] is None:
40 kwargs["annotations"] = {}
41 for annotator in self.annotators:
42 try:
43 annotations = annotator.transform(
44 [sample], **{k: [v] for k, v in kwargs.items()}
45 )[0]
46 except Exception:
47 logger.debug(
48 "The annotator `%s' failed to annotate!",
49 annotator,
50 exc_info=True,
51 )
52 annotations = None
53 if not annotations:
54 logger.debug(
55 "Annotator `%s' returned empty annotations.", annotator
56 )
57 else:
58 logger.debug("Annotator `%s' succeeded!", annotator)
59 kwargs["annotations"].update(annotations or {})
60 # check if we have all the required annotations
61 if all(key in kwargs["annotations"] for key in self.required_keys):
62 break
63 else: # this else is for the for loop
64 # we don't want to return half of the annotations
65 kwargs["annotations"] = None
66 if self.only_required_keys:
67 for key in list(kwargs["annotations"].keys()):
68 if key not in self.required_keys:
69 del kwargs["annotations"][key]
70 return kwargs["annotations"]
72 def transform(self, samples, **kwargs):
73 """
74 Takes a batch of data and tries annotating them while unsuccessful.
76 Tries each annotator given at the creation of FailSafe when the previous
77 one fails.
79 Each ``kwargs`` value is a list of parameters, with each element of those
80 lists corresponding to each element of ``sample_batch`` (for example:
81 with ``[s1, s2, ...]`` as ``samples_batch``, ``kwargs['annotations']``
82 should contain ``[{<s1_annotations>}, {<s2_annotations>}, ...]``).
83 """
84 kwargs = translate_kwargs(kwargs, len(samples))
85 return [
86 self.annotate(sample, **kw) for sample, kw in zip(samples, kwargs)
87 ]
90def translate_kwargs(kwargs, size):
91 new_kwargs = [{}] * size
93 if not kwargs:
94 return new_kwargs
96 for k, value_list in kwargs.items():
97 if len(value_list) != size:
98 raise ValueError(
99 f"Got {value_list} in kwargs which is not of the same length of samples {size}"
100 )
101 for kw, v in zip(new_kwargs, value_list):
102 kw[k] = v
104 return new_kwargs