Coverage for src/bob/bio/video/utils.py: 88%

146 statements  

« prev     ^ index     » next       coverage.py v7.6.5, created at 2024-11-14 22:56 +0100

1import functools 

2import importlib 

3import logging 

4import pickle 

5import unittest 

6 

7import h5py 

8import imageio 

9import numpy as np 

10 

11from bob.bio.base import selected_indices 

12from bob.io.image import to_bob 

13from bob.pipelines import wrap 

14 

15from .transformer import VideoWrapper 

16 

17logger = logging.getLogger(__name__) 

18 

19 

20def video_wrap_skpipeline(sk_pipeline): 

21 """ 

22 This function takes a `sklearn.Pipeline` and wraps each estimator inside of it with 

23 :any:`bob.bio.video.transformer.VideoWrapper` 

24 """ 

25 

26 for i, name, estimator in sk_pipeline._iter(): 

27 # 1. Unwrap the estimator 

28 # If the estimator is `Sample` wrapped takes `estimator.estimator`. 

29 transformer = ( 

30 estimator.estimator 

31 if hasattr(estimator, "estimator") 

32 else estimator 

33 ) 

34 

35 # 2. do a video wrap 

36 transformer = VideoWrapper(transformer) 

37 

38 # 3. Sample wrap again 

39 transformer = wrap( 

40 ["sample"], 

41 transformer, 

42 fit_extra_arguments=estimator.fit_extra_arguments, 

43 transform_extra_arguments=estimator.transform_extra_arguments, 

44 ) 

45 

46 sk_pipeline.steps[i] = (name, transformer) 

47 

48 return sk_pipeline 

49 

50 

51def select_frames( 

52 count, max_number_of_frames=None, selection_style=None, step_size=None 

53): 

54 """Returns indices of the frames to be selected given the parameters. 

55 

56 Different selection styles are supported: 

57 

58 * first : The first frames are selected 

59 * spread : Frames are selected to be taken from the whole video with equal spaces in 

60 between. 

61 * step : Frames are selected every ``step_size`` indices, starting at 

62 ``step_size/2`` **Think twice if you want to have that when giving FrameContainer 

63 data!** 

64 * all : All frames are selected unconditionally. 

65 

66 Parameters 

67 ---------- 

68 count : int 

69 Total number of frames that are available 

70 max_number_of_frames : int 

71 The maximum number of frames to be selected. Ignored when selection_style is 

72 "all". 

73 selection_style : str 

74 One of (``first``, ``spread``, ``step``, ``all``). See above. 

75 step_size : int 

76 Only useful when ``selection_style`` is ``step``. 

77 

78 Returns 

79 ------- 

80 range 

81 A range of frames to be selected. 

82 

83 Raises 

84 ------ 

85 ValueError 

86 If ``selection_style`` is not one of the supported ones. 

87 """ 

88 # default values 

89 if max_number_of_frames is None: 

90 max_number_of_frames = 20 

91 if selection_style is None: 

92 selection_style = "spread" 

93 if step_size is None: 

94 step_size = 10 

95 

96 if selection_style == "first": 

97 # get the first frames (limited by all frames) 

98 indices = range(0, min(count, max_number_of_frames)) 

99 elif selection_style == "spread": 

100 # get frames lineraly spread over all frames 

101 indices = selected_indices(count, max_number_of_frames) 

102 elif selection_style == "step": 

103 indices = range(step_size // 2, count, step_size)[:max_number_of_frames] 

104 elif selection_style == "all": 

105 indices = range(0, count) 

106 else: 

107 raise ValueError(f"Invalid selection style: {selection_style}") 

108 

109 return indices 

110 

111 

112def no_transform(x): 

113 return x 

114 

115 

116def is_library_available(library): 

117 """Decorator to check if the mxnet is present, before running the test""" 

118 

119 def _is_library_available(function): 

120 @functools.wraps(function) 

121 def wrapper(*args, **kwargs): 

122 try: 

123 importlib.import_module(library) 

124 

125 return function(*args, **kwargs) 

126 except ImportError as e: 

127 # unittest.SkipTest is compatible with both nose and pytest 

128 raise unittest.SkipTest( 

129 f"Skipping test since `{library}` is not available: %s" % e 

130 ) 

131 

132 return wrapper 

133 

134 return _is_library_available 

135 

136 

137class VideoAsArray: 

138 """A memory efficient class to load only select video frames. 

139 It also supports efficient conversion to dask arrays. 

140 """ 

141 

142 def __init__( 

143 self, 

144 path, 

145 selection_style=None, 

146 max_number_of_frames=None, 

147 step_size=None, 

148 transform=None, 

149 **kwargs, 

150 ): 

151 """init 

152 

153 Parameters 

154 ---------- 

155 path : str 

156 Path to the video file 

157 selection_style : str, optional 

158 See :any:`select_frames`, by default None 

159 max_number_of_frames : int, optional 

160 See :any:`select_frames`, by default None 

161 step_size : int, optional 

162 See :any:`select_frames`, by default None 

163 transform : callable, optional 

164 A function that transforms the loaded video. This function should 

165 not change the video shape or its dtype. For example, you may flip 

166 the frames horizontally using this function, by default None 

167 """ 

168 super().__init__(**kwargs) 

169 self.path = path 

170 self.reader = imageio.get_reader(path) 

171 self.dtype = np.uint8 

172 shape = (self.reader.count_frames(), 3) + self.reader.get_meta_data()[ 

173 "size" 

174 ][::-1] 

175 self.ndim = len(shape) 

176 self.selection_style = selection_style 

177 

178 indices = select_frames( 

179 count=self.reader.count_frames(), 

180 max_number_of_frames=max_number_of_frames, 

181 selection_style=selection_style, 

182 step_size=step_size, 

183 ) 

184 

185 self.indices = indices 

186 self.shape = (len(indices),) + shape[1:] 

187 self.transform = transform or no_transform 

188 

189 def __getstate__(self): 

190 d = self.__dict__.copy() 

191 d.pop("reader") 

192 return d 

193 

194 def __setstate__(self, state): 

195 self.__dict__.update(state) 

196 self.reader = imageio.get_reader(self.path) 

197 

198 def __len__(self): 

199 return self.shape[0] 

200 

201 def __getitem__(self, index): 

202 # logger.debug("Getting frame %s from %s", index, self.path) 

203 

204 # In this method, someone is requesting indices thinking this video has 

205 # the shape of self.shape but self.shape is determined through 

206 # select_frames parameters. What we want to do here is to translate 

207 # ``index`` to real indices of the video file given that we want to load 

208 # only the selected frames. List of the selected frames are stored in 

209 # self.indices 

210 

211 # If only one frame is requested, first translate the index to the real 

212 # frame number in the video file and load that 

213 

214 if isinstance(index, int): 

215 idx = self.indices[index] 

216 return self.transform( 

217 np.asarray([to_bob(self.reader.get_data(idx))]) 

218 )[0] 

219 

220 if not ( 

221 isinstance(index, tuple) 

222 and len(index) == self.ndim 

223 and all(isinstance(idx, slice) for idx in index) 

224 ): 

225 raise NotImplementedError( 

226 f"Indexing like {index} is not supported yet!" 

227 ) 

228 

229 # dask.array.from_array sometimes requests empty arrays 

230 if all(i == slice(0, 0) for i in index): 

231 return np.array([], dtype=self.dtype) 

232 

233 def _frames_generator(): 

234 # read the frames one by one and yield them 

235 real_frame_numbers = self.indices[index[0]] 

236 for i, frame in enumerate(self.reader): 

237 frame = to_bob(frame) 

238 if i not in real_frame_numbers: 

239 continue 

240 # make sure arrays are loaded in C order because we reshape them 

241 # by C order later. Also, index into the frames here 

242 frame = np.ascontiguousarray(frame)[index[1:]] 

243 # return a tuple of flat array to match what is expected by 

244 # field_dtype 

245 yield (frame.ravel(),) 

246 if i == real_frame_numbers[-1]: 

247 break 

248 

249 iterable = _frames_generator() 

250 # compute the final shape given self.shape and index 

251 # see https://stackoverflow.com/a/36188683/1286165 

252 shape = [ 

253 len(range(*idx.indices(dim))) for idx, dim in zip(index, self.shape) 

254 ] 

255 # field_dtype contains information about dtype and shape of each frame 

256 # numpy black magic: https://stackoverflow.com/a/12473478/1286165 allows 

257 # us to yield frame by frame in _frames_generator which greatly speeds 

258 # up loading the video 

259 field_dtype = [("", (self.dtype, (np.prod(shape[1:]),)))] 

260 total_number_of_frames = shape[0] 

261 video = np.fromiter(iterable, field_dtype, total_number_of_frames) 

262 # view the array as self.dtype to remove the field_dtype 

263 video = np.reshape(video.view(self.dtype), shape, order="C") 

264 

265 return self.transform(video) 

266 

267 def __repr__(self): 

268 return f"VideoAsArray: {self.path!r} {self.dtype!r} {self.ndim!r} {self.shape!r} {self.indices!r}" 

269 

270 

271class VideoLikeContainer: 

272 def __init__(self, data, indices, **kwargs): 

273 super().__init__(**kwargs) 

274 self.data = data 

275 self.indices = indices 

276 

277 def __repr__(self) -> str: 

278 return f"VideoLikeContainer: {self.data!r} {self.indices!r}" 

279 

280 @property 

281 def dtype(self): 

282 return self.data.dtype 

283 

284 @property 

285 def shape(self): 

286 return self.data.shape 

287 

288 @property 

289 def ndim(self): 

290 return self.data.ndim 

291 

292 def __len__(self): 

293 return len(self.data) 

294 

295 def __getitem__(self, item): 

296 # we need to throw IndexErrors here because h5py throws ValueErrors 

297 # instead and this breaks loops on this class 

298 if isinstance(item, int) and item >= len(self): 

299 raise IndexError(f"Index ({item}) out of range (0-{len(self)-1})") 

300 return self.data[item] 

301 

302 def __array__(self, dtype=None, *args, **kwargs): 

303 return np.asarray(self.data, dtype, *args, **kwargs) 

304 

305 def __eq__(self, o: object) -> bool: 

306 return np.array_equal(self.data, o.data) and np.array_equal( 

307 self.indices, o.indices 

308 ) 

309 

310 def save(self, file): 

311 self.save_function(self, file) 

312 

313 @staticmethod 

314 def save_function(other, file): 

315 try: 

316 with h5py.File(file, mode="w") as f: 

317 f["data"] = other.data 

318 f["indices"] = other.indices 

319 # revert to saving data in pickles when the dtype is not supported by hdf5 

320 except TypeError: 

321 with open(file, "wb") as f: 

322 pickle.dump({"data": other.data, "indices": other.indices}, f) 

323 

324 @classmethod 

325 def load(cls, file): 

326 try: 

327 # weak closing of the hdf5 file so we don't load all the data into 

328 # memory https://docs.h5py.org/en/stable/high/file.html#closing-files 

329 f = h5py.File(file, mode="r") 

330 loaded = {"data": f["data"], "indices": list(f["indices"])} 

331 except OSError: 

332 with open(file, "rb") as f: 

333 loaded = pickle.load(f) 

334 self = cls(**loaded) 

335 return self