Source code for bob.ip.binseg.engine.ssltrainer

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import csv
import datetime
import logging
import os
import shutil
import sys
import time

import numpy
import torch

from tqdm import tqdm

from ..utils.measure import SmoothedValue
from ..utils.resources import cpu_constants, cpu_log, gpu_constants, gpu_log
from ..utils.summary import summary
from .trainer import PYTORCH_GE_110, torch_evaluation

logger = logging.getLogger(__name__)


[docs]def sharpen(x, T): temp = x ** (1 / T) return temp / temp.sum(dim=1, keepdim=True)
[docs]def mix_up(alpha, input, target, unlabelled_input, unlabled_target): """Applies mix up as described in [MIXMATCH_19]. Parameters ---------- alpha : float input : :py:class:`torch.Tensor` target : :py:class:`torch.Tensor` unlabelled_input : :py:class:`torch.Tensor` unlabled_target : :py:class:`torch.Tensor` Returns ------- list """ with torch.no_grad(): l = numpy.random.beta(alpha, alpha) # Eq (8) l = max(l, 1 - l) # Eq (9) # Shuffle and concat. Alg. 1 Line: 12 w_inputs = torch.cat([input, unlabelled_input], 0) w_targets = torch.cat([target, unlabled_target], 0) idx = torch.randperm(w_inputs.size(0)) # get random index # Apply MixUp to labelled data and entries from W. Alg. 1 Line: 13 input_mixedup = l * input + (1 - l) * w_inputs[idx[len(input) :]] target_mixedup = l * target + (1 - l) * w_targets[idx[len(target) :]] # Apply MixUp to unlabelled data and entries from W. Alg. 1 Line: 14 unlabelled_input_mixedup = ( l * unlabelled_input + (1 - l) * w_inputs[idx[: len(unlabelled_input)]] ) unlabled_target_mixedup = ( l * unlabled_target + (1 - l) * w_targets[idx[: len(unlabled_target)]] ) return ( input_mixedup, target_mixedup, unlabelled_input_mixedup, unlabled_target_mixedup, )
[docs]def square_rampup(current, rampup_length=16): """slowly ramp-up ``lambda_u`` Parameters ---------- current : int current epoch rampup_length : :obj:`int`, optional how long to ramp up, by default 16 Returns ------- factor : float ramp up factor """ if rampup_length == 0: return 1.0 else: current = numpy.clip((current / float(rampup_length)) ** 2, 0.0, 1.0) return float(current)
[docs]def linear_rampup(current, rampup_length=16): """slowly ramp-up ``lambda_u`` Parameters ---------- current : int current epoch rampup_length : :obj:`int`, optional how long to ramp up, by default 16 Returns ------- factor: float ramp up factor """ if rampup_length == 0: return 1.0 else: current = numpy.clip(current / rampup_length, 0.0, 1.0) return float(current)
[docs]def guess_labels(unlabelled_images, model): """ Calculate the average predictions by 2 augmentations: horizontal and vertical flips Parameters ---------- unlabelled_images : :py:class:`torch.Tensor` ``[n,c,h,w]`` target : :py:class:`torch.Tensor` Returns ------- shape : :py:class:`torch.Tensor` ``[n,c,h,w]`` """ with torch.no_grad(): guess1 = torch.sigmoid(model(unlabelled_images)).unsqueeze(0) # Horizontal flip and unsqueeze to work with batches (increase flip dimension by 1) hflip = torch.sigmoid(model(unlabelled_images.flip(2))).unsqueeze(0) guess2 = hflip.flip(3) # Vertical flip and unsqueeze to work with batches (increase flip dimension by 1) vflip = torch.sigmoid(model(unlabelled_images.flip(3))).unsqueeze(0) guess3 = vflip.flip(4) # Concat concat = torch.cat([guess1, guess2, guess3], 0) avg_guess = torch.mean(concat, 0) return avg_guess
[docs]def run( model, data_loader, valid_loader, optimizer, criterion, scheduler, checkpointer, checkpoint_period, device, arguments, output_folder, rampup_length, ): """ Fits an FCN model using semi-supervised learning and saves it to disk. This method supports periodic checkpointing and the output of a CSV-formatted log with the evolution of some figures during training. Parameters ---------- model : :py:class:`torch.nn.Module` Network (e.g. driu, hed, unet) data_loader : :py:class:`torch.utils.data.DataLoader` To be used to train the model valid_loader : :py:class:`torch.utils.data.DataLoader` To be used to validate the model and enable automatic checkpointing. If set to ``None``, then do not validate it. optimizer : :py:mod:`torch.optim` criterion : :py:class:`torch.nn.modules.loss._Loss` loss function scheduler : :py:mod:`torch.optim` learning rate scheduler checkpointer : :py:class:`bob.ip.binseg.utils.checkpointer.Checkpointer` checkpointer implementation checkpoint_period : int save a checkpoint every ``n`` epochs. If set to ``0`` (zero), then do not save intermediary checkpoints device : str device to use ``'cpu'`` or ``cuda:0`` arguments : dict start and end epochs output_folder : str output path rampup_length : int rampup epochs """ start_epoch = arguments["epoch"] max_epoch = arguments["max_epoch"] if device != "cpu": # asserts we do have a GPU assert bool(gpu_constants()), ( f"Device set to '{device}', but cannot " f"find a GPU (maybe nvidia-smi is not installed?)" ) os.makedirs(output_folder, exist_ok=True) # Save model summary summary_path = os.path.join(output_folder, "model_summary.txt") logger.info(f"Saving model summary at {summary_path}...") with open(summary_path, "wt") as f: r, n = summary(model) logger.info(f"Model has {n} parameters...") f.write(r) # write static information to a CSV file static_logfile_name = os.path.join(output_folder, "constants.csv") if os.path.exists(static_logfile_name): backup = static_logfile_name + "~" if os.path.exists(backup): os.unlink(backup) shutil.move(static_logfile_name, backup) with open(static_logfile_name, "w", newline="") as f: logdata = cpu_constants() if device != "cpu": logdata += gpu_constants() logdata += (("model_size", n),) logwriter = csv.DictWriter(f, fieldnames=[k[0] for k in logdata]) logwriter.writeheader() logwriter.writerow(dict(k for k in logdata)) # Log continous information to (another) file logfile_name = os.path.join(output_folder, "trainlog.csv") if arguments["epoch"] == 0 and os.path.exists(logfile_name): backup = logfile_name + "~" if os.path.exists(backup): os.unlink(backup) shutil.move(logfile_name, backup) logfile_fields = ( "epoch", "total_time", "eta", "average_loss", "median_loss", "labelled_median_loss", "unlabelled_median_loss", "learning_rate", ) if valid_loader is not None: logfile_fields += ("validation_average_loss", "validation_median_loss") logfile_fields += tuple([k[0] for k in cpu_log()]) if device != "cpu": logfile_fields += tuple([k[0] for k in gpu_log()]) # the lowest validation loss obtained so far - this value is updated only # if a validation set is available lowest_validation_loss = sys.float_info.max with open(logfile_name, "a+", newline="") as logfile: logwriter = csv.DictWriter(logfile, fieldnames=logfile_fields) if arguments["epoch"] == 0: logwriter.writeheader() model.train() # set training mode model.to(device) # set/cast parameters to device for state in optimizer.state.values(): for k, v in state.items(): if isinstance(v, torch.Tensor): state[k] = v.to(device) # Total training timer start_training_time = time.time() for epoch in tqdm( range(start_epoch, max_epoch), desc="epoch", leave=False, disable=None, ): if not PYTORCH_GE_110: scheduler.step() losses = SmoothedValue(len(data_loader)) labelled_loss = SmoothedValue(len(data_loader)) unlabelled_loss = SmoothedValue(len(data_loader)) epoch = epoch + 1 arguments["epoch"] = epoch # Epoch time start_epoch_time = time.time() # progress bar only on interactive jobs for samples in tqdm( data_loader, desc="batch", leave=False, disable=None ): # data forwarding on the existing network # labelled images = samples[1].to( device=device, non_blocking=torch.cuda.is_available() ) ground_truths = samples[2].to( device=device, non_blocking=torch.cuda.is_available() ) unlabelled_images = samples[4].to( device=device, non_blocking=torch.cuda.is_available() ) # labelled outputs outputs = model(images) unlabelled_outputs = model(unlabelled_images) # guessed unlabelled outputs unlabelled_ground_truths = guess_labels( unlabelled_images, model ) # loss evaluation and learning (backward step) ramp_up_factor = square_rampup( epoch, rampup_length=rampup_length ) # note: no support for masks... loss, ll, ul = criterion( outputs, ground_truths, unlabelled_outputs, unlabelled_ground_truths, ramp_up_factor, ) optimizer.zero_grad() loss.backward() optimizer.step() losses.update(loss) labelled_loss.update(ll) unlabelled_loss.update(ul) logger.debug(f"batch loss: {loss.item()}") if PYTORCH_GE_110: scheduler.step() # calculates the validation loss if necessary # note: validation does not comprise "unlabelled" losses valid_losses = None if valid_loader is not None: with torch.no_grad(), torch_evaluation(model): valid_losses = SmoothedValue(len(valid_loader)) for samples in tqdm( valid_loader, desc="valid", leave=False, disable=None ): # data forwarding on the existing network images = samples[1].to( device=device, non_blocking=torch.cuda.is_available(), ) ground_truths = samples[2].to( device=device, non_blocking=torch.cuda.is_available(), ) masks = ( torch.ones_like(ground_truths) if len(samples) < 4 else samples[3].to( device=device, non_blocking=torch.cuda.is_available(), ) ) outputs = model(images) loss = criterion(outputs, ground_truths, masks) valid_losses.update(loss) if checkpoint_period and (epoch % checkpoint_period == 0): checkpointer.save(f"model_{epoch:03d}", **arguments) if ( valid_losses is not None and valid_losses.avg < lowest_validation_loss ): lowest_validation_loss = valid_losses.avg logger.info( f"Found new low on validation set:" f" {lowest_validation_loss:.6f}" ) checkpointer.save("model_lowest_valid_loss", **arguments) if epoch >= max_epoch: checkpointer.save("model_final", **arguments) # computes ETA (estimated time-of-arrival; end of training) taking # into consideration previous epoch performance epoch_time = time.time() - start_epoch_time eta_seconds = epoch_time * (max_epoch - epoch) current_time = time.time() - start_training_time logdata = ( ("epoch", f"{epoch}"), ( "total_time", f"{datetime.timedelta(seconds=int(current_time))}", ), ("eta", f"{datetime.timedelta(seconds=int(eta_seconds))}"), ("average_loss", f"{losses.avg:.6f}"), ("median_loss", f"{losses.median:.6f}"), ("labelled_median_loss", f"{labelled_loss.median:.6f}"), ("unlabelled_median_loss", f"{unlabelled_loss.median:.6f}"), ("learning_rate", f"{optimizer.param_groups[0]['lr']:.6f}"), ) if valid_losses is not None: logdata += ( ("validation_average_loss", f"{valid_losses.avg:.6f}"), ("validation_median_loss", f"{valid_losses.median:.6f}"), ) logdata += cpu_log() if device != "cpu": logdata += gpu_log() if device != "cpu": logdata += gpu_log() logwriter.writerow(dict(k for k in logdata)) logfile.flush() tqdm.write("|".join([f"{k}: {v}" for (k, v) in logdata[:4]])) total_training_time = time.time() - start_training_time logger.info( f"Total training time: {datetime.timedelta(seconds=total_training_time)} ({(total_training_time/max_epoch):.4f}s in average per epoch)" )