Source code for bob.ip.pytorch_extractor.MultiNetPatchExtractor
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
@author: Olegs Nikisins
"""
# =============================================================================
# Import what is needed here:
from bob.bio.base.extractor import Extractor
import numpy as np
from bob.ip.pytorch_extractor.utils import reshape_flat_patches
from bob.ip.pytorch_extractor.utils import transform_and_net_forward
from bob.ip.pytorch_extractor.utils import load_pretrained_model
# =============================================================================
# Main body:
class MultiNetPatchExtractor(Extractor, object):
"""
This class is designed to pass a set of patches through a possibly multiple
networks and compute a feature vector combining outputs of all networks.
The functional work-flow is the following:
First, an array of **flattened** input patches is converted to a list
of patches with original dimensions (2D or 3D arrays).
Second, each patch is passed through an individual network, for example
an auto-encoder pre-trained for each patch type (left eye, for example).
Third, outputs of all networks are concatenated into a single feature
vector.
Attributes
-----------
transform : object
Function namely ``transform``, which is a Compose transformation of
torchvision package, to be applied to the input samples.
network : object
An instance of the Network to be used for feature extraction.
Note: in current extractor the ``forward()`` method of the Network
is used for feature extraction. For example, if you want to use the
latent embeddings of the autoencoder class, initialize the network
accordingly.
model_file : [str]
A list of paths to the model files to be used for ``network``
initialization.
patches_num : [int]
A list of inices specifying which patches will be selected for
processing/feature vector extraction.
patch_reshape_parameters : [int] or None
The parameters to be used for patch reshaping. The loaded patch is
vectorized. Example:
``patch_reshape_parameters = [4, 8, 8]``, then the patch of the
size (256,) will be reshaped to (4,8,8) dimensions. Only 2D and 3D
patches are supported.
Default: None.
color_input_flag : bool
If set to ``True``, the input is considered to be a color image of the
size ``(3, H, W)``. The tensor to be passed through the net will be
of the size ``(1, 3, H, W)``.
If set to ``False``, the input is considered to be a set of BW images
of the size ``(n_samples, H, W)``. The tensor to be passed through
the net will be of the size ``(n_samples, 1, H, W)``.
Default: ``False``.
urls : [str]
List of URLs to download the pretrained models from.
If models are not available in the locations specified in the
``model_file`` list, the system will try to download them from
``urls``. The downloaded models **will be placed to the locations**
specified in ``model_file`` list.
For example, a model for an autoencoder pre-trained on
RGB faces of the size (3(channels) x 128 x 128) and fine-tuned
on the BW-NIR-D data can be found here:
["https://www.idiap.ch/software/bob/data/bob/bob.ip.pytorch_extractor/master/"
"conv_ae_model_pretrain_celeba_tune_batl_full_face.pth.tar.gz"]
Default: None
archive_extension : str
Extension of the archived files to download from above ``urls``.
Default: '.tar.gz'
"""
# =========================================================================
[docs] def __init__(self,
transform,
network,
model_file,
patches_num,
patch_reshape_parameters = None,
color_input_flag = False,
urls = None,
archive_extension = '.tar.gz'):
"""
Init method.
"""
super(MultiNetPatchExtractor, self).__init__(transform = transform,
network = network,
model_file = model_file,
patches_num = patches_num,
patch_reshape_parameters = patch_reshape_parameters,
color_input_flag = color_input_flag,
urls = urls,
archive_extension = archive_extension)
self.transform = transform
self.network = network
self.model_file = model_file
self.patches_num = patches_num
self.patch_reshape_parameters = patch_reshape_parameters
self.color_input_flag = color_input_flag
self.urls = urls
self.archive_extension = archive_extension
# =========================================================================
[docs] def __call__(self, patches):
"""
Extract features combining outputs of multiple networks.
Parameters
-----------
patches : 2D :py:class:`numpy.ndarray`
An array containing flattened patches. The dimensions are:
``num_patches x len_of_flat_patch``
Returns
--------
features : :py:class:`numpy.ndarray`
Feature vector.
"""
# select patches specified by indices:
patches_selected = [patches[idx] for idx in self.patches_num] # convert to list to make it iterable
# convert to original dimensions:
patches_3d = reshape_flat_patches(patches_selected, self.patch_reshape_parameters)
features_all_patches = []
# make sure the model_file and urls are not None, but lists:
if self.model_file is None:
self.model_file = [None] * len(self.patches_num)
if self.urls is None:
self.urls = [None] * len(self.patches_num)
for idx, patch in enumerate(patches_3d):
# try to load the model if not available, do nothing if available:
load_pretrained_model(model_path = self.model_file[self.patches_num[idx]],
url = self.urls[self.patches_num[idx]],
archive_extension = self.archive_extension)
if len(self.model_file) == 1: # patches are passed through the same network:
features = transform_and_net_forward(feature = patch,
transform = self.transform,
network = self.network,
model_file = self.model_file[0],
color_input_flag = self.color_input_flag)
else: # patches are passed through different networks:
features = transform_and_net_forward(feature = patch,
transform = self.transform,
network = self.network,
model_file = self.model_file[idx],
color_input_flag = self.color_input_flag)
# print ("The model we use for patch {} is:".format(str(idx)))
# print (self.model_file[idx])
features_all_patches.append(features)
features = np.hstack(features_all_patches)
return features