Coverage for src/bob/bio/base/pipelines/pipelines.py: 94%

63 statements  

« prev     ^ index     » next       coverage.py v7.6.0, created at 2024-07-12 22:34 +0200

1#!/usr/bin/env python 

2# vim: set fileencoding=utf-8 : 

3 

4""" 

5Implementation of the PipelineSimple using Dask :ref:`bob.bio.base.struct_bio_rec_sys`_ 

6 

7This file contains simple processing blocks meant to be used 

8for bob.bio experiments 

9""" 

10 

11import logging 

12 

13from sklearn.base import BaseEstimator 

14from sklearn.pipeline import Pipeline 

15 

16from bob.bio.base.pipelines.abstract_classes import BioAlgorithm 

17from bob.pipelines import SampleWrapper, is_instance_nested, wrap 

18 

19from .score_writers import FourColumnsScoreWriter 

20 

21logger = logging.getLogger(__name__) 

22import tempfile 

23 

24 

25class PipelineSimple: 

26 """ 

27 The simplest possible pipeline 

28 

29 This is the backbone of most biometric recognition systems. 

30 It implements three subpipelines and they are the following: 

31 

32 - :py:class:`PipelineSimple.train_background_model`: Initializes or trains your transformer. 

33 It will run :py:meth:`sklearn.base.BaseEstimator.fit` 

34 

35 - :py:class:`PipelineSimple.enroll_templates`: Creates enrollment templates 

36 It will run :py:meth:`sklearn.base.BaseEstimator.transform` followed by a sequence of 

37 :py:meth:`bob.bio.base.pipelines.abstract_classes.BioAlgorithm.create_templates` 

38 

39 - :py:class:`PipelineSimple.probe_templates`: Creates probe templates 

40 It will run :py:meth:`sklearn.base.BaseEstimator.transform` followed by a sequence of 

41 :py:meth:`bob.bio.base.pipelines.abstract_classes.BioAlgorithm.create_templates` 

42 

43 - :py:class:`PipelineSimple.compute_scores`: Computes scores 

44 It will run :py:meth:`bob.bio.base.pipelines.abstract_classes.BioAlgorithm.compare` 

45 

46 

47 Example 

48 ------- 

49 >>> from sklearn.preprocessing import FunctionTransformer 

50 >>> from sklearn.pipeline import make_pipeline 

51 >>> from bob.bio.base.algorithm import Distance 

52 >>> from bob.bio.base.pipelines import PipelineSimple 

53 >>> from bob.pipelines import wrap 

54 >>> import numpy 

55 >>> linearize = lambda samples: [numpy.reshape(x, (-1,)) for x in samples] 

56 >>> transformer = wrap(["sample"], FunctionTransformer(linearize)) 

57 >>> transformer_pipeline = make_pipeline(transformer) 

58 >>> biometric_algorithm = Distance() 

59 >>> pipeline = PipelineSimple(transformer_pipeline, biometric_algorithm) 

60 >>> pipeline(samples_for_training_back_ground_model, samplesets_for_enroll, samplesets_for_scoring) # doctest: +SKIP 

61 

62 

63 To run this pipeline using Dask, used the function 

64 :py:func:`dask_bio_pipeline`. 

65 

66 Example 

67 ------- 

68 >>> from bob.bio.base.pipelines import dask_bio_pipeline 

69 >>> pipeline = PipelineSimple(transformer_pipeline, biometric_algorithm) 

70 >>> pipeline = dask_bio_pipeline(pipeline) 

71 >>> pipeline(samples_for_training_back_ground_model, samplesets_for_enroll, samplesets_for_scoring).compute() # doctest: +SKIP 

72 

73 

74 Parameters 

75 ---------- 

76 

77 transformer: :py:class`sklearn.pipeline.Pipeline` or a `sklearn.base.BaseEstimator` 

78 Transformer that will preprocess your data 

79 

80 biometric_algorithm: :py:class:`bob.bio.base.pipelines.abstract_classes.BioAlgorithm` 

81 Biometrics algorithm object that implements the methods `enroll` and 

82 `score` methods 

83 

84 score_writer: :any:`bob.bio.base.pipelines.ScoreWriter` 

85 Format to write scores. Default to 

86 :any:`bob.bio.base.pipelines.FourColumnsScoreWriter` 

87 

88 """ 

89 

90 def __init__( 

91 self, 

92 transformer: Pipeline, 

93 biometric_algorithm: BioAlgorithm, 

94 score_writer=None, 

95 ): 

96 self.transformer = transformer 

97 self.biometric_algorithm = biometric_algorithm 

98 self.score_writer = score_writer 

99 if self.score_writer is None: 

100 tempdir = tempfile.TemporaryDirectory() 

101 self.score_writer = FourColumnsScoreWriter(tempdir.name) 

102 

103 check_valid_pipeline(self) 

104 

105 def __call__( 

106 self, 

107 background_model_samples, 

108 biometric_reference_samples, 

109 probe_samples, 

110 score_all_vs_all=True, 

111 return_templates=False, 

112 ): 

113 logger.info(" >> PipelineSimple: Training background model") 

114 self.train_background_model(background_model_samples) 

115 

116 logger.info(" >> PipelineSimple: Creating enroll templates") 

117 enroll_templates = self.enroll_templates(biometric_reference_samples) 

118 

119 logger.info(" >> PipelineSimple: Creating probe templates") 

120 probe_templates = self.probe_templates(probe_samples) 

121 

122 logger.info(" >> PipelineSimple: Computing scores") 

123 scores = self.compute_scores( 

124 probe_templates, 

125 enroll_templates, 

126 score_all_vs_all, 

127 ) 

128 

129 if return_templates: 

130 return scores, enroll_templates, probe_templates 

131 else: 

132 return scores 

133 

134 def train_background_model(self, background_model_samples): 

135 # background_model_samples is a list of Samples 

136 

137 # We might have algorithms that has no data for training 

138 if len(background_model_samples) > 0: 

139 self.transformer.fit(background_model_samples) 

140 else: 

141 logger.warning( 

142 "There's no data to train background model. " 

143 "For the rest of the execution it will be assumed that the pipeline does not require fit." 

144 ) 

145 return self.transformer 

146 

147 def enroll_templates(self, biometric_reference_samples): 

148 biometric_reference_features = self.transformer.transform( 

149 biometric_reference_samples 

150 ) 

151 

152 enroll_templates = ( 

153 self.biometric_algorithm.create_templates_from_samplesets( 

154 biometric_reference_features, enroll=True 

155 ) 

156 ) 

157 

158 # a list of Samples 

159 return enroll_templates 

160 

161 def probe_templates(self, probe_samples): 

162 probe_features = self.transformer.transform(probe_samples) 

163 

164 probe_templates = ( 

165 self.biometric_algorithm.create_templates_from_samplesets( 

166 probe_features, enroll=False 

167 ) 

168 ) 

169 

170 # a list of Samples 

171 return probe_templates 

172 

173 def compute_scores( 

174 self, 

175 probe_templates, 

176 enroll_templates, 

177 score_all_vs_all, 

178 ): 

179 return self.biometric_algorithm.score_sample_templates( 

180 probe_templates, enroll_templates, score_all_vs_all 

181 ) 

182 

183 def write_scores(self, scores): 

184 if self.score_writer is None: 

185 raise ValueError("No score writer defined in the pipeline") 

186 return self.score_writer.write(scores) 

187 

188 def post_process(self, score_paths, filename): 

189 if self.score_writer is None: 

190 raise ValueError("No score writer defined in the pipeline") 

191 

192 return self.score_writer.post_process(score_paths, filename) 

193 

194 

195def check_valid_pipeline(pipeline_simple): 

196 """ 

197 Applying some checks in the PipelineSimple 

198 """ 

199 

200 # CHECKING THE TRANSFORMER 

201 # Checking if it's a Scikit Pipeline or an estimator 

202 if isinstance(pipeline_simple.transformer, Pipeline): 

203 # Checking if all steps are wrapped as samples, if not, we should wrap them 

204 for p in pipeline_simple.transformer: 

205 if ( 

206 not is_instance_nested(p, "estimator", SampleWrapper) 

207 and type(p) is not str 

208 and p is not None 

209 ): 

210 wrap(["sample"], p) 

211 

212 # In this case it can be a simple estimator. AND 

213 # Checking if it's sample wrapper, if not, do it 

214 elif is_instance_nested( 

215 pipeline_simple.transformer, "estimator", BaseEstimator 

216 ) and is_instance_nested( 

217 pipeline_simple.transformer, "estimator", BaseEstimator 

218 ): 

219 wrap(["sample"], pipeline_simple.transformer) 

220 else: 

221 raise ValueError( 

222 f"pipeline_simple.transformer should be instance of either `sklearn.pipeline.Pipeline` or" 

223 f"sklearn.base.BaseEstimator, not {pipeline_simple.transformer}" 

224 ) 

225 

226 # Checking the Biometric algorithm 

227 if not isinstance(pipeline_simple.biometric_algorithm, BioAlgorithm): 

228 raise ValueError( 

229 f"pipeline_simple.biometric_algorithm should be instance of `BioAlgorithm`" 

230 f"not {pipeline_simple.biometric_algorithm}" 

231 ) 

232 

233 return True