#!/usr/bin/env python
# encoding: utf-8
import os
import numpy
from torch.utils.data import Dataset, DataLoader
import bob.io.base
import bob.io.image
from .utils import map_labels
[docs]class CasiaWebFaceDataset(Dataset):
"""Class representing the CASIA WebFace dataset
Note that here the only label is identity
Attributes
----------
root_dir : str
The path to the data
transform : `torchvision.transforms`
The transform(s) to apply to the face images
data_files : list of :obj:`str`
The list of data files
id_labels : list of :obj:`int`
The list of identities, for each data file
"""
def __init__(self, root_dir, transform=None, start_index=0):
"""Init function
Parameters
----------
root_dir : str
The path to the data
transform : :py:class:`torchvision.transforms`
The transform(s) to apply to the face images
start_index : int
label of the first identity (useful if you use
several databases)
"""
self.root_dir = root_dir
self.transform = transform
self.data_files = []
id_labels = []
for root, dirs, files in os.walk(self.root_dir):
for name in files:
filename = os.path.split(os.path.join(root, name))[-1]
path = root.split(os.sep)
subject = int(path[-1])
self.data_files.append(os.path.join(root, name))
id_labels.append(subject)
self.id_labels = map_labels(id_labels, start_index)
def __len__(self):
"""Returns the length of the dataset (i.e. nb of examples)
Returns
-------
int
the number of examples in the dataset
"""
return len(self.data_files)
def __getitem__(self, idx):
"""Returns a sample from the dataset
Returns
-------
dict
an example of the dataset, containing the
transformed face image and its identity
"""
image = bob.io.base.load(self.data_files[idx])
identity = self.id_labels[idx]
sample = {'image': image, 'label': identity}
if self.transform:
sample = self.transform(sample)
return sample
[docs]class CasiaDataset(Dataset):
"""Class representing the CASIA WebFace dataset
Note that in this class, two labels are provided
with each image: identity and pose.
Pose labels have been automatically inferred using
the ROC face recognirion SDK from RankOne.
There are 13 pose labels, corresponding to cluster
of 15 degrees, ranging from -90 degress (left profile)
to 90 degrees (right profile)
Attributes
----------
root_dir: str
The path to the data
transform : `torchvision.transforms`
The transform(s) to apply to the face images
data_files: list of :obj:`str`
The list of data files
id_labels : list of :obj:`int`
The list of identities, for each data file
pose_labels : list of :obj:`int`
The list containing the pose labels
"""
def __init__(self, root_dir, transform=None, start_index=0):
"""Init function
Parameters
----------
root_dir: str
The path to the data
transform: :py:class:`torchvision.transforms`
The transform(s) to apply to the face images
start_index : int
label of the first identity (useful if you use
several databases)
"""
self.root_dir = root_dir
self.transform = transform
dir_to_pose_label = {'l90': '0',
'l75': '1',
'l60': '2',
'l45': '3',
'l30': '4',
'l15': '5',
'0' : '6',
'r15': '7',
'r30': '8',
'r45': '9',
'r60': '10',
'r75': '11',
'r90': '12',
}
# get all the needed file, the pose labels, and the id labels
self.data_files = []
self.pose_labels = []
id_labels = []
for root, dirs, files in os.walk(self.root_dir):
for name in files:
filename = os.path.split(os.path.join(root, name))[-1]
path = root.split(os.sep)
subject = int(path[-1])
cluster = path[-2]
self.data_files.append(os.path.join(root, name))
self.pose_labels.append(int(dir_to_pose_label[cluster]))
id_labels.append(subject)
self.id_labels = map_labels(id_labels, start_index)
def __len__(self):
"""Returns the length of the dataset (i.e. nb of examples)
Returns
-------
int
the number of examples in the dataset
"""
return len(self.data_files)
def __getitem__(self, idx):
"""Returns a sample from the dataset
Returns
-------
dict
an example of the dataset, containing the
transformed face image, its identity and pose information
"""
image = bob.io.base.load(self.data_files[idx])
identity = self.id_labels[idx]
pose = self.pose_labels[idx]
sample = {'image': image, 'label': identity, 'pose': pose}
if self.transform:
sample = self.transform(sample)
return sample