Coverage for src/deepdraw/script/experiment.py: 90%
49 statements
« prev ^ index » next coverage.py v7.3.1, created at 2023-11-30 15:00 +0100
« prev ^ index » next coverage.py v7.3.1, created at 2023-11-30 15:00 +0100
1# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
2#
3# SPDX-License-Identifier: GPL-3.0-or-later
5import os
6import shutil
8import click
10from clapper.click import ConfigCommand, ResourceOption, verbosity_option
11from clapper.logging import setup
13logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
15from .common import save_sh_command
18@click.command(
19 entry_point_group="deepdraw.config",
20 cls=ConfigCommand,
21 epilog="""Examples:
23\b
24 1. Trains an M2U-Net model (VGG-16 backbone) with DRIVE (vessel
25 segmentation), on the CPU, for only two epochs, then runs inference and
26 evaluation on stock datasets, report performance as a table and a figure:
28 .. code:: sh
30 $ deepdraw experiment -vv m2unet drive --epochs=2
31""",
32)
33@click.option(
34 "--output-folder",
35 "-o",
36 help="Path where to store experiment outputs (created if does not exist)",
37 required=True,
38 type=click.Path(),
39 default="results",
40 cls=ResourceOption,
41)
42@click.option(
43 "--model",
44 "-m",
45 help="A torch.nn.Module instance implementing the network to be trained, and then evaluated",
46 required=True,
47 cls=ResourceOption,
48)
49@click.option(
50 "--dataset",
51 "-d",
52 help="A dictionary mapping string keys to "
53 "torch.utils.data.dataset.Dataset instances implementing datasets "
54 "to be used for training and validating the model, possibly including all "
55 "pre-processing pipelines required or, optionally, a dictionary mapping "
56 "string keys to torch.utils.data.dataset.Dataset instances. At least "
57 "one key named ``train`` must be available. This dataset will be used for "
58 "training the network model. The dataset description must include all "
59 "required pre-processing, including eventual data augmentation. If a "
60 "dataset named ``__train__`` is available, it is used prioritarily for "
61 "training instead of ``train``. If a dataset named ``__valid__`` is "
62 "available, it is used for model validation (and automatic "
63 "check-pointing) at each epoch. If a dataset list named "
64 "``__valid_extra__`` is available, then it will be tracked during the "
65 "validation process and its loss output at the training log as well, "
66 "in the format of an array occupying a single column. All other keys "
67 "are considered test datasets and only used during analysis, to report "
68 "the final system performance",
69 required=True,
70 cls=ResourceOption,
71)
72@click.option(
73 "--second-annotator",
74 "-S",
75 help="A dataset or dictionary, like in --dataset, with the same "
76 "sample keys, but with annotations from a different annotator that is "
77 "going to be compared to the one in --dataset",
78 required=False,
79 default=None,
80 cls=ResourceOption,
81 show_default=True,
82)
83@click.option(
84 "--optimizer",
85 help="A torch.optim.Optimizer that will be used to train the network",
86 required=True,
87 cls=ResourceOption,
88)
89@click.option(
90 "--criterion",
91 help="A loss function to compute the FCN error for every sample "
92 "respecting the PyTorch API for loss functions (see torch.nn.modules.loss)",
93 required=True,
94 cls=ResourceOption,
95)
96@click.option(
97 "--scheduler",
98 help="A learning rate scheduler that drives changes in the learning "
99 "rate depending on the FCN state (see torch.optim.lr_scheduler)",
100 required=True,
101 cls=ResourceOption,
102)
103@click.option(
104 "--batch-size",
105 "-b",
106 help="Number of samples in every batch (this parameter affects "
107 "memory requirements for the network). If the number of samples in "
108 "the batch is larger than the total number of samples available for "
109 "training, this value is truncated. If this number is smaller, then "
110 "batches of the specified size are created and fed to the network "
111 "until there are no more new samples to feed (epoch is finished). "
112 "If the total number of training samples is not a multiple of the "
113 "batch-size, the last batch will be smaller than the first, unless "
114 "--drop-incomplete-batch is set, in which case this batch is not used.",
115 required=True,
116 show_default=True,
117 default=2,
118 type=click.IntRange(min=1),
119 cls=ResourceOption,
120)
121@click.option(
122 "--batch-chunk-count",
123 "-c",
124 help="Number of chunks in every batch (this parameter affects "
125 "memory requirements for the network). The number of samples "
126 "loaded for every iteration will be batch-size/batch-chunk-count. "
127 "batch-size needs to be divisible by batch-chunk-count, otherwise an "
128 "error will be raised. This parameter is used to reduce number of "
129 "samples loaded in each iteration, in order to reduce the memory usage "
130 "in exchange for processing time (more iterations). This is specially "
131 "interesting whe one is running with GPUs with limited RAM. The "
132 "default of 1 forces the whole batch to be processed at once. Otherwise "
133 "the batch is broken into batch-chunk-count pieces, and gradients are "
134 "accumulated to complete each batch.",
135 required=True,
136 show_default=True,
137 default=1,
138 type=click.IntRange(min=1),
139 cls=ResourceOption,
140)
141@click.option(
142 "--drop-incomplete-batch/--no-drop-incomplete-batch",
143 "-D",
144 help="If set, then may drop the last batch in an epoch, in case it is "
145 "incomplete. If you set this option, you should also consider "
146 "increasing the total number of epochs of training, as the total number "
147 "of training steps may be reduced",
148 required=True,
149 show_default=True,
150 default=False,
151 cls=ResourceOption,
152)
153@click.option(
154 "--epochs",
155 "-e",
156 help="Number of epochs (complete training set passes) to train for. "
157 "If continuing from a saved checkpoint, ensure to provide a greater "
158 "number of epochs than that saved on the checkpoint to be loaded. ",
159 show_default=True,
160 required=True,
161 default=1000,
162 type=click.IntRange(min=1),
163 cls=ResourceOption,
164)
165@click.option(
166 "--checkpoint-period",
167 "-p",
168 help="Number of epochs after which a checkpoint is saved. "
169 "A value of zero will disable check-pointing. If checkpointing is "
170 "enabled and training stops, it is automatically resumed from the "
171 "last saved checkpoint if training is restarted with the same "
172 "configuration.",
173 show_default=True,
174 required=True,
175 default=0,
176 type=click.IntRange(min=0),
177 cls=ResourceOption,
178)
179@click.option(
180 "--device",
181 "-d",
182 help='A string indicating the device to use (e.g. "cpu" or "cuda:0")',
183 show_default=True,
184 required=True,
185 default="cpu",
186 cls=ResourceOption,
187)
188@click.option(
189 "--seed",
190 "-s",
191 help="Seed to use for the random number generator",
192 show_default=True,
193 required=False,
194 default=42,
195 type=click.IntRange(min=0),
196 cls=ResourceOption,
197)
198@click.option(
199 "--parallel",
200 "-P",
201 help="""Use multiprocessing for data loading and processing: if set to -1
202 (default), disables multiprocessing altogether. Set to 0 to enable as many
203 data loading instances as processing cores as available in the system. Set
204 to >= 1 to enable that many multiprocessing instances for data
205 processing.""",
206 type=click.IntRange(min=-1),
207 show_default=True,
208 required=True,
209 default=-1,
210 cls=ResourceOption,
211)
212@click.option(
213 "--monitoring-interval",
214 "-I",
215 help="""Time between checks for the use of resources during each training
216 epoch. An interval of 5 seconds, for example, will lead to CPU and GPU
217 resources being probed every 5 seconds during each training epoch.
218 Values registered in the training logs correspond to averages (or maxima)
219 observed through possibly many probes in each epoch. Notice that setting a
220 very small value may cause the probing process to become extremely busy,
221 potentially biasing the overall perception of resource usage.""",
222 type=click.FloatRange(min=0.1),
223 show_default=True,
224 required=True,
225 default=5.0,
226 cls=ResourceOption,
227)
228@click.option(
229 "--overlayed/--no-overlayed",
230 "-O",
231 help="Creates overlayed representations of the output probability maps, "
232 "similar to --overlayed in prediction-mode, except it includes "
233 "distinctive colours for true and false positives and false negatives. "
234 "If not set, or empty then do **NOT** output overlayed images.",
235 show_default=True,
236 default=False,
237 required=False,
238 cls=ResourceOption,
239)
240@click.option(
241 "--steps",
242 "-S",
243 help="This number is used to define the number of threshold steps to "
244 "consider when evaluating the highest possible F1-score on test data.",
245 default=1000,
246 show_default=True,
247 required=True,
248 cls=ResourceOption,
249)
250@click.option(
251 "--plot-limits",
252 "-L",
253 help="""If set, this option affects the performance comparison plots. It
254 must be a 4-tuple containing the bounds of the plot for the x and y axis
255 respectively (format: x_low, x_high, y_low, y_high]). If not set, use
256 normal bounds ([0, 1, 0, 1]) for the performance curve.""",
257 default=[0.0, 1.0, 0.0, 1.0],
258 show_default=True,
259 nargs=4,
260 type=float,
261 cls=ResourceOption,
262)
263@verbosity_option(logger=logger, cls=ResourceOption)
264@click.pass_context
265def experiment(
266 ctx,
267 model,
268 optimizer,
269 scheduler,
270 output_folder,
271 epochs,
272 batch_size,
273 batch_chunk_count,
274 drop_incomplete_batch,
275 criterion,
276 dataset,
277 second_annotator,
278 checkpoint_period,
279 device,
280 seed,
281 parallel,
282 monitoring_interval,
283 overlayed,
284 steps,
285 plot_limits,
286 verbose,
287 **kwargs,
288):
289 """Runs a complete experiment, from training, to prediction and evaluation.
291 This script is just a wrapper around the individual scripts for training,
292 running prediction, evaluating and comparing FCN model performance. It
293 organises the output in a preset way::
295 \b
296 └─ <output-folder>/
297 ├── model/ #the generated model will be here
298 ├── predictions/ #the prediction outputs for the train/test set
299 ├── overlayed/ #the overlayed outputs for the train/test set
300 ├── predictions/ #predictions overlayed on the input images
301 ├── analysis/ #predictions overlayed on the input images
302 ├ #including analysis of false positives, negatives
303 ├ #and true positives
304 └── second-annotator/ #if set, store overlayed images for the
305 #second annotator here
306 └── analysis / #the outputs of the analysis of both train/test sets
307 #includes second-annotator "mesures" as well, if
308 # configured
310 Training is performed for a configurable number of epochs, and generates at
311 least a final_model.pth. It may also generate a number of intermediate
312 checkpoints. Checkpoints are model files (.pth files) that are stored
313 during the training and useful to resume the procedure in case it stops
314 abruptly.
316 N.B.: The tool is designed to prevent analysis bias and allows one to
317 provide (potentially multiple) separate subsets for training,
318 validation, and evaluation. Instead of using simple datasets, datasets
319 for full experiment running should be dictionaries with specific subset
320 names:
322 * ``__train__``: dataset used for training, prioritarily. It is typically
323 the dataset containing data augmentation pipelines.
324 * ``__valid__``: dataset used for validation. It is typically disjoint
325 from the training and test sets. In such a case, we checkpoint the model
326 with the lowest loss on the validation set as well, throughout all the
327 training, besides the model at the end of training.
328 * ``train`` (optional): a copy of the ``__train__`` dataset, without data
329 augmentation, that will be evaluated alongside other sets available
330 * ``__valid_extra__``: a list of datasets that are tracked during
331 validation, but do not affect checkpoiting. If present, an extra
332 column with an array containing the loss of each set is kept on the
333 training log.
334 * ``*``: any other name, not starting with an underscore character (``_``),
335 will be considered a test set for evaluation.
337 N.B.2: The threshold used for calculating the F1-score on the test set, or
338 overlay analysis (false positives, negatives and true positives overprinted
339 on the original image) also follows the logic above.
340 """
342 command_sh = os.path.join(output_folder, "command.sh")
343 if os.path.exists(command_sh):
344 backup = command_sh + "~"
345 if os.path.exists(backup):
346 os.unlink(backup)
347 shutil.move(command_sh, backup)
348 save_sh_command(command_sh)
350 # training
351 logger.info("Started training")
353 from .train import train
355 train_output_folder = os.path.join(output_folder, "model")
356 ctx.invoke(
357 train,
358 model=model,
359 optimizer=optimizer,
360 scheduler=scheduler,
361 output_folder=train_output_folder,
362 epochs=epochs,
363 batch_size=batch_size,
364 batch_chunk_count=batch_chunk_count,
365 drop_incomplete_batch=drop_incomplete_batch,
366 criterion=criterion,
367 dataset=dataset,
368 checkpoint_period=checkpoint_period,
369 device=device,
370 seed=seed,
371 parallel=parallel,
372 monitoring_interval=monitoring_interval,
373 verbose=verbose,
374 )
375 logger.info("Ended training")
377 from .train_analysis import train_analysis
379 ctx.invoke(
380 train_analysis,
381 log=os.path.join(train_output_folder, "trainlog.csv"),
382 constants=os.path.join(train_output_folder, "constants.csv"),
383 output_pdf=os.path.join(train_output_folder, "trainlog.pdf"),
384 verbose=verbose,
385 )
387 from .analyze import analyze
389 # preferably, we use the best model on the validation set
390 # otherwise, we get the last saved model
391 model_file = os.path.join(
392 train_output_folder, "model_lowest_valid_loss.pth"
393 )
394 if not os.path.exists(model_file):
395 model_file = os.path.join(train_output_folder, "model_final_epoch.pth")
397 ctx.invoke(
398 analyze,
399 model=model,
400 output_folder=output_folder,
401 batch_size=batch_size,
402 dataset=dataset,
403 second_annotator=second_annotator,
404 device=device,
405 overlayed=overlayed,
406 weight=model_file,
407 steps=steps,
408 parallel=parallel,
409 plot_limits=plot_limits,
410 verbose=verbose,
411 )