Coverage for src/bob/bio/spear/transformer/path_to_audio.py: 93%
30 statements
« prev ^ index » next coverage.py v7.3.2, created at 2023-12-06 22:04 +0100
« prev ^ index » next coverage.py v7.3.2, created at 2023-12-06 22:04 +0100
1#!/usr/bin/env python
2# @author: Yannick Dayer <yannick.dayer@idiap.ch>
3# @date: Thu 01 Jul 2021 10:41:55 UTC+02
5import logging
7from functools import partial
8from typing import Optional
10import numpy
12from sklearn.base import BaseEstimator, TransformerMixin
14from bob.bio.spear.audio_processing import read as read_audio
15from bob.pipelines import DelayedSample
17logger = logging.getLogger(__name__)
20def get_audio_sample_rate(path: str, forced_sr: Optional[int] = None) -> int:
21 """Returns the sample rate of the audio data."""
22 return (
23 forced_sr if forced_sr is not None else read_audio(path, None, None)[1]
24 )
27def get_audio_data(
28 path: str,
29 channel: Optional[int] = None,
30 forced_sr: Optional[int] = None,
31) -> numpy.ndarray:
32 """Returns the audio data from the given path."""
33 return read_audio(path, channel, forced_sr)[0]
36class PathToAudio(BaseEstimator, TransformerMixin):
37 """Transforms a Sample's data containing a path to an audio signal.
39 The Sample's metadata are updated (rate).
41 Note:
42 audio processing functions expect int16 audio (range [-32768, 32767]), but in
43 float format. Hence the loading as int16 and the cast to float. (values will be
44 in the range [-32768.0, 32767.0])
45 """
47 def __init__(
48 self,
49 forced_channel: Optional[int] = None,
50 forced_sr: Optional[int] = None,
51 ) -> None:
52 """
53 Parameters
54 ----------
55 forced_sr:
56 If not None, every sample rate will be forced to this value (resampling if
57 needed).
58 forced_channel:
59 Forces the loading of a specific channel for each audio file, if the samples
60 don't have a ``channel`` attribute. If None and the samples don't have a
61 ``channel`` attribute, all the channels will be loaded in a 2D array.
62 """
63 super().__init__()
64 self.forced_channel = forced_channel
65 self.forced_sr = forced_sr
67 def transform(self, samples: list) -> list:
68 output_samples = []
69 for sample in samples:
70 channel = getattr(sample, "channel", self.forced_channel)
71 load_fn = partial(
72 get_audio_data,
73 sample.data,
74 int(channel) if channel is not None else None,
75 self.forced_sr,
76 )
77 delayed_attrs = {
78 "rate": partial(
79 get_audio_sample_rate, sample.data, self.forced_sr
80 )
81 }
82 new_sample = DelayedSample(
83 load=load_fn,
84 parent=sample,
85 delayed_attributes=delayed_attrs,
86 )
87 output_samples.append(new_sample)
88 return output_samples
90 def fit(self, X, y=None):
91 return self
93 def _more_tags(self):
94 return {
95 "stateless": True,
96 "requires_fit": False,
97 }