Coverage for src/bob/fusion/base/algorithm/Algorithm.py: 96%
47 statements
« prev ^ index » next coverage.py v7.6.0, created at 2024-07-13 01:00 +0200
« prev ^ index » next coverage.py v7.6.0, created at 2024-07-13 01:00 +0200
1#!/usr/bin/env python
3from __future__ import absolute_import, division
5import logging
6import pickle
8import numpy as np
10logger = logging.getLogger(__name__)
13class Algorithm(object):
14 """A class to be used in score fusion
16 Attributes
17 ----------
18 classifier
19 preprocessors
20 str : dict
21 A dictionary that its content will printed in the __str__ method.
22 """
24 def __init__(self, preprocessors=None, classifier=None, **kwargs):
25 """
26 Parameters
27 ----------
28 preprocessors : :any:`list`
29 An optional list of preprocessors that follow the API of
30 :any:`sklearn.preprocessing.StandardScaler`. Especially
31 `fit_transform` and `transform` must be implemented.
32 classifier
33 An instance of a class that implements `fit(X[, y])` and
34 `decision_function(X)` like:
35 :any:`sklearn.linear_model.LogisticRegression`
36 **kwargs
37 All extra
38 """
39 super(Algorithm, self).__init__(**kwargs)
40 self.classifier = classifier
41 self.preprocessors = preprocessors
42 self.str = {"preprocessors": preprocessors}
43 if classifier is not self:
44 self.str["classifier"] = classifier
46 def train_preprocessors(self, X, y=None):
47 """Train preprocessors in order.
48 X: numpy.ndarray with the shape of (n_samples, n_systems)."""
49 if self.preprocessors is not None:
50 for preprocessor in self.preprocessors:
51 X = preprocessor.fit_transform(X, y)
53 def preprocess(self, scores):
54 """
55 scores: numpy.ndarray with the shape of (n_samples, n_systems).
56 returns the transformed scores."""
57 if scores.size == 0:
58 return scores
59 if self.preprocessors is not None:
60 for preprocessor in self.preprocessors:
61 scores = preprocessor.transform(scores)
62 return scores
64 def train(self, train_neg, train_pos, devel_neg=None, devel_pos=None):
65 """If you use development data for training you need to override this
66 method.
68 train_neg: numpy.ndarray
69 Negatives training data should be numpy.ndarray with the shape of
70 (n_samples, n_systems).
71 train_pos: numpy.ndarray
72 Positives training data should be numpy.ndarray with the shape of
73 (n_samples, n_systems).
74 devel_neg, devel_pos: numpy.ndarray
75 Same as ``train`` but used for development (validation).
76 """
77 train_scores = np.vstack((train_neg, train_pos))
78 neg_len = train_neg.shape[0]
79 y = np.zeros((train_scores.shape[0],), dtype="bool")
80 y[neg_len:] = True
81 self.classifier.fit(train_scores, y)
83 def fuse(self, scores):
84 """
85 scores: numpy.ndarray
86 A numpy.ndarray with the shape of (n_samples, n_systems).
88 **Returns:**
90 fused_score: numpy.ndarray
91 The fused scores in shape of (n_samples,).
92 """
93 return self.classifier.decision_function(scores)
95 def __str__(self):
96 """Return all parameters of this class (and its derived class) in string.
99 **Returns:**
101 info: str
102 A string containing the full information of all parameters of this
103 (and the derived) class.
104 """
105 return "%s(%s)" % (
106 str(self.__class__),
107 ", ".join(
108 [
109 "%s=%s" % (key, value)
110 for key, value in self.str.items()
111 if value is not None
112 ]
113 ),
114 )
116 def save(self, model_file):
117 """Save the instance of the algorithm.
119 model_file: str
120 A path to save the file. Please note that file objects
121 are not accepted. The filename MUST end with ".pkl".
122 Also, an algorithm may save itself in multiple files with different
123 extensions such as model_file and model_file[:-3]+'hdf5'.
124 """
125 # support for bob machines
126 if hasattr(self, "custom_save"):
127 self.custom_save(model_file)
128 else:
129 with open(model_file, "wb") as f:
130 pickle.dump(type(self), f)
131 pickle.dump(self, f)
133 def load(self, model_file):
134 """Load the algorithm the same way it was saved.
135 A new instance will be returned.
137 **Returns:**
139 loaded_algorithm: Algorithm
140 A new instance of the loaded algorithm.
141 """
142 with open(model_file, "rb") as f:
143 algo_class = pickle.load(f)
144 algo = algo_class()
145 if not hasattr(algo, "custom_save"):
146 return pickle.load(f)
147 return algo.load(model_file)