Source code for bob.learn.pytorch.trainers.GenericTrainer

#!/usr/bin/env python
# encoding: utf-8

import copy
import os
import time
import numpy as np
import torch
import torch.nn as nn
from torch.autograd import Variable
from .tflog import Logger

import bob.core
logger = bob.core.log.setup("bob.learn.pytorch")


[docs]class GenericTrainer(object): """ Class to train a generic NN; all the parameters are provided in configs Attributes ---------- network: :py:class:`torch.nn.Module` The network to train optimizer: :py:class:`torch.optim.Optimizer` Optimizer object to be used. Initialized in the config file. device: str Device which will be used for training the model verbosity_level: int The level of verbosity output to stdout """ def __init__(self, network, optimizer, compute_loss, learning_rate=0.0001, device='cpu', verbosity_level=2, tf_logdir='tf_logs', do_crossvalidation=False, save_interval=5): """ Init function . The layers to be adapted in the network is selected and the gradients are set to `True` for the layers which needs to be adapted. Parameters ---------- network: :py:class:`torch.nn.Module` The network to train device: str Device which will be used for training the model verbosity_level: int The level of verbosity output to stdout do_crossvalidation: bool If set to `True`, performs validation in each epoch and stores the best model based on validation loss. """ self.network = network self.optimizer = optimizer self.compute_loss = compute_loss self.device = device self.learning_rate = learning_rate self.save_interval = save_interval self.do_crossvalidation = do_crossvalidation if self.do_crossvalidation: phases = ['train', 'val'] else: phases = ['train'] self.phases = phases # Move the network to device self.network.to(self.device) bob.core.log.set_verbosity_level(logger, verbosity_level) self.tf_logger = Logger(tf_logdir) # Setting the gradients to true for the layers which needs to be adapted
[docs] def load_model(self, model_filename): """Loads an existing model Parameters ---------- model_file: str The filename of the model to load Returns ------- start_epoch: int The epoch to start with start_iteration: int The iteration to start with losses: list(float) The list of losses from previous training """ cp = torch.load(model_filename) self.network.load_state_dict(cp['state_dict']) start_epoch = cp['epoch'] start_iter = cp['iteration'] losses = cp['loss'] return start_epoch, start_iter, losses
[docs] def save_model(self, output_dir, epoch=0, iteration=0, losses=None): """Save the trained network Parameters ---------- output_dir: str The directory to write the models to epoch: int the current epoch iteration: int the current (last) iteration losses: list(float) The list of losses since the beginning of training """ saved_filename = 'model_{}_{}.pth'.format(epoch, iteration) saved_path = os.path.join(output_dir, saved_filename) logger.info('Saving model to {}'.format(saved_path)) cp = {'epoch': epoch, 'iteration': iteration, 'loss': losses, 'state_dict': self.network.cpu().state_dict() } torch.save(cp, saved_path) self.network.to(self.device)
[docs] def train(self, dataloader, n_epochs=25, output_dir='out', model=None): """Performs the training. Parameters ---------- dataloader: :py:class:`torch.utils.data.DataLoader` The dataloader for your data n_epochs: int The number of epochs you would like to train for learning_rate: float The learning rate for Adam optimizer. output_dir: str The directory where you would like to save models model: str The path to a pretrained model file to start training from; this is the PAD model; not the LightCNN model """ # if model exists, load it if model is not None: start_epoch, start_iter, losses = self.load_model(model) logger.info('Starting training at epoch {}, iteration {} - last loss value is {}'.format( start_epoch, start_iter, losses[-1])) else: start_epoch = 0 start_iter = 0 losses = [] logger.info('Starting training from scratch') for name, param in self.network.named_parameters(): if param.requires_grad == True: logger.info( 'Layer to be adapted from grad check : {}'.format(name)) # setup optimizer self.network.train(True) best_model_wts = copy.deepcopy(self.network.state_dict()) best_loss = float("inf") # let's go for epoch in range(start_epoch, n_epochs): # in the epoch train_loss_history = [] val_loss_history = [] for phase in self.phases: if phase == 'train': self.network.train() # Set model to training mode else: self.network.eval() # Set model to evaluate mode for i, data in enumerate(dataloader[phase], 0): if i >= start_iter: start = time.time() # get data from dataset img, labels = data self.optimizer.zero_grad() with torch.set_grad_enabled(phase == 'train'): loss = self.compute_loss( self.network, img, labels, self.device) if phase == 'train': loss.backward() self.optimizer.step() train_loss_history.append(loss.item()) else: val_loss_history.append(loss.item()) end = time.time() logger.info("[{}/{}][{}/{}] => Loss = {} (time spent: {}), Phase {}".format( epoch, n_epochs, i, len(dataloader[phase]), loss.item(), (end-start), phase)) losses.append(loss.item()) epoch_train_loss = np.mean(train_loss_history) logger.info("Train Loss : {} epoch : {}".format( epoch_train_loss, epoch)) if self.do_crossvalidation: epoch_val_loss = np.mean(val_loss_history) logger.info("Val Loss : {} epoch : {}".format( epoch_val_loss, epoch)) if phase == 'val' and epoch_val_loss < best_loss: logger.debug("New val loss : {} is better than old: {}, copying over the new weights".format( epoch_val_loss, best_loss)) best_loss = epoch_val_loss best_model_wts = copy.deepcopy(self.network.state_dict()) ######################################## <Logging> ################################### if self.do_crossvalidation: info = {'train_loss': epoch_train_loss, 'val_loss': epoch_val_loss} else: info = {'train_loss': epoch_train_loss} # scalar logs for tag, value in info.items(): self.tf_logger.scalar_summary(tag, value, epoch+1) # Log values and gradients of the parameters (histogram summary) for tag, value in self.network.named_parameters(): tag = tag.replace('.', '/') try: self.tf_logger.histo_summary( tag, value.data.cpu().numpy(), epoch+1) self.tf_logger.histo_summary( tag+'/grad', value.grad.data.cpu().numpy(), epoch+1) except: pass ######################################## </Logging> ################################### # do stuff - like saving models logger.info("EPOCH {} DONE".format(epoch+1)) # comment it out after debugging # save the last model, and the ones in the specified interval if (epoch+1) == n_epochs or epoch % self.save_interval == 0: self.save_model(output_dir, epoch=(epoch+1), iteration=0, losses=losses) # load the best weights self.network.load_state_dict(best_model_wts) # best epoch is 0 self.save_model(output_dir, epoch=0, iteration=0, losses=losses)