Source code for bob.bio.base.database.atnt

#!/usr/bin/env python
"""
ATNT database implementation
"""

from pathlib import Path

from sklearn.pipeline import make_pipeline

import bob.io.base

from bob.bio.base.database.utils import download_file, md5_hash

from . import CSVDatabase, FileSampleLoader


class AtntBioDatabase(CSVDatabase):
    """
    The AT&T (aka ORL) database of faces
    (http://www.cl.cam.ac.uk/research/dtg/attarchive/facedatabase.html). This
    class defines a simple protocol for training, enrollment and probe by
    splitting the few images of the database in a reasonable manner. Due to the
    small size of the database, there is only a 'dev' group, and I did not
    define an 'eval' group.
    """

    category = "base"
    dataset_protocols_urls = [
        "https://www.idiap.ch/software/bob/databases/latest/base/atnt-f529acef.tar.gz",
        "http://www.idiap.ch/software/bob/databases/latest/base/atnt-f529acef.tar.gz",
    ]
    dataset_protocols_checksum = "f529acef"

    def __init__(
        self,
        protocol="idiap_protocol",
        dataset_original_directory=None,
        **kwargs,
    ):
        """Custom init to download the raw data files."""

        # Download the raw data (or use a cached local version)
        if dataset_original_directory is None:
            path = download_file(
                urls=[
                    "http://www.idiap.ch/software/bob/data/bob/att_faces.zip"
                ],
                destination_sub_directory="datasets/atnt",
                destination_filename="atnt_faces.zip",
                checksum="6efb25cb0d40755e9492b9c012e3348d",
                checksum_fct=md5_hash,
                extract=True,
            )
            dataset_original_directory = Path(path).as_posix()

        super().__init__(
            name="atnt",
            protocol=protocol,
            transformer=make_pipeline(
                FileSampleLoader(
                    data_loader=bob.io.base.load,
                    dataset_original_directory=dataset_original_directory,
                    extension=".pgm",
                ),
            ),
            **kwargs,
        )
        # just expose original_directory for backward compatibility of tests
        self.original_directory = dataset_original_directory
        self.original_extension = ".pgm"

    # define an objects method for compatibility with the old tests
[docs] def objects( self, model_ids=None, groups=None, purposes=None, protocol=None ): samples = [] if groups is None: groups = self.groups() if "train" in groups or "world" in groups: samples += self.background_model_samples() if purposes is None: purposes = ("enroll", "probe") if "enroll" in purposes: samples += self.references() if "probe" in purposes: samples += self.probes() if model_ids: samples = [s for s in samples if s.template_id in model_ids] # create the old attributes for s in samples: s.client_id = s.template_id s.path = s.id = s.key return samples