bob.med.tb.engine.trainer

Functions

run(model, data_loader, valid_loader, ...[, ...])

Fits a CNN model using supervised learning and save it to disk.

torch_evaluation(model)

Context manager to turn ON/OFF model evaluation

bob.med.tb.engine.trainer.torch_evaluation(model)[source]

Context manager to turn ON/OFF model evaluation

This context manager will turn evaluation mode ON on entry and turn it OFF when exiting the with statement block.

Parameters

model (torch.nn.Module) – Network (e.g. driu, hed, unet)

Yields

model (torch.nn.Module) – Network (e.g. driu, hed, unet)

bob.med.tb.engine.trainer.run(model, data_loader, valid_loader, optimizer, criterion, checkpointer, checkpoint_period, device, arguments, output_folder, criterion_valid=None)[source]

Fits a CNN model using supervised learning and save it to disk.

This method supports periodic checkpointing and the output of a CSV-formatted log with the evolution of some figures during training.

Parameters
  • model (torch.nn.Module) – Network (e.g. pasa)

  • data_loader (torch.utils.data.DataLoader) – To be used to train the model

  • valid_loader (torch.utils.data.DataLoader) – To be used to validate the model and enable automatic checkpointing. If set to None, then do not validate it.

  • optimizer (torch.optim) –

  • criterion (torch.nn.modules.loss._Loss) – loss function

  • checkpointer (bob.med.tb.utils.checkpointer.Checkpointer) – checkpointer implementation

  • checkpoint_period (int) – save a checkpoint every n epochs. If set to 0 (zero), then do not save intermediary checkpoints

  • device (str) – device to use 'cpu' or cuda:0

  • arguments (dict) – start and end epochs

  • output_folder (str) – output path

  • criterion_valid (torch.nn.modules.loss._Loss) – specific loss function for the validation set