Coverage for src/bob/pad/face/config/lbp_svm.py: 0%
30 statements
« prev ^ index » next coverage.py v7.6.0, created at 2024-07-13 01:19 +0200
« prev ^ index » next coverage.py v7.6.0, created at 2024-07-13 01:19 +0200
1import dask_ml.model_selection as dcv
3from sklearn.model_selection import StratifiedGroupKFold
4from sklearn.pipeline import Pipeline
5from sklearn.svm import SVC
7from bob.bio.face.annotator import MTCNN
8from bob.bio.face.preprocessor import INormLBP
9from bob.bio.face.utils import make_cropper, pad_default_cropping
10from bob.pad.face.transformer import VideoToFrames
11from bob.pad.face.transformer.histogram import SpatialHistogram
12from bob.pipelines.wrappers import SampleWrapper
15def _init_pipeline(database, crop_size=(112, 112), grid_size=(3, 3)):
16 # Face Crop
17 # --------------------------
18 annotator = MTCNN(thresholds=(0.1, 0.2, 0.2))
19 cropped_pos = pad_default_cropping(crop_size, database.annotation_type)
20 cropper = make_cropper(
21 cropped_image_size=crop_size,
22 cropped_positions=cropped_pos,
23 fixed_positions=database.fixed_positions,
24 color_channel="rgb",
25 annotator=annotator,
26 )
27 face_cropper = SampleWrapper(
28 cropper[0], transform_extra_arguments=cropper[1], delayed_output=False
29 )
31 # Extract LBP
32 # --------------------------
33 lbp_extractor = INormLBP(face_cropper=None, color_channel="gray")
34 lbp_extractor = SampleWrapper(lbp_extractor, delayed_output=False)
36 # Histogram
37 # --------------------------
38 histo = SpatialHistogram(grid_size=grid_size, nbins=256)
39 # histo = VideoWrapper(histo)
40 histo = SampleWrapper(histo, delayed_output=False)
42 # Classifier
43 # --------------------------
44 sk_classifier = SVC()
45 param_grid = [
46 {
47 "C": [2**p for p in range(-3, 14, 2)],
48 "gamma": [2**p for p in range(-15, 0, 2)],
49 "kernel": ["rbf"],
50 }
51 ]
52 cv = StratifiedGroupKFold(n_splits=3)
53 sk_classifier = dcv.GridSearchCV(
54 sk_classifier, param_grid=param_grid, cv=cv
55 )
56 fit_extra_arguments = [("y", "is_bonafide"), ("groups", "video_key")]
57 classifier = SampleWrapper(
58 sk_classifier,
59 delayed_output=False,
60 fit_extra_arguments=fit_extra_arguments,
61 )
63 # Full Pipeline
64 # --------------------------
65 return Pipeline(
66 [
67 ("video2frames", VideoToFrames()),
68 ("cropper", face_cropper),
69 ("lbp", lbp_extractor),
70 ("spatial_histogram", histo),
71 ("classifier", classifier),
72 ]
73 )
76# Get database information, needed for face cropper
77db = globals()["database"]
78if db is None:
79 raise ValueError("Missing database!")
80# Pipeline #
81pipeline = _init_pipeline(database=db)