Coverage for src/bob/bio/spear/transformer/path_to_audio.py: 93%

30 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2023-12-06 22:04 +0100

1#!/usr/bin/env python 

2# @author: Yannick Dayer <yannick.dayer@idiap.ch> 

3# @date: Thu 01 Jul 2021 10:41:55 UTC+02 

4 

5import logging 

6 

7from functools import partial 

8from typing import Optional 

9 

10import numpy 

11 

12from sklearn.base import BaseEstimator, TransformerMixin 

13 

14from bob.bio.spear.audio_processing import read as read_audio 

15from bob.pipelines import DelayedSample 

16 

17logger = logging.getLogger(__name__) 

18 

19 

20def get_audio_sample_rate(path: str, forced_sr: Optional[int] = None) -> int: 

21 """Returns the sample rate of the audio data.""" 

22 return ( 

23 forced_sr if forced_sr is not None else read_audio(path, None, None)[1] 

24 ) 

25 

26 

27def get_audio_data( 

28 path: str, 

29 channel: Optional[int] = None, 

30 forced_sr: Optional[int] = None, 

31) -> numpy.ndarray: 

32 """Returns the audio data from the given path.""" 

33 return read_audio(path, channel, forced_sr)[0] 

34 

35 

36class PathToAudio(BaseEstimator, TransformerMixin): 

37 """Transforms a Sample's data containing a path to an audio signal. 

38 

39 The Sample's metadata are updated (rate). 

40 

41 Note: 

42 audio processing functions expect int16 audio (range [-32768, 32767]), but in 

43 float format. Hence the loading as int16 and the cast to float. (values will be 

44 in the range [-32768.0, 32767.0]) 

45 """ 

46 

47 def __init__( 

48 self, 

49 forced_channel: Optional[int] = None, 

50 forced_sr: Optional[int] = None, 

51 ) -> None: 

52 """ 

53 Parameters 

54 ---------- 

55 forced_sr: 

56 If not None, every sample rate will be forced to this value (resampling if 

57 needed). 

58 forced_channel: 

59 Forces the loading of a specific channel for each audio file, if the samples 

60 don't have a ``channel`` attribute. If None and the samples don't have a 

61 ``channel`` attribute, all the channels will be loaded in a 2D array. 

62 """ 

63 super().__init__() 

64 self.forced_channel = forced_channel 

65 self.forced_sr = forced_sr 

66 

67 def transform(self, samples: list) -> list: 

68 output_samples = [] 

69 for sample in samples: 

70 channel = getattr(sample, "channel", self.forced_channel) 

71 load_fn = partial( 

72 get_audio_data, 

73 sample.data, 

74 int(channel) if channel is not None else None, 

75 self.forced_sr, 

76 ) 

77 delayed_attrs = { 

78 "rate": partial( 

79 get_audio_sample_rate, sample.data, self.forced_sr 

80 ) 

81 } 

82 new_sample = DelayedSample( 

83 load=load_fn, 

84 parent=sample, 

85 delayed_attributes=delayed_attrs, 

86 ) 

87 output_samples.append(new_sample) 

88 return output_samples 

89 

90 def fit(self, X, y=None): 

91 return self 

92 

93 def _more_tags(self): 

94 return { 

95 "stateless": True, 

96 "requires_fit": False, 

97 }