mednet.engine.trainer#
Functions
|
Fit a CNN model using supervised learning and save it to disk. |
- mednet.engine.trainer.run(model, datamodule, validation_period, device_manager, max_epochs, output_folder, monitoring_interval, batch_chunk_count, checkpoint)[source]#
Fit a CNN model using supervised learning and save it to disk.
This method supports periodic checkpointing and the output of a CSV-formatted log with the evolution of some figures during training.
- Parameters:
model (
LightningModule) – Neural network model (e.g. pasa).datamodule (
LightningDataModule) – The lightning DataModule to use for training and validation.validation_period (
int) – Number of epochs after which validation happens. By default, we run validation after every training epoch (period=1). You can change this to make validation more sparse, by increasing the validation period. Notice that this affects checkpoint saving. While checkpoints are created after every training step (the last training step always triggers the overriding of latest checkpoint), and that this process is independent of validation runs, evaluation of the ‘best’ model obtained so far based on those will be influenced by this setting.device_manager (
DeviceManager) – An internal device representation, to be used for training and validation. This representation can be converted into a pytorch device or a lightning accelerator setup.max_epochs (
int) – The maximum number of epochs to train for.output_folder (
Path) – Folder in which the results will be saved.monitoring_interval (
int|float) – Interval, in seconds (or fractions), through which we should monitor resources during training.batch_chunk_count (
int) – If this number is different than 1, then each batch will be divided in this number of chunks. Gradients will be accumulated to perform each mini-batch. This is particularly interesting when one has limited RAM on the GPU, but would like to keep training with larger batches. One exchanges for longer processing times in this case.checkpoint (
Path|None) – Path to an optional checkpoint file to load.