deepdraw.engine.trainer#
Functions
|
Check existance of logfile (trainlog.csv), If the logfile exist the and the epochs number are still 0, The logfile will be replaced. |
|
Check the device type and the availability of GPU. |
|
Process the checkpointer, save the final model and keep track of the best model. |
|
Creation of the logfile fields that will appear in the logfile. |
|
Fits an FCN model using supervised learning and save it to disk. |
|
Save a little summary of the model in a txt file. |
Save the static information in a csv file. |
|
|
Context manager to turn ON/OFF model evaluation. |
|
Trains the model for a single epoch (through all batches) |
|
Processes input samples and returns loss (scalar) |
|
Write log info in trainlog.csv. |
- deepdraw.engine.trainer.torch_evaluation(model)[source]#
Context manager to turn ON/OFF model evaluation.
This context manager will turn evaluation mode ON on entry and turn it OFF when exiting the
with
statement block.- Parameters:
model (
torch.nn.Module
) – Network (e.g. driu, hed, unet)- Yields:
model (
torch.nn.Module
) – Network (e.g. driu, hed, unet)
- deepdraw.engine.trainer.check_gpu(device)[source]#
Check the device type and the availability of GPU.
- Parameters:
device (
torch.device
) – device to use
- deepdraw.engine.trainer.save_model_summary(output_folder, model)[source]#
Save a little summary of the model in a txt file.
- Parameters:
output_folder (str) – output path
model (
torch.nn.Module
) – Network (e.g. driu, hed, unet)
- Returns:
r (str) – The model summary in a text format.
n (int) – The number of parameters of the model.
- deepdraw.engine.trainer.static_information_to_csv(static_logfile_name, device, n)[source]#
Save the static information in a csv file.
- Parameters:
static_logfile_name (str) – The static file name which is a join between the output folder and “constant.csv”
- deepdraw.engine.trainer.check_exist_logfile(logfile_name, arguments)[source]#
Check existance of logfile (trainlog.csv), If the logfile exist the and the epochs number are still 0, The logfile will be replaced.
- deepdraw.engine.trainer.create_logfile_fields(valid_loader, extra_valid_loaders, device)[source]#
Creation of the logfile fields that will appear in the logfile.
- Parameters:
valid_loader (
torch.utils.data.DataLoader
) – To be used to validate the model and enable automatic checkpointing. If set toNone
, then do not validate it.extra_valid_loaders (
list
oftorch.utils.data.DataLoader
) – To be used to validate the model, however does not affect automatic checkpointing. If set toNone
, or empty, then does not log anything else. Otherwise, an extra column with the loss of every dataset in this list is kept on the final training log.device (
torch.device
) – device to use
- Returns:
logfile_fields (tuple) – The fields that will appear in trainlog.csv
- deepdraw.engine.trainer.train_epoch(loader, model, optimizer, device, criterion, batch_chunk_count)[source]#
Trains the model for a single epoch (through all batches)
- Parameters:
loader (
torch.utils.data.DataLoader
) – To be used to train the modelmodel (
torch.nn.Module
) – Network (e.g. driu, hed, unet)optimizer (
torch.optim
) –device (
torch.device
) – device to usecriterion (
torch.nn.modules.loss._Loss
) –batch_chunk_count (int) – If this number is different than 1, then each batch will be divided in this number of chunks. Gradients will be accumulated to perform each mini-batch. This is particularly interesting when one has limited RAM on the GPU, but would like to keep training with larger batches. One exchanges for longer processing times in this case. To better understand gradient accumulation, read https://stackoverflow.com/questions/62067400/understanding-accumulated-gradients-in-pytorch.
- Returns:
loss (float) – A floating-point value corresponding the weighted average of this epoch’s loss
- deepdraw.engine.trainer.validate_epoch(loader, model, device, criterion, pbar_desc)[source]#
Processes input samples and returns loss (scalar)
- Parameters:
loader (
torch.utils.data.DataLoader
) – To be used to validate the modelmodel (
torch.nn.Module
) – Network (e.g. driu, hed, unet)optimizer (
torch.optim
) –device (
torch.device
) – device to usecriterion (
torch.nn.modules.loss._Loss
) – loss functionpbar_desc (str) – A string for the progress bar descriptor
- Returns:
loss (float) – A floating-point value corresponding the weighted average of this epoch’s loss
- deepdraw.engine.trainer.checkpointer_process(checkpointer, checkpoint_period, valid_loss, lowest_validation_loss, arguments, epoch, max_epoch)[source]#
Process the checkpointer, save the final model and keep track of the best model.
- Parameters:
checkpointer (
deepdraw.utils.checkpointer.Checkpointer
) – checkpointer implementationcheckpoint_period (int) – save a checkpoint every
n
epochs. If set to0
(zero), then do not save intermediary checkpointsvalid_loss (float) – Current epoch validation loss
lowest_validation_loss (float) – Keeps track of the best (lowest) validation loss
arguments (dict) – start and end epochs
max_epoch (int) – end_potch
- Returns:
lowest_validation_loss (float) – The lowest validation loss currently observed
- deepdraw.engine.trainer.write_log_info(epoch, current_time, eta_seconds, loss, valid_loss, extra_valid_losses, optimizer, logwriter, logfile, resource_data)[source]#
Write log info in trainlog.csv.
- Parameters:
epoch (int) – Current epoch
current_time (float) – Current training time
eta_seconds (float) – estimated time-of-arrival taking into consideration previous epoch performance
loss (float) – Current epoch’s training loss
valid_loss (
float
, None) – Current epoch’s validation lossextra_valid_losses (
list
offloat
) – Validation losses from other validation datasets being currently trackedoptimizer (
torch.optim
) –logwriter (csv.DictWriter) – Dictionary writer that give the ability to write on the trainlog.csv
logfile (io.TextIOWrapper) –
resource_data (tuple) – Monitored resources at the machine (CPU and GPU)
- deepdraw.engine.trainer.run(model, data_loader, valid_loader, extra_valid_loaders, optimizer, criterion, scheduler, checkpointer, checkpoint_period, device, arguments, output_folder, monitoring_interval, batch_chunk_count)[source]#
Fits an FCN model using supervised learning and save it to disk.
This method supports periodic checkpointing and the output of a CSV-formatted log with the evolution of some figures during training.
- Parameters:
model (
torch.nn.Module
) – Network (e.g. driu, hed, unet)data_loader (
torch.utils.data.DataLoader
) – To be used to train the modelvalid_loaders (
list
oftorch.utils.data.DataLoader
) – To be used to validate the model and enable automatic checkpointing. IfNone
, then do not validate it.extra_valid_loaders (
list
oftorch.utils.data.DataLoader
) – To be used to validate the model, however does not affect automatic checkpointing. If empty, then does not log anything else. Otherwise, an extra column with the loss of every dataset in this list is kept on the final training log.optimizer (
torch.optim
) –criterion (
torch.nn.modules.loss._Loss
) – loss functionscheduler (
torch.optim
) – learning rate schedulercheckpointer (
deepdraw.utils.checkpointer.Checkpointer
) – checkpointer implementationcheckpoint_period (int) – save a checkpoint every
n
epochs. If set to0
(zero), then do not save intermediary checkpointsdevice (
torch.device
) – device to usearguments (dict) – start and end epochs
output_folder (str) – output path
monitoring_interval (int, float) – interval, in seconds (or fractions), through which we should monitor resources during training.
batch_chunk_count (int) – If this number is different than 1, then each batch will be divided in this number of chunks. Gradients will be accumulated to perform each mini-batch. This is particularly interesting when one has limited RAM on the GPU, but would like to keep training with larger batches. One exchanges for longer processing times in this case.