Coverage for src/bob/bio/video/transformer.py: 80%
45 statements
« prev ^ index » next coverage.py v7.6.5, created at 2024-11-14 22:56 +0100
« prev ^ index » next coverage.py v7.6.5, created at 2024-11-14 22:56 +0100
1import logging
3from sklearn.base import BaseEstimator, TransformerMixin
5from bob.pipelines.wrappers import _check_n_input_output, _frmt
7from . import utils
9logger = logging.getLogger(__name__)
12class VideoWrapper(TransformerMixin, BaseEstimator):
13 """Wrapper class to run image preprocessing algorithms on video data.
15 **Parameters:**
17 estimator : str or ``sklearn.base.BaseEstimator`` instance
18 The transformer to be used to preprocess the frames.
19 """
21 def __init__(
22 self,
23 estimator,
24 **kwargs,
25 ):
26 super().__init__(**kwargs)
27 self.estimator = estimator
29 def transform(self, videos, **kwargs):
30 transformed_videos = []
31 for i, video in enumerate(videos):
32 if not hasattr(video, "indices"):
33 raise ValueError(
34 f"The input video: {video}\n does not have indices.\n "
35 f"Processing failed in {self}"
36 )
38 kw = {}
39 if kwargs:
40 kw = {k: v[i] for k, v in kwargs.items()}
41 if "annotations" in kw and kw["annotations"] is not None:
42 kw["annotations"] = [
43 kw["annotations"].get(
44 index, kw["annotations"].get(str(index))
45 )
46 for index in video.indices
47 ]
49 # remove None's before calling and add them back in data later
50 # Isolate invalid samples (when previous transformers returned None)
51 invalid_ids = [i for i, frame in enumerate(video) if frame is None]
52 valid_frames = [frame for frame in video if frame is not None]
54 # remove invalid kw args as well
55 for k, v in kw.items():
56 if v is None:
57 continue
58 kw[k] = [vv for j, vv in enumerate(v) if j not in invalid_ids]
60 # Process only the valid samples
61 output = None
62 if len(valid_frames) > 0:
63 output = self.estimator.transform(valid_frames, **kw)
64 _check_n_input_output(
65 valid_frames, output, f"{_frmt(self.estimator)}.transform"
66 )
68 if output is None:
69 output = [None] * len(valid_frames)
71 # Rebuild the full batch of samples (include the previously failed)
72 if len(invalid_ids) > 0:
73 output = list(output)
74 for j in invalid_ids:
75 output.insert(j, None)
77 data = utils.VideoLikeContainer(output, video.indices)
78 transformed_videos.append(data)
79 return transformed_videos
81 def _more_tags(self):
82 tags = self.estimator._get_tags()
83 tags["bob_features_save_fn"] = utils.VideoLikeContainer.save_function
84 tags["bob_features_load_fn"] = utils.VideoLikeContainer.load
85 return tags
87 def fit(self, X, y=None, **fit_params):
88 """Does nothing"""
89 return self