bob.ip.binseg.engine.ssltrainer

Functions

guess_labels(unlabelled_images, model)

Calculate the average predictions by 2 augmentations: horizontal and vertical flips

linear_rampup(current[, rampup_length])

slowly ramp-up lambda_u

mix_up(alpha, input, target, ...)

Applies mix up as described in [MIXMATCH_19].

run(model, data_loader, valid_loader, ...)

Fits an FCN model using semi-supervised learning and saves it to disk.

sharpen(x, T)

square_rampup(current[, rampup_length])

slowly ramp-up lambda_u

bob.ip.binseg.engine.ssltrainer.sharpen(x, T)[source]
bob.ip.binseg.engine.ssltrainer.mix_up(alpha, input, target, unlabelled_input, unlabled_target)[source]

Applies mix up as described in [MIXMATCH_19].

Parameters
Returns

Return type

list

bob.ip.binseg.engine.ssltrainer.square_rampup(current, rampup_length=16)[source]

slowly ramp-up lambda_u

Parameters
  • current (int) – current epoch

  • rampup_length (int, optional) – how long to ramp up, by default 16

Returns

factor – ramp up factor

Return type

float

bob.ip.binseg.engine.ssltrainer.linear_rampup(current, rampup_length=16)[source]

slowly ramp-up lambda_u

Parameters
  • current (int) – current epoch

  • rampup_length (int, optional) – how long to ramp up, by default 16

Returns

factor – ramp up factor

Return type

float

bob.ip.binseg.engine.ssltrainer.guess_labels(unlabelled_images, model)[source]

Calculate the average predictions by 2 augmentations: horizontal and vertical flips

Parameters
Returns

shape[n,c,h,w]

Return type

torch.Tensor

bob.ip.binseg.engine.ssltrainer.run(model, data_loader, valid_loader, optimizer, criterion, scheduler, checkpointer, checkpoint_period, device, arguments, output_folder, rampup_length)[source]

Fits an FCN model using semi-supervised learning and saves it to disk.

This method supports periodic checkpointing and the output of a CSV-formatted log with the evolution of some figures during training.

Parameters
  • model (torch.nn.Module) – Network (e.g. driu, hed, unet)

  • data_loader (torch.utils.data.DataLoader) – To be used to train the model

  • valid_loader (torch.utils.data.DataLoader) – To be used to validate the model and enable automatic checkpointing. If set to None, then do not validate it.

  • optimizer (torch.optim) –

  • criterion (torch.nn.modules.loss._Loss) – loss function

  • scheduler (torch.optim) – learning rate scheduler

  • checkpointer (bob.ip.binseg.utils.checkpointer.Checkpointer) – checkpointer implementation

  • checkpoint_period (int) – save a checkpoint every n epochs. If set to 0 (zero), then do not save intermediary checkpoints

  • device (str) – device to use 'cpu' or cuda:0

  • arguments (dict) – start and end epochs

  • output_folder (str) – output path

  • rampup_length (int) – rampup epochs