Coverage for src/bob/pad/face/deep_pix_bis.py: 91%
89 statements
« prev ^ index » next coverage.py v7.6.5, created at 2024-11-14 23:14 +0100
« prev ^ index » next coverage.py v7.6.5, created at 2024-11-14 23:14 +0100
1import logging
3import numpy as np
4import torch
5import torchvision.transforms as vision_transforms
7from sklearn.base import BaseEstimator, ClassifierMixin
8from torch import nn
9from torchvision import models
11from bob.bio.base.database.utils import download_file
12from bob.io.image import to_matplotlib
14logger = logging.getLogger(__name__)
17DEEP_PIX_BIS_PRETRAINED_MODELS = {
18 "oulu-npu-p1": [
19 "http://www.idiap.ch/software/bob/data/bob/bob.pad.face/deep_pix_bis_OULU_Protocol_1_model_0_0-24844429.pth"
20 ],
21 "oulu-npu-p2": [
22 "http://www.idiap.ch/software/bob/data/bob/bob.pad.face/deep_pix_bis_OULU_Protocol_2_model_0_0-4aae2f3a.pth"
23 ],
24 "oulu-npu-p3-1": [
25 "http://www.idiap.ch/software/bob/data/bob/bob.pad.face/deep_pix_bis_OULU_Protocol_3_1_model_0_0-f0e70cf3.pth"
26 ],
27 "oulu-npu-p3-2": [
28 "http://www.idiap.ch/software/bob/data/bob/bob.pad.face/deep_pix_bis_OULU_Protocol_3_2_model_0_0-92594797.pth"
29 ],
30 "oulu-npu-p3-3": [
31 "http://www.idiap.ch/software/bob/data/bob/bob.pad.face/deep_pix_bis_OULU_Protocol_3_3_model_0_0-71e18149.pth"
32 ],
33 "oulu-npu-p3-4": [
34 "http://www.idiap.ch/software/bob/data/bob/bob.pad.face/deep_pix_bis_OULU_Protocol_3_4_model_0_0-d7f666e5.pth"
35 ],
36 "oulu-npu-p3-5": [
37 "http://www.idiap.ch/software/bob/data/bob/bob.pad.face/deep_pix_bis_OULU_Protocol_3_5_model_0_0-fc40ba69.pth"
38 ],
39 "oulu-npu-p3-6": [
40 "http://www.idiap.ch/software/bob/data/bob/bob.pad.face/deep_pix_bis_OULU_Protocol_3_6_model_0_0-123a6c92.pth"
41 ],
42 "oulu-npu-p4-1": [
43 "http://www.idiap.ch/software/bob/data/bob/bob.pad.face/deep_pix_bis_OULU_Protocol_4_1_model_0_0-5f8dc7cf.pth"
44 ],
45 "oulu-npu-p4-2": [
46 "http://www.idiap.ch/software/bob/data/bob/bob.pad.face/deep_pix_bis_OULU_Protocol_4_2_model_0_0-168f2644.pth"
47 ],
48 "oulu-npu-p4-3": [
49 "http://www.idiap.ch/software/bob/data/bob/bob.pad.face/deep_pix_bis_OULU_Protocol_4_3_model_0_0-db57e3b5.pth"
50 ],
51 "oulu-npu-p4-4": [
52 "http://www.idiap.ch/software/bob/data/bob/bob.pad.face/deep_pix_bis_OULU_Protocol_4_4_model_0_0-e999b7e8.pth"
53 ],
54 "oulu-npu-p4-5": [
55 "http://www.idiap.ch/software/bob/data/bob/bob.pad.face/deep_pix_bis_OULU_Protocol_4_5_model_0_0-dcd13b8b.pth"
56 ],
57 "oulu-npu-p4-6": [
58 "http://www.idiap.ch/software/bob/data/bob/bob.pad.face/deep_pix_bis_OULU_Protocol_4_6_model_0_0-96a1ab92.pth"
59 ],
60 "replay-mobile": [
61 "http://www.idiap.ch/software/bob/data/bob/bob.pad.face/deep_pix_bis_RM_grandtest_model_0_0-6761ca7e.pth"
62 ],
63}
64"A dictionary with the url paths to pre-trained weights of the DeepPixBis model."
67class DeepPixBiS(nn.Module):
68 """The class defining Deep Pixelwise Binary Supervision for Face Presentation
69 Attack Detection:
71 Reference: Anjith George and Sébastien Marcel. "Deep Pixel-wise Binary Supervision for
72 Face Presentation Attack Detection." In 2019 International Conference on Biometrics (ICB).IEEE, 2019.
74 Attributes
75 ----------
76 pretrained: bool
77 If set to `True` uses the pretrained DenseNet model as the base. If set to `False`, the network
78 will be trained from scratch.
79 """
81 def __init__(self, pretrained=True, **kwargs):
82 """
83 Parameters
84 ----------
85 pretrained: bool
86 If set to `True` uses the pretrained densenet model as the base. Else, it uses the default network
87 """
88 super().__init__(**kwargs)
90 dense = models.densenet161(pretrained=pretrained)
92 features = list(dense.features.children())
94 self.enc = nn.Sequential(*features[0:8])
96 self.dec = nn.Conv2d(384, 1, kernel_size=1, padding=0)
98 self.linear = nn.Linear(14 * 14, 1)
100 def forward(self, x):
101 """Propagate data through the network
103 Parameters
104 ----------
105 img: :py:class:`torch.Tensor`
106 The data to forward through the network. Expects RGB image of size 3x224x224
108 Returns
109 -------
110 dec: :py:class:`torch.Tensor`
111 Binary map of size 1x14x14
112 op: :py:class:`torch.Tensor`
113 Final binary score.
115 """
116 enc = self.enc(x)
118 dec = self.dec(enc)
120 dec = nn.Sigmoid()(dec)
122 dec_flat = dec.view(-1, 14 * 14)
124 op = self.linear(dec_flat)
126 op = nn.Sigmoid()(op)
128 return dec, op
131class DeepPixBisClassifier(BaseEstimator, ClassifierMixin):
132 """The class implementing the DeepPixBiS score computation"""
134 def __init__(
135 self,
136 model_file=None,
137 transforms=None,
138 scoring_method="pixel_mean",
139 device=None,
140 threshold=0.8,
141 **kwargs,
142 ):
143 """Init method
145 Parameters
146 ----------
147 model_file: str
148 The path of the trained PAD network to load or one of the keys to :py:attr:`DEEP_PIX_BIS_PRETRAINED_MODELS`
149 transforms: :py:mod:`torchvision.transforms`
150 Transform to be applied on the image
151 scoring_method: str
152 The scoring method to be used to get the final score,
153 available methods are ['pixel_mean','binary','combined'].
154 threshold: float
155 The threshold to be used to binarize the output of the DeepPixBiS model.
156 This is not used in the normal bob.pad.base pipeline.
157 """
158 super().__init__(**kwargs)
160 if transforms is None:
161 transforms = vision_transforms.Compose(
162 [
163 vision_transforms.ToTensor(),
164 vision_transforms.Normalize(
165 mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
166 ),
167 ]
168 )
170 # model
171 self.transforms = transforms
172 self.model = None
173 self.scoring_method = scoring_method.lower()
174 if self.scoring_method not in ("pixel_mean", "binary", "combined"):
175 raise ValueError(
176 "Scoring method {} is not implemented.".format(
177 self.scoring_method
178 )
179 )
180 self.device = device
181 self.threshold = threshold
183 logger.debug(
184 "Scoring method is : {}".format(self.scoring_method.upper())
185 )
187 if model_file in DEEP_PIX_BIS_PRETRAINED_MODELS:
188 model_urls = DEEP_PIX_BIS_PRETRAINED_MODELS[model_file]
189 filename = model_urls[0].split("/")[-1]
190 file_hash = (
191 model_urls[0].split("/")[-1].split("-")[-1].split(".")[0]
192 )
193 model_file = download_file(
194 urls=model_urls,
195 destination_filename=filename,
196 destination_sub_directory="models",
197 checksum=file_hash,
198 extract=False,
199 )
201 logger.debug("Using pretrained model {}".format(model_file))
202 self.model_file = model_file
204 def load_model(self):
205 if self.model is not None:
206 return
208 cp = torch.load(
209 self.model_file, map_location=lambda storage, loc: storage
210 )
212 self.model = DeepPixBiS(pretrained=False)
213 self.model.load_state_dict(cp["state_dict"])
214 self.place_model_on_device()
215 self.model.eval()
216 logger.debug("Loaded the pretrained PAD model")
218 def predict_proba(self, images):
219 """Scores face images for PAD
221 Parameters
222 ----------
223 image : 3D :py:class:`numpy.ndarray`
224 The image to extract the score from. Its size must be 3x224x224;
226 Returns
227 -------
228 output : float
229 The output score is close to 1 for bonafide and 0 for PAs.
230 """
231 self.load_model()
233 tensor_images = []
234 for img in images:
235 img = to_matplotlib(img)
236 with torch.no_grad():
237 img = self.transforms(img)
238 tensor_images.append(img)
240 images = tensor_images = torch.stack(tensor_images).to(self.device)
242 with torch.no_grad():
243 outputs = self.model.forward(images)
245 output_pixel = outputs[0].cpu().detach().numpy().mean(axis=(1, 2, 3))
246 output_binary = outputs[1].cpu().detach().numpy().mean(axis=1)
248 score = {
249 "pixel_mean": output_pixel,
250 "binary": output_binary,
251 "combined": (output_binary + output_pixel) / 2,
252 }[self.scoring_method]
254 print(score)
255 return score
257 def predict(self, X):
258 scores = self.predict_proba(X)
259 return np.int(scores > self.threshold)
261 def fit(self, X, y=None):
262 """No training required for this model"""
263 return self
265 def __getstate__(self):
266 # Handling unpicklable objects
267 d = self.__dict__.copy()
268 d["model"] = None
269 return d
271 def _more_tags(self):
272 return {"requires_fit": False}
274 def place_model_on_device(self):
275 if self.device is None:
276 self.device = torch.device(
277 "cuda" if torch.cuda.is_available() else "cpu"
278 )
279 if self.model is not None:
280 self.model.to(self.device)