Source code for bob.learn.pytorch.architectures.MCCNN

#!/usr/bin/env python
# encoding: utf-8

import torch
import torch.nn as nn
import torch.nn.functional as F

import os
import numpy as np

import pkg_resources
import bob.extension.download

import bob.io.base

from .utils import MaxFeatureMap
from .utils import group
from .utils import resblock

import logging

logger = logging.getLogger("bob.learn.pytorch")


[docs]class MCCNN(nn.Module): """ The class defining the MCCNN This class implements the MCCNN for multi-channel PAD Attributes ---------- num_channels: int The number of channels present in the input lcnn_layers: list The adaptable layers present in the base LightCNN model module_dict: dict A dictionary containing module names and `torch.nn.Module` elements as key, value pairs. layer_dict: :py:class:`torch.nn.ModuleDict` Pytorch class containing the modules as a dictionary. light_cnn_model_file: str Absolute path to the pretrained LightCNN model file. url: str The path to download the pretrained LightCNN model from. """ def __init__( self, block=resblock, layers=[1, 2, 3, 4], num_channels=4, verbosity_level=2, use_sigmoid=True, ): """ Init function Parameters ---------- num_channels: int The number of channels present in the input use_sigmoid: bool Whether to use sigmoid in eval phase. If set to `False` do not use sigmoid in eval phase. Training phase is not affected. verbosity_level: int Verbosity level. """ super(MCCNN, self).__init__() self.num_channels = num_channels self.use_sigmoid = use_sigmoid self.lcnn_layers = [ "conv1", "block1", "group1", "block2", "group2", "block3", "group3", "block4", "group4", "fc", ] logger.setLevel(verbosity_level) self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True) self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True) self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True) self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True) # newly added FC layers self.linear1fc = nn.Linear(256 * num_channels, 10) self.linear2fc = nn.Linear(10, 1) # add modules module_dict = {} for i in range(self.num_channels): m_dict = {} m_dict["conv1"] = MaxFeatureMap(1, 48, 5, 1, 2) m_dict["block1"] = self._make_layer(block, layers[0], 48, 48) m_dict["group1"] = group(48, 96, 3, 1, 1) m_dict["block2"] = self._make_layer(block, layers[1], 96, 96) m_dict["group2"] = group(96, 192, 3, 1, 1) m_dict["block3"] = self._make_layer(block, layers[2], 192, 192) m_dict["group3"] = group(192, 128, 3, 1, 1) m_dict["block4"] = self._make_layer(block, layers[3], 128, 128) m_dict["group4"] = group(128, 128, 3, 1, 1) m_dict["fc"] = MaxFeatureMap(8 * 8 * 128, 256, type=0) # ch_0_should be the anchor for layer in self.lcnn_layers: layer_name = "ch_{}_".format(i) + layer module_dict[layer_name] = m_dict[layer] self.layer_dict = nn.ModuleDict(module_dict) # check for pretrained model light_cnn_model_file = os.path.join( MCCNN.get_mccnnpath(), "LightCNN_29Layers_checkpoint.pth.tar" ) url = "http://www.idiap.ch/software/bob/data/bob/bob.learn.pytorch/master/LightCNN_29Layers_checkpoint.pth.tar" logger.info("Light_cnn_model_file path: {}".format(light_cnn_model_file)) if not os.path.exists(light_cnn_model_file): bob.io.base.create_directories_safe(os.path.split(light_cnn_model_file)[0]) logger.info("Downloading the LightCNN model") bob.extension.download.download_file(url, light_cnn_model_file) logger.info( "Downloaded LightCNN model to location: {}".format(light_cnn_model_file) ) ## Loding the pretrained model for ch_0 self.load_state_dict( self.get_model_state_dict(light_cnn_model_file), strict=False ) # copy over the weights to all other layers for layer in self.lcnn_layers: for i in range(1, self.num_channels): # except for 0 th channel self.layer_dict["ch_{}_".format(i) + layer].load_state_dict( self.layer_dict["ch_0_" + layer].state_dict() ) def _make_layer(self, block, num_blocks, in_channels, out_channels): """ makes multiple copies of the same base module Parameters ---------- block: :py:class:`torch.nn.Module` The base block to replicate num_blocks: int Number of copies of the block to be made in_channels: int Number of input channels for a block out_channels: int Number of output channels for a block """ layers = [] for i in range(0, num_blocks): layers.append(block(in_channels, out_channels)) return nn.Sequential(*layers)
[docs] def forward(self, img): """ Propagate data through the network Parameters ---------- img: :py:class:`torch.Tensor` The data to forward through the network. Image of size num_channelsx128x128 Returns ------- output: :py:class:`torch.Tensor` score """ embeddings = [] for i in range(self.num_channels): x = img[:, i, :, :].unsqueeze(1) # the image for the specific channel x = self.layer_dict["ch_{}_".format(i) + "conv1"](x) x = self.pool1(x) x = self.layer_dict["ch_{}_".format(i) + "block1"](x) x = self.layer_dict["ch_{}_".format(i) + "group1"](x) x = self.pool2(x) x = self.layer_dict["ch_{}_".format(i) + "block2"](x) x = self.layer_dict["ch_{}_".format(i) + "group2"](x) x = self.pool3(x) x = self.layer_dict["ch_{}_".format(i) + "block3"](x) x = self.layer_dict["ch_{}_".format(i) + "group3"](x) x = self.layer_dict["ch_{}_".format(i) + "block4"](x) x = self.layer_dict["ch_{}_".format(i) + "group4"](x) x = self.pool4(x) x = x.view(x.size(0), -1) fc = self.layer_dict["ch_{}_".format(i) + "fc"](x) fc = F.dropout(fc, training=self.training) embeddings.append(fc) merged = torch.cat(embeddings, 1) output = self.linear1fc(merged) output = nn.Sigmoid()(output) output = self.linear2fc(output) if self.training or self.use_sigmoid: output = nn.Sigmoid()(output) return output
[docs] @staticmethod def get_mccnnpath(): import pkg_resources return pkg_resources.resource_filename("bob.learn.pytorch", "models")
[docs] def get_model_state_dict(self, pretrained_model_path): """ The class to load pretrained LightCNN model Attributes ---------- pretrained_model_path: str Absolute path to the LightCNN model file new_state_dict: dict Dictionary with LightCNN weights """ checkpoint = torch.load( pretrained_model_path, map_location=lambda storage, loc: storage ) start_epoch = checkpoint["epoch"] state_dict = checkpoint["state_dict"] # create new OrderedDict that does not contain `module.` from collections import OrderedDict new_state_dict = OrderedDict() for k, v in state_dict.items(): name = "layer_dict.ch_0_" + k[7:] # remove `module.` new_state_dict[name] = v # load params return new_state_dict