Coverage for src/bob/pad/face/config/lbp_svm.py: 0%

30 statements  

« prev     ^ index     » next       coverage.py v7.6.5, created at 2024-11-14 23:14 +0100

1import dask_ml.model_selection as dcv 

2 

3from sklearn.model_selection import StratifiedGroupKFold 

4from sklearn.pipeline import Pipeline 

5from sklearn.svm import SVC 

6 

7from bob.bio.face.annotator import MTCNN 

8from bob.bio.face.preprocessor import INormLBP 

9from bob.bio.face.utils import make_cropper, pad_default_cropping 

10from bob.pad.face.transformer import VideoToFrames 

11from bob.pad.face.transformer.histogram import SpatialHistogram 

12from bob.pipelines.wrappers import SampleWrapper 

13 

14 

15def _init_pipeline(database, crop_size=(112, 112), grid_size=(3, 3)): 

16 # Face Crop 

17 # -------------------------- 

18 annotator = MTCNN(thresholds=(0.1, 0.2, 0.2)) 

19 cropped_pos = pad_default_cropping(crop_size, database.annotation_type) 

20 cropper = make_cropper( 

21 cropped_image_size=crop_size, 

22 cropped_positions=cropped_pos, 

23 fixed_positions=database.fixed_positions, 

24 color_channel="rgb", 

25 annotator=annotator, 

26 ) 

27 face_cropper = SampleWrapper( 

28 cropper[0], transform_extra_arguments=cropper[1], delayed_output=False 

29 ) 

30 

31 # Extract LBP 

32 # -------------------------- 

33 lbp_extractor = INormLBP(face_cropper=None, color_channel="gray") 

34 lbp_extractor = SampleWrapper(lbp_extractor, delayed_output=False) 

35 

36 # Histogram 

37 # -------------------------- 

38 histo = SpatialHistogram(grid_size=grid_size, nbins=256) 

39 # histo = VideoWrapper(histo) 

40 histo = SampleWrapper(histo, delayed_output=False) 

41 

42 # Classifier 

43 # -------------------------- 

44 sk_classifier = SVC() 

45 param_grid = [ 

46 { 

47 "C": [2**p for p in range(-3, 14, 2)], 

48 "gamma": [2**p for p in range(-15, 0, 2)], 

49 "kernel": ["rbf"], 

50 } 

51 ] 

52 cv = StratifiedGroupKFold(n_splits=3) 

53 sk_classifier = dcv.GridSearchCV( 

54 sk_classifier, param_grid=param_grid, cv=cv 

55 ) 

56 fit_extra_arguments = [("y", "is_bonafide"), ("groups", "video_key")] 

57 classifier = SampleWrapper( 

58 sk_classifier, 

59 delayed_output=False, 

60 fit_extra_arguments=fit_extra_arguments, 

61 ) 

62 

63 # Full Pipeline 

64 # -------------------------- 

65 return Pipeline( 

66 [ 

67 ("video2frames", VideoToFrames()), 

68 ("cropper", face_cropper), 

69 ("lbp", lbp_extractor), 

70 ("spatial_histogram", histo), 

71 ("classifier", classifier), 

72 ] 

73 ) 

74 

75 

76# Get database information, needed for face cropper 

77db = globals()["database"] 

78if db is None: 

79 raise ValueError("Missing database!") 

80# Pipeline # 

81pipeline = _init_pipeline(database=db)