Coverage for src/bob/bio/spear/transformer/preprocessing.py: 100%

13 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2023-12-06 22:04 +0100

1import numpy as np 

2 

3from sklearn.preprocessing import OrdinalEncoder 

4 

5 

6class ReferenceIdEncoder(OrdinalEncoder): 

7 # Default values of init args are different from the base class 

8 def __init__( 

9 self, 

10 *, 

11 categories="auto", 

12 dtype=int, 

13 handle_unknown="use_encoded_value", 

14 unknown_value=-1, 

15 **kwargs, 

16 ): 

17 super().__init__( 

18 categories=categories, 

19 dtype=dtype, 

20 handle_unknown=handle_unknown, 

21 unknown_value=unknown_value, 

22 **kwargs, 

23 ) 

24 

25 def fit(self, X, y=None): 

26 # X is a SampleBatch or list of template_id strings 

27 # we want a 2d array of shape (N, 1) 

28 X = np.asarray(X).reshape((-1, 1)) 

29 return super().fit(X) 

30 

31 def transform(self, X): 

32 X = np.asarray(X).reshape((-1, 1)) 

33 # we output a flat array instead 

34 return super().transform(X).flatten() 

35 

36 def _more_tags(self): 

37 return { 

38 "bob_input": "subject_id", 

39 "bob_output": "subject_id_int", 

40 }