deepdraw.engine.trainer#

Functions

check_exist_logfile(logfile_name, arguments)

Check existance of logfile (trainlog.csv), If the logfile exist the and the epochs number are still 0, The logfile will be replaced.

check_gpu(device)

Check the device type and the availability of GPU.

checkpointer_process(checkpointer, ...)

Process the checkpointer, save the final model and keep track of the best model.

create_logfile_fields(valid_loader, ...)

Creation of the logfile fields that will appear in the logfile.

run(model, data_loader, valid_loader, ...)

Fits an FCN model using supervised learning and save it to disk.

save_model_summary(output_folder, model)

Save a little summary of the model in a txt file.

static_information_to_csv(...)

Save the static information in a csv file.

torch_evaluation(model)

Context manager to turn ON/OFF model evaluation.

train_epoch(loader, model, optimizer, ...)

Trains the model for a single epoch (through all batches)

validate_epoch(loader, model, device, ...)

Processes input samples and returns loss (scalar)

write_log_info(epoch, current_time, ...)

Write log info in trainlog.csv.

deepdraw.engine.trainer.torch_evaluation(model)[source]#

Context manager to turn ON/OFF model evaluation.

This context manager will turn evaluation mode ON on entry and turn it OFF when exiting the with statement block.

Parameters:

model (torch.nn.Module) – Network (e.g. driu, hed, unet)

Yields:

model (torch.nn.Module) – Network (e.g. driu, hed, unet)

deepdraw.engine.trainer.check_gpu(device)[source]#

Check the device type and the availability of GPU.

Parameters:

device (torch.device) – device to use

deepdraw.engine.trainer.save_model_summary(output_folder, model)[source]#

Save a little summary of the model in a txt file.

Parameters:
  • output_folder (str) – output path

  • model (torch.nn.Module) – Network (e.g. driu, hed, unet)

Returns:

  • r (str) – The model summary in a text format.

  • n (int) – The number of parameters of the model.

deepdraw.engine.trainer.static_information_to_csv(static_logfile_name, device, n)[source]#

Save the static information in a csv file.

Parameters:

static_logfile_name (str) – The static file name which is a join between the output folder and “constant.csv”

deepdraw.engine.trainer.check_exist_logfile(logfile_name, arguments)[source]#

Check existance of logfile (trainlog.csv), If the logfile exist the and the epochs number are still 0, The logfile will be replaced.

Parameters:
  • logfile_name (str) – The logfile_name which is a join between the output_folder and trainlog.csv

  • arguments (dict) – start and end epochs

deepdraw.engine.trainer.create_logfile_fields(valid_loader, extra_valid_loaders, device)[source]#

Creation of the logfile fields that will appear in the logfile.

Parameters:
  • valid_loader (torch.utils.data.DataLoader) – To be used to validate the model and enable automatic checkpointing. If set to None, then do not validate it.

  • extra_valid_loaders (list of torch.utils.data.DataLoader) – To be used to validate the model, however does not affect automatic checkpointing. If set to None, or empty, then does not log anything else. Otherwise, an extra column with the loss of every dataset in this list is kept on the final training log.

  • device (torch.device) – device to use

Returns:

  • logfile_fields (tuple) – The fields that will appear in trainlog.csv

deepdraw.engine.trainer.train_epoch(loader, model, optimizer, device, criterion, batch_chunk_count)[source]#

Trains the model for a single epoch (through all batches)

Parameters:
Returns:

  • loss (float) – A floating-point value corresponding the weighted average of this epoch’s loss

deepdraw.engine.trainer.validate_epoch(loader, model, device, criterion, pbar_desc)[source]#

Processes input samples and returns loss (scalar)

Parameters:
Returns:

  • loss (float) – A floating-point value corresponding the weighted average of this epoch’s loss

deepdraw.engine.trainer.checkpointer_process(checkpointer, checkpoint_period, valid_loss, lowest_validation_loss, arguments, epoch, max_epoch)[source]#

Process the checkpointer, save the final model and keep track of the best model.

Parameters:
  • checkpointer (deepdraw.utils.checkpointer.Checkpointer) – checkpointer implementation

  • checkpoint_period (int) – save a checkpoint every n epochs. If set to 0 (zero), then do not save intermediary checkpoints

  • valid_loss (float) – Current epoch validation loss

  • lowest_validation_loss (float) – Keeps track of the best (lowest) validation loss

  • arguments (dict) – start and end epochs

  • max_epoch (int) – end_potch

Returns:

  • lowest_validation_loss (float) – The lowest validation loss currently observed

deepdraw.engine.trainer.write_log_info(epoch, current_time, eta_seconds, loss, valid_loss, extra_valid_losses, optimizer, logwriter, logfile, resource_data)[source]#

Write log info in trainlog.csv.

Parameters:
  • epoch (int) – Current epoch

  • current_time (float) – Current training time

  • eta_seconds (float) – estimated time-of-arrival taking into consideration previous epoch performance

  • loss (float) – Current epoch’s training loss

  • valid_loss (float, None) – Current epoch’s validation loss

  • extra_valid_losses (list of float) – Validation losses from other validation datasets being currently tracked

  • optimizer (torch.optim) –

  • logwriter (csv.DictWriter) – Dictionary writer that give the ability to write on the trainlog.csv

  • logfile (io.TextIOWrapper) –

  • resource_data (tuple) – Monitored resources at the machine (CPU and GPU)

deepdraw.engine.trainer.run(model, data_loader, valid_loader, extra_valid_loaders, optimizer, criterion, scheduler, checkpointer, checkpoint_period, device, arguments, output_folder, monitoring_interval, batch_chunk_count)[source]#

Fits an FCN model using supervised learning and save it to disk.

This method supports periodic checkpointing and the output of a CSV-formatted log with the evolution of some figures during training.

Parameters:
  • model (torch.nn.Module) – Network (e.g. driu, hed, unet)

  • data_loader (torch.utils.data.DataLoader) – To be used to train the model

  • valid_loaders (list of torch.utils.data.DataLoader) – To be used to validate the model and enable automatic checkpointing. If None, then do not validate it.

  • extra_valid_loaders (list of torch.utils.data.DataLoader) – To be used to validate the model, however does not affect automatic checkpointing. If empty, then does not log anything else. Otherwise, an extra column with the loss of every dataset in this list is kept on the final training log.

  • optimizer (torch.optim) –

  • criterion (torch.nn.modules.loss._Loss) – loss function

  • scheduler (torch.optim) – learning rate scheduler

  • checkpointer (deepdraw.utils.checkpointer.Checkpointer) – checkpointer implementation

  • checkpoint_period (int) – save a checkpoint every n epochs. If set to 0 (zero), then do not save intermediary checkpoints

  • device (torch.device) – device to use

  • arguments (dict) – start and end epochs

  • output_folder (str) – output path

  • monitoring_interval (int, float) – interval, in seconds (or fractions), through which we should monitor resources during training.

  • batch_chunk_count (int) – If this number is different than 1, then each batch will be divided in this number of chunks. Gradients will be accumulated to perform each mini-batch. This is particularly interesting when one has limited RAM on the GPU, but would like to keep training with larger batches. One exchanges for longer processing times in this case.