Source code for bob.learn.pytorch.datasets.casia_webface

#!/usr/bin/env python
# encoding: utf-8

import os
import numpy

from torch.utils.data import Dataset, DataLoader

import bob.io.base
import bob.io.image

from .utils import map_labels


[docs]class CasiaWebFaceDataset(Dataset): """Class representing the CASIA WebFace dataset Note that here the only label is identity Attributes ---------- root_dir : str The path to the data transform : `torchvision.transforms` The transform(s) to apply to the face images data_files : list of :obj:`str` The list of data files id_labels : list of :obj:`int` The list of identities, for each data file """ def __init__(self, root_dir, transform=None, start_index=0): """Init function Parameters ---------- root_dir : str The path to the data transform : :py:class:`torchvision.transforms` The transform(s) to apply to the face images start_index : int label of the first identity (useful if you use several databases) """ self.root_dir = root_dir self.transform = transform self.data_files = [] id_labels = [] for root, dirs, files in os.walk(self.root_dir): for name in files: filename = os.path.split(os.path.join(root, name))[-1] path = root.split(os.sep) subject = int(path[-1]) self.data_files.append(os.path.join(root, name)) id_labels.append(subject) self.id_labels = map_labels(id_labels, start_index) def __len__(self): """Returns the length of the dataset (i.e. nb of examples) Returns ------- int the number of examples in the dataset """ return len(self.data_files) def __getitem__(self, idx): """Returns a sample from the dataset Returns ------- dict an example of the dataset, containing the transformed face image and its identity """ image = bob.io.base.load(self.data_files[idx]) identity = self.id_labels[idx] sample = {"image": image, "label": identity} if self.transform: sample = self.transform(sample) return sample
[docs]class CasiaDataset(Dataset): """Class representing the CASIA WebFace dataset Note that in this class, two labels are provided with each image: identity and pose. Pose labels have been automatically inferred using the ROC face recognirion SDK from RankOne. There are 13 pose labels, corresponding to cluster of 15 degrees, ranging from -90 degress (left profile) to 90 degrees (right profile) Attributes ---------- root_dir: str The path to the data transform : `torchvision.transforms` The transform(s) to apply to the face images data_files: list of :obj:`str` The list of data files id_labels : list of :obj:`int` The list of identities, for each data file pose_labels : list of :obj:`int` The list containing the pose labels """ def __init__(self, root_dir, transform=None, start_index=0): """Init function Parameters ---------- root_dir: str The path to the data transform: :py:class:`torchvision.transforms` The transform(s) to apply to the face images start_index : int label of the first identity (useful if you use several databases) """ self.root_dir = root_dir self.transform = transform dir_to_pose_label = { "l90": "0", "l75": "1", "l60": "2", "l45": "3", "l30": "4", "l15": "5", "0": "6", "r15": "7", "r30": "8", "r45": "9", "r60": "10", "r75": "11", "r90": "12", } # get all the needed file, the pose labels, and the id labels self.data_files = [] self.pose_labels = [] id_labels = [] for root, dirs, files in os.walk(self.root_dir): for name in files: filename = os.path.split(os.path.join(root, name))[-1] path = root.split(os.sep) subject = int(path[-1]) cluster = path[-2] self.data_files.append(os.path.join(root, name)) self.pose_labels.append(int(dir_to_pose_label[cluster])) id_labels.append(subject) self.id_labels = map_labels(id_labels, start_index) def __len__(self): """Returns the length of the dataset (i.e. nb of examples) Returns ------- int the number of examples in the dataset """ return len(self.data_files) def __getitem__(self, idx): """Returns a sample from the dataset Returns ------- dict an example of the dataset, containing the transformed face image, its identity and pose information """ image = bob.io.base.load(self.data_files[idx]) identity = self.id_labels[idx] pose = self.pose_labels[idx] sample = {"image": image, "label": identity, "pose": pose} if self.transform: sample = self.transform(sample) return sample