Coverage for src/bob/fusion/base/algorithm/Algorithm.py: 96%

47 statements  

« prev     ^ index     » next       coverage.py v7.6.0, created at 2024-07-13 01:00 +0200

1#!/usr/bin/env python 

2 

3from __future__ import absolute_import, division 

4 

5import logging 

6import pickle 

7 

8import numpy as np 

9 

10logger = logging.getLogger(__name__) 

11 

12 

13class Algorithm(object): 

14 """A class to be used in score fusion 

15 

16 Attributes 

17 ---------- 

18 classifier 

19 preprocessors 

20 str : dict 

21 A dictionary that its content will printed in the __str__ method. 

22 """ 

23 

24 def __init__(self, preprocessors=None, classifier=None, **kwargs): 

25 """ 

26 Parameters 

27 ---------- 

28 preprocessors : :any:`list` 

29 An optional list of preprocessors that follow the API of 

30 :any:`sklearn.preprocessing.StandardScaler`. Especially 

31 `fit_transform` and `transform` must be implemented. 

32 classifier 

33 An instance of a class that implements `fit(X[, y])` and 

34 `decision_function(X)` like: 

35 :any:`sklearn.linear_model.LogisticRegression` 

36 **kwargs 

37 All extra 

38 """ 

39 super(Algorithm, self).__init__(**kwargs) 

40 self.classifier = classifier 

41 self.preprocessors = preprocessors 

42 self.str = {"preprocessors": preprocessors} 

43 if classifier is not self: 

44 self.str["classifier"] = classifier 

45 

46 def train_preprocessors(self, X, y=None): 

47 """Train preprocessors in order. 

48 X: numpy.ndarray with the shape of (n_samples, n_systems).""" 

49 if self.preprocessors is not None: 

50 for preprocessor in self.preprocessors: 

51 X = preprocessor.fit_transform(X, y) 

52 

53 def preprocess(self, scores): 

54 """ 

55 scores: numpy.ndarray with the shape of (n_samples, n_systems). 

56 returns the transformed scores.""" 

57 if scores.size == 0: 

58 return scores 

59 if self.preprocessors is not None: 

60 for preprocessor in self.preprocessors: 

61 scores = preprocessor.transform(scores) 

62 return scores 

63 

64 def train(self, train_neg, train_pos, devel_neg=None, devel_pos=None): 

65 """If you use development data for training you need to override this 

66 method. 

67 

68 train_neg: numpy.ndarray 

69 Negatives training data should be numpy.ndarray with the shape of 

70 (n_samples, n_systems). 

71 train_pos: numpy.ndarray 

72 Positives training data should be numpy.ndarray with the shape of 

73 (n_samples, n_systems). 

74 devel_neg, devel_pos: numpy.ndarray 

75 Same as ``train`` but used for development (validation). 

76 """ 

77 train_scores = np.vstack((train_neg, train_pos)) 

78 neg_len = train_neg.shape[0] 

79 y = np.zeros((train_scores.shape[0],), dtype="bool") 

80 y[neg_len:] = True 

81 self.classifier.fit(train_scores, y) 

82 

83 def fuse(self, scores): 

84 """ 

85 scores: numpy.ndarray 

86 A numpy.ndarray with the shape of (n_samples, n_systems). 

87 

88 **Returns:** 

89 

90 fused_score: numpy.ndarray 

91 The fused scores in shape of (n_samples,). 

92 """ 

93 return self.classifier.decision_function(scores) 

94 

95 def __str__(self): 

96 """Return all parameters of this class (and its derived class) in string. 

97 

98 

99 **Returns:** 

100 

101 info: str 

102 A string containing the full information of all parameters of this 

103 (and the derived) class. 

104 """ 

105 return "%s(%s)" % ( 

106 str(self.__class__), 

107 ", ".join( 

108 [ 

109 "%s=%s" % (key, value) 

110 for key, value in self.str.items() 

111 if value is not None 

112 ] 

113 ), 

114 ) 

115 

116 def save(self, model_file): 

117 """Save the instance of the algorithm. 

118 

119 model_file: str 

120 A path to save the file. Please note that file objects 

121 are not accepted. The filename MUST end with ".pkl". 

122 Also, an algorithm may save itself in multiple files with different 

123 extensions such as model_file and model_file[:-3]+'hdf5'. 

124 """ 

125 # support for bob machines 

126 if hasattr(self, "custom_save"): 

127 self.custom_save(model_file) 

128 else: 

129 with open(model_file, "wb") as f: 

130 pickle.dump(type(self), f) 

131 pickle.dump(self, f) 

132 

133 def load(self, model_file): 

134 """Load the algorithm the same way it was saved. 

135 A new instance will be returned. 

136 

137 **Returns:** 

138 

139 loaded_algorithm: Algorithm 

140 A new instance of the loaded algorithm. 

141 """ 

142 with open(model_file, "rb") as f: 

143 algo_class = pickle.load(f) 

144 algo = algo_class() 

145 if not hasattr(algo, "custom_save"): 

146 return pickle.load(f) 

147 return algo.load(model_file)