Source code for bob.ip.pytorch_extractor.MultiNetPatchExtractor

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
@author: Olegs Nikisins
"""

# =============================================================================
# Import what is needed here:
from bob.bio.base.extractor import Extractor

import numpy as np

from bob.ip.pytorch_extractor.utils import reshape_flat_patches
from bob.ip.pytorch_extractor.utils import transform_and_net_forward
from bob.ip.pytorch_extractor.utils import load_pretrained_model

# =============================================================================
# Main body:
class MultiNetPatchExtractor(Extractor, object):
    """
    This class is designed to pass a set of patches through a possibly multiple
    networks and compute a feature vector combining outputs of all networks.

    The functional work-flow is the following:

    First, an array of **flattened** input patches is converted to a list
    of patches with original dimensions (2D or 3D arrays).

    Second, each patch is passed through an individual network, for example
    an auto-encoder pre-trained for each patch type (left eye, for example).

    Third, outputs of all networks are concatenated into a single feature
    vector.

    Attributes
    -----------

    transform : object
        Function namely ``transform``, which is a  Compose transformation of
        torchvision package, to be applied to the input samples.

    network : object
        An instance of the Network to be used for feature extraction.
        Note: in current extractor the ``forward()`` method of the Network
        is used for feature extraction. For example, if you want to use the
        latent embeddings of the autoencoder class, initialize the network
        accordingly.

    model_file : [str]
        A list of paths to the model files to be used for ``network``
        initialization.

    patches_num : [int]
        A list of inices specifying which patches will be selected for
        processing/feature vector extraction.

    patch_reshape_parameters : [int] or None
        The parameters to be used for patch reshaping. The loaded patch is
        vectorized. Example:
        ``patch_reshape_parameters = [4, 8, 8]``, then the patch of the
        size (256,) will be reshaped to (4,8,8) dimensions. Only 2D and 3D
        patches are supported.
        Default: None.

    color_input_flag : bool
        If set to ``True``, the input is considered to be a color image of the
        size ``(3, H, W)``. The tensor to be passed through the net will be
        of the size ``(1, 3, H, W)``.
        If set to ``False``, the input is considered to be a set of BW images
        of the size ``(n_samples, H, W)``. The tensor to be passed through
        the net will be of the size ``(n_samples, 1, H, W)``.
        Default: ``False``.

    urls : [str]
        List of URLs to download the pretrained models from.
        If models are not available in the locations specified in the
        ``model_file`` list, the system will try to download them from
        ``urls``. The downloaded models **will be placed to the locations**
        specified in ``model_file`` list.

        For example, a model for an autoencoder pre-trained on
        RGB faces of the size (3(channels) x 128 x 128) and fine-tuned
        on the BW-NIR-D data can be found here:
        ["https://www.idiap.ch/software/bob/data/bob/bob.ip.pytorch_extractor/master/"
        "conv_ae_model_pretrain_celeba_tune_batl_full_face.pth.tar.gz"]

        Default: None

    archive_extension : str
        Extension of the archived files to download from above ``urls``.

        Default: '.tar.gz'
    """

    # =========================================================================
[docs] def __init__(self, transform, network, model_file, patches_num, patch_reshape_parameters = None, color_input_flag = False, urls = None, archive_extension = '.tar.gz'): """ Init method. """ super(MultiNetPatchExtractor, self).__init__(transform = transform, network = network, model_file = model_file, patches_num = patches_num, patch_reshape_parameters = patch_reshape_parameters, color_input_flag = color_input_flag, urls = urls, archive_extension = archive_extension) self.transform = transform self.network = network self.model_file = model_file self.patches_num = patches_num self.patch_reshape_parameters = patch_reshape_parameters self.color_input_flag = color_input_flag self.urls = urls self.archive_extension = archive_extension
# =========================================================================
[docs] def __call__(self, patches): """ Extract features combining outputs of multiple networks. Parameters ----------- patches : 2D :py:class:`numpy.ndarray` An array containing flattened patches. The dimensions are: ``num_patches x len_of_flat_patch`` Returns -------- features : :py:class:`numpy.ndarray` Feature vector. """ # select patches specified by indices: patches_selected = [patches[idx] for idx in self.patches_num] # convert to list to make it iterable # convert to original dimensions: patches_3d = reshape_flat_patches(patches_selected, self.patch_reshape_parameters) features_all_patches = [] # make sure the model_file and urls are not None, but lists: if self.model_file is None: self.model_file = [None] * len(self.patches_num) if self.urls is None: self.urls = [None] * len(self.patches_num) for idx, patch in enumerate(patches_3d): # try to load the model if not available, do nothing if available: load_pretrained_model(model_path = self.model_file[self.patches_num[idx]], url = self.urls[self.patches_num[idx]], archive_extension = self.archive_extension) if len(self.model_file) == 1: # patches are passed through the same network: features = transform_and_net_forward(feature = patch, transform = self.transform, network = self.network, model_file = self.model_file[0], color_input_flag = self.color_input_flag) else: # patches are passed through different networks: features = transform_and_net_forward(feature = patch, transform = self.transform, network = self.network, model_file = self.model_file[idx], color_input_flag = self.color_input_flag) # print ("The model we use for patch {} is:".format(str(idx))) # print (self.model_file[idx]) features_all_patches.append(features) features = np.hstack(features_all_patches) return features