Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1#!/usr/bin/env python
2# coding=utf-8
4import logging
5import multiprocessing
6import sys
8import click
9import torch
11from torch.utils.data import DataLoader
13from bob.extension.scripts.click_helper import (
14 ConfigCommand,
15 ResourceOption,
16 verbosity_option,
17)
19from ..utils.checkpointer import Checkpointer
20from .binseg import set_seeds, setup_pytorch_device
22logger = logging.getLogger(__name__)
25@click.command(
26 entry_point_group="bob.ip.binseg.config",
27 cls=ConfigCommand,
28 epilog="""Examples:
30\b
31 1. Trains a U-Net model (VGG-16 backbone) with DRIVE (vessel segmentation),
32 on a GPU (``cuda:0``):
34 $ bob binseg train -vv unet drive --batch-size=4 --device="cuda:0"
36 2. Trains a HED model with HRF on a GPU (``cuda:0``):
38 $ bob binseg train -vv hed hrf --batch-size=8 --device="cuda:0"
40 3. Trains a M2U-Net model on the COVD-DRIVE dataset on the CPU:
42 $ bob binseg train -vv m2unet covd-drive --batch-size=8
44 4. Trains a DRIU model with SSL on the COVD-HRF dataset on the CPU:
46 $ bob binseg train -vv --ssl driu-ssl covd-drive-ssl --batch-size=1
48""",
49)
50@click.option(
51 "--output-folder",
52 "-o",
53 help="Path where to store the generated model (created if does not exist)",
54 required=True,
55 type=click.Path(),
56 default="results",
57 cls=ResourceOption,
58)
59@click.option(
60 "--model",
61 "-m",
62 help="A torch.nn.Module instance implementing the network to be trained",
63 required=True,
64 cls=ResourceOption,
65)
66@click.option(
67 "--dataset",
68 "-d",
69 help="A torch.utils.data.dataset.Dataset instance implementing a dataset "
70 "to be used for training the model, possibly including all pre-processing "
71 "pipelines required or, optionally, a dictionary mapping string keys to "
72 "torch.utils.data.dataset.Dataset instances. At least one key "
73 "named ``train`` must be available. This dataset will be used for "
74 "training the network model. The dataset description must include all "
75 "required pre-processing, including eventual data augmentation. If a "
76 "dataset named ``__train__`` is available, it is used prioritarily for "
77 "training instead of ``train``. If a dataset named ``__valid__`` is "
78 "available, it is used for model validation (and automatic check-pointing) "
79 "at each epoch.",
80 required=True,
81 cls=ResourceOption,
82)
83@click.option(
84 "--optimizer",
85 help="A torch.optim.Optimizer that will be used to train the network",
86 required=True,
87 cls=ResourceOption,
88)
89@click.option(
90 "--criterion",
91 help="A loss function to compute the FCN error for every sample "
92 "respecting the PyTorch API for loss functions (see torch.nn.modules.loss)",
93 required=True,
94 cls=ResourceOption,
95)
96@click.option(
97 "--scheduler",
98 help="A learning rate scheduler that drives changes in the learning "
99 "rate depending on the FCN state (see torch.optim.lr_scheduler)",
100 required=True,
101 cls=ResourceOption,
102)
103@click.option(
104 "--batch-size",
105 "-b",
106 help="Number of samples in every batch (this parameter affects "
107 "memory requirements for the network). If the number of samples in "
108 "the batch is larger than the total number of samples available for "
109 "training, this value is truncated. If this number is smaller, then "
110 "batches of the specified size are created and fed to the network "
111 "until there are no more new samples to feed (epoch is finished). "
112 "If the total number of training samples is not a multiple of the "
113 "batch-size, the last batch will be smaller than the first, unless "
114 "--drop-incomplete--batch is set, in which case this batch is not used.",
115 required=True,
116 show_default=True,
117 default=2,
118 type=click.IntRange(min=1),
119 cls=ResourceOption,
120)
121@click.option(
122 "--drop-incomplete-batch/--no-drop-incomplete-batch",
123 "-D",
124 help="If set, then may drop the last batch in an epoch, in case it is "
125 "incomplete. If you set this option, you should also consider "
126 "increasing the total number of epochs of training, as the total number "
127 "of training steps may be reduced",
128 required=True,
129 show_default=True,
130 default=False,
131 cls=ResourceOption,
132)
133@click.option(
134 "--epochs",
135 "-e",
136 help="Number of epochs (complete training set passes) to train for. "
137 "If continuing from a saved checkpoint, ensure to provide a greater "
138 "number of epochs than that saved on the checkpoint to be loaded. ",
139 show_default=True,
140 required=True,
141 default=1000,
142 type=click.IntRange(min=1),
143 cls=ResourceOption,
144)
145@click.option(
146 "--checkpoint-period",
147 "-p",
148 help="Number of epochs after which a checkpoint is saved. "
149 "A value of zero will disable check-pointing. If checkpointing is "
150 "enabled and training stops, it is automatically resumed from the "
151 "last saved checkpoint if training is restarted with the same "
152 "configuration.",
153 show_default=True,
154 required=True,
155 default=0,
156 type=click.IntRange(min=0),
157 cls=ResourceOption,
158)
159@click.option(
160 "--device",
161 "-d",
162 help='A string indicating the device to use (e.g. "cpu" or "cuda:0")',
163 show_default=True,
164 required=True,
165 default="cpu",
166 cls=ResourceOption,
167)
168@click.option(
169 "--seed",
170 "-s",
171 help="Seed to use for the random number generator",
172 show_default=True,
173 required=False,
174 default=42,
175 type=click.IntRange(min=0),
176 cls=ResourceOption,
177)
178@click.option(
179 "--ssl/--no-ssl",
180 help="Switch ON/OFF semi-supervised training mode",
181 show_default=True,
182 required=True,
183 default=False,
184 cls=ResourceOption,
185)
186@click.option(
187 "--rampup",
188 "-r",
189 help="Ramp-up length in epochs (for SSL training only)",
190 show_default=True,
191 required=True,
192 default=900,
193 type=click.IntRange(min=0),
194 cls=ResourceOption,
195)
196@click.option(
197 "--multiproc-data-loading",
198 "-P",
199 help="""Use multiprocessing for data loading: if set to -1 (default),
200 disables multiprocessing data loading. Set to 0 to enable as many data
201 loading instances as processing cores as available in the system. Set to
202 >= 1 to enable that many multiprocessing instances for data loading.""",
203 type=click.IntRange(min=-1),
204 show_default=True,
205 required=True,
206 default=-1,
207 cls=ResourceOption,
208)
209@verbosity_option(cls=ResourceOption)
210def train(
211 model,
212 optimizer,
213 scheduler,
214 output_folder,
215 epochs,
216 batch_size,
217 drop_incomplete_batch,
218 criterion,
219 dataset,
220 checkpoint_period,
221 device,
222 seed,
223 ssl,
224 rampup,
225 multiproc_data_loading,
226 verbose,
227 **kwargs,
228):
229 """Trains an FCN to perform binary segmentation
231 Training is performed for a configurable number of epochs, and generates at
232 least a final_model.pth. It may also generate a number of intermediate
233 checkpoints. Checkpoints are model files (.pth files) that are stored
234 during the training and useful to resume the procedure in case it stops
235 abruptly.
237 Tip: In case the model has been trained over a number of epochs, it is
238 possible to continue training, by simply relaunching the same command, and
239 changing the number of epochs to a number greater than the number where
240 the original training session stopped (or the last checkpoint was saved).
242 """
244 device = setup_pytorch_device(device)
246 set_seeds(seed, all_gpus=False)
248 use_dataset = dataset
249 validation_dataset = None
250 if isinstance(dataset, dict):
251 if "__train__" in dataset:
252 logger.info("Found (dedicated) '__train__' set for training")
253 use_dataset = dataset["__train__"]
254 else:
255 use_dataset = dataset["train"]
257 if "__valid__" in dataset:
258 logger.info("Found (dedicated) '__valid__' set for validation")
259 logger.info("Will checkpoint lowest loss model on validation set")
260 validation_dataset = dataset["__valid__"]
262 # PyTorch dataloader
263 multiproc_kwargs = dict()
264 if multiproc_data_loading < 0:
265 multiproc_kwargs["num_workers"] = 0
266 elif multiproc_data_loading == 0:
267 multiproc_kwargs["num_workers"] = multiprocessing.cpu_count()
268 else:
269 multiproc_kwargs["num_workers"] = multiproc_data_loading
271 if multiproc_kwargs["num_workers"] > 0 and sys.platform == "darwin":
272 multiproc_kwargs[
273 "multiprocessing_context"
274 ] = multiprocessing.get_context("spawn")
276 data_loader = DataLoader(
277 dataset=use_dataset,
278 batch_size=batch_size,
279 shuffle=True,
280 drop_last=drop_incomplete_batch,
281 pin_memory=torch.cuda.is_available(),
282 **multiproc_kwargs,
283 )
285 valid_loader = None
286 if validation_dataset is not None:
287 valid_loader = DataLoader(
288 dataset=validation_dataset,
289 batch_size=batch_size,
290 shuffle=False,
291 drop_last=False,
292 pin_memory=torch.cuda.is_available(),
293 **multiproc_kwargs,
294 )
296 checkpointer = Checkpointer(model, optimizer, scheduler, path=output_folder)
298 arguments = {}
299 arguments["epoch"] = 0
300 extra_checkpoint_data = checkpointer.load()
301 arguments.update(extra_checkpoint_data)
302 arguments["max_epoch"] = epochs
304 logger.info("Training for {} epochs".format(arguments["max_epoch"]))
305 logger.info("Continuing from epoch {}".format(arguments["epoch"]))
307 if not ssl:
308 from ..engine.trainer import run
310 run(
311 model,
312 data_loader,
313 valid_loader,
314 optimizer,
315 criterion,
316 scheduler,
317 checkpointer,
318 checkpoint_period,
319 device,
320 arguments,
321 output_folder,
322 )
324 else:
325 from ..engine.ssltrainer import run
327 run(
328 model,
329 data_loader,
330 valid_loader,
331 optimizer,
332 criterion,
333 scheduler,
334 checkpointer,
335 checkpoint_period,
336 device,
337 arguments,
338 output_folder,
339 rampup,
340 )