Coverage for src/bob/bio/spear/transformer/preprocessing.py: 100%
13 statements
« prev ^ index » next coverage.py v7.3.2, created at 2023-12-06 22:04 +0100
« prev ^ index » next coverage.py v7.3.2, created at 2023-12-06 22:04 +0100
1import numpy as np
3from sklearn.preprocessing import OrdinalEncoder
6class ReferenceIdEncoder(OrdinalEncoder):
7 # Default values of init args are different from the base class
8 def __init__(
9 self,
10 *,
11 categories="auto",
12 dtype=int,
13 handle_unknown="use_encoded_value",
14 unknown_value=-1,
15 **kwargs,
16 ):
17 super().__init__(
18 categories=categories,
19 dtype=dtype,
20 handle_unknown=handle_unknown,
21 unknown_value=unknown_value,
22 **kwargs,
23 )
25 def fit(self, X, y=None):
26 # X is a SampleBatch or list of template_id strings
27 # we want a 2d array of shape (N, 1)
28 X = np.asarray(X).reshape((-1, 1))
29 return super().fit(X)
31 def transform(self, X):
32 X = np.asarray(X).reshape((-1, 1))
33 # we output a flat array instead
34 return super().transform(X).flatten()
36 def _more_tags(self):
37 return {
38 "bob_input": "subject_id",
39 "bob_output": "subject_id_int",
40 }