Coverage for src/bob/bio/video/transformer.py: 80%

45 statements  

« prev     ^ index     » next       coverage.py v7.6.0, created at 2024-07-13 00:44 +0200

1import logging 

2 

3from sklearn.base import BaseEstimator, TransformerMixin 

4 

5from bob.pipelines.wrappers import _check_n_input_output, _frmt 

6 

7from . import utils 

8 

9logger = logging.getLogger(__name__) 

10 

11 

12class VideoWrapper(TransformerMixin, BaseEstimator): 

13 """Wrapper class to run image preprocessing algorithms on video data. 

14 

15 **Parameters:** 

16 

17 estimator : str or ``sklearn.base.BaseEstimator`` instance 

18 The transformer to be used to preprocess the frames. 

19 """ 

20 

21 def __init__( 

22 self, 

23 estimator, 

24 **kwargs, 

25 ): 

26 super().__init__(**kwargs) 

27 self.estimator = estimator 

28 

29 def transform(self, videos, **kwargs): 

30 transformed_videos = [] 

31 for i, video in enumerate(videos): 

32 if not hasattr(video, "indices"): 

33 raise ValueError( 

34 f"The input video: {video}\n does not have indices.\n " 

35 f"Processing failed in {self}" 

36 ) 

37 

38 kw = {} 

39 if kwargs: 

40 kw = {k: v[i] for k, v in kwargs.items()} 

41 if "annotations" in kw and kw["annotations"] is not None: 

42 kw["annotations"] = [ 

43 kw["annotations"].get( 

44 index, kw["annotations"].get(str(index)) 

45 ) 

46 for index in video.indices 

47 ] 

48 

49 # remove None's before calling and add them back in data later 

50 # Isolate invalid samples (when previous transformers returned None) 

51 invalid_ids = [i for i, frame in enumerate(video) if frame is None] 

52 valid_frames = [frame for frame in video if frame is not None] 

53 

54 # remove invalid kw args as well 

55 for k, v in kw.items(): 

56 if v is None: 

57 continue 

58 kw[k] = [vv for j, vv in enumerate(v) if j not in invalid_ids] 

59 

60 # Process only the valid samples 

61 output = None 

62 if len(valid_frames) > 0: 

63 output = self.estimator.transform(valid_frames, **kw) 

64 _check_n_input_output( 

65 valid_frames, output, f"{_frmt(self.estimator)}.transform" 

66 ) 

67 

68 if output is None: 

69 output = [None] * len(valid_frames) 

70 

71 # Rebuild the full batch of samples (include the previously failed) 

72 if len(invalid_ids) > 0: 

73 output = list(output) 

74 for j in invalid_ids: 

75 output.insert(j, None) 

76 

77 data = utils.VideoLikeContainer(output, video.indices) 

78 transformed_videos.append(data) 

79 return transformed_videos 

80 

81 def _more_tags(self): 

82 tags = self.estimator._get_tags() 

83 tags["bob_features_save_fn"] = utils.VideoLikeContainer.save_function 

84 tags["bob_features_load_fn"] = utils.VideoLikeContainer.load 

85 return tags 

86 

87 def fit(self, X, y=None, **fit_params): 

88 """Does nothing""" 

89 return self