1#!/usr/bin/env python
2# coding=utf-8
3
4import os
5import sys
6import random
7import multiprocessing
8
9import click
10import numpy
11import torch
12from torch.nn import BCEWithLogitsLoss
13from torch.utils.data import DataLoader, WeightedRandomSampler
14
15from ..configs.datasets import get_samples_weights, get_positive_weights
16
17
18from bob.extension.scripts.click_helper import (
19 verbosity_option,
20 ConfigCommand,
21 ResourceOption,
22)
23
24from ..utils.checkpointer import Checkpointer
25from ..engine.trainer import run
26from .tb import download_to_tempfile
27
28import logging
29
30logger = logging.getLogger(__name__)
31
32
33def setup_pytorch_device(name):
34 """Sets-up the pytorch device to use
35
36
37 Parameters
38 ----------
39
40 name : str
41 The device name (``cpu``, ``cuda:0``, ``cuda:1``, and so on). If you
42 set a specific cuda device such as ``cuda:1``, then we'll make sure it
43 is currently set.
44
45
46 Returns
47 -------
48
49 device : :py:class:`torch.device`
50 The pytorch device to use, pre-configured (and checked)
51
52 """
53
54 if name.startswith("cuda:"):
55 # In case one has multiple devices, we must first set the one
56 # we would like to use so pytorch can find it.
57 logger.info(f"User set device to '{name}' - trying to force device...")
58 os.environ["CUDA_VISIBLE_DEVICES"] = name.split(":", 1)[1]
59 if not torch.cuda.is_available():
60 raise RuntimeError(
61 f"CUDA is not currently available, but "
62 f"you set device to '{name}'"
63 )
64 # Let pytorch auto-select from environment variable
65 return torch.device("cuda")
66
67 elif name.startswith("cuda"): # use default device
68 logger.info(f"User set device to '{name}' - using default CUDA device")
69 assert os.environ.get("CUDA_VISIBLE_DEVICES") is not None
70
71 # cuda or cpu
72 return torch.device(name)
73
74
75def set_seeds(value, all_gpus):
76 """Sets up all relevant random seeds (numpy, python, cuda)
77
78 If running with multiple GPUs **at the same time**, set ``all_gpus`` to
79 ``True`` to force all GPU seeds to be initialized.
80
81 Reference: `PyTorch page for reproducibility
82 <https://pytorch.org/docs/stable/notes/randomness.html>`_.
83
84
85 Parameters
86 ----------
87
88 value : int
89 The random seed value to use
90
91 all_gpus : :py:class:`bool`, Optional
92 If set, then reset the seed on all GPUs available at once. This is
93 normally **not** what you want if running on a single GPU
94
95 """
96
97 random.seed(value)
98 numpy.random.seed(value)
99 torch.manual_seed(value)
100 torch.cuda.manual_seed(value) # noop if cuda not available
101
102 # set seeds for all gpus
103 if all_gpus:
104 torch.cuda.manual_seed_all(value) # noop if cuda not available
105
106
107def set_reproducible_cuda():
108 """Turns-off all CUDA optimizations that would affect reproducibility
109
110 For full reproducibility, also ensure not to use multiple (parallel) data
111 lowers. That is setup ``num_workers=0``.
112
113 Reference: `PyTorch page for reproducibility
114 <https://pytorch.org/docs/stable/notes/randomness.html>`_.
115
116
117 """
118
119 # ensure to use only optimization algos for cuda that are known to have
120 # a deterministic effect (not random)
121 torch.backends.cudnn.deterministic = True
122
123 # turns off any optimization tricks
124 torch.backends.cudnn.benchmark = False
125
126
127@click.command(
128 entry_point_group="bob.med.tb.config",
129 cls=ConfigCommand,
130 epilog="""Examples:
131
132\b
133 1. Trains PASA model with Montgomery dataset,
134 on a GPU (``cuda:0``):
135
136 $ bob tb train -vv pasa montgomery --batch-size=4 --device="cuda:0"
137
138""",
139)
140@click.option(
141 "--output-folder",
142 "-o",
143 help="Path where to store the generated model (created if does not exist)",
144 required=True,
145 type=click.Path(),
146 default="results",
147 cls=ResourceOption,
148)
149@click.option(
150 "--model",
151 "-m",
152 help="A torch.nn.Module instance implementing the network to be trained",
153 required=True,
154 cls=ResourceOption,
155)
156@click.option(
157 "--dataset",
158 "-d",
159 help="A dictionary mapping string keys to "
160 "torch.utils.data.dataset.Dataset instances implementing datasets "
161 "to be used for training and validating the model, possibly including all "
162 "pre-processing pipelines required or, optionally, a dictionary mapping "
163 "string keys to torch.utils.data.dataset.Dataset instances. At least "
164 "one key named ``train`` must be available. This dataset will be used for "
165 "training the network model. The dataset description must include all "
166 "required pre-processing, including eventual data augmentation. If a "
167 "dataset named ``__train__`` is available, it is used prioritarily for "
168 "training instead of ``train``. If a dataset named ``__valid__`` is "
169 "available, it is used for model validation (and automatic "
170 "check-pointing) at each epoch. If a dataset list named "
171 "``__extra_valid__`` is available, then it will be tracked during the "
172 "validation process and its loss output at the training log as well, "
173 "in the format of an array occupying a single column. All other keys "
174 "are considered test datasets and are ignored during training",
175 required=True,
176 cls=ResourceOption,
177)
178@click.option(
179 "--optimizer",
180 help="A torch.optim.Optimizer that will be used to train the network",
181 required=True,
182 cls=ResourceOption,
183)
184@click.option(
185 "--criterion",
186 help="A loss function to compute the CNN error for every sample "
187 "respecting the PyTorch API for loss functions (see torch.nn.modules.loss)",
188 required=True,
189 cls=ResourceOption,
190)
191@click.option(
192 "--criterion-valid",
193 help="A specific loss function for the validation set to compute the CNN"
194 "error for every sample respecting the PyTorch API for loss functions"
195 "(see torch.nn.modules.loss)",
196 required=False,
197 cls=ResourceOption,
198)
199@click.option(
200 "--batch-size",
201 "-b",
202 help="Number of samples in every batch (this parameter affects "
203 "memory requirements for the network). If the number of samples in "
204 "the batch is larger than the total number of samples available for "
205 "training, this value is truncated. If this number is smaller, then "
206 "batches of the specified size are created and fed to the network "
207 "until there are no more new samples to feed (epoch is finished). "
208 "If the total number of training samples is not a multiple of the "
209 "batch-size, the last batch will be smaller than the first, unless "
210 "--drop-incomplete-batch is set, in which case this batch is not used.",
211 required=True,
212 show_default=True,
213 default=1,
214 type=click.IntRange(min=1),
215 cls=ResourceOption,
216)
217@click.option(
218 "--batch-chunk-count",
219 "-c",
220 help="Number of chunks in every batch (this parameter affects "
221 "memory requirements for the network). The number of samples "
222 "loaded for every iteration will be batch-size/batch-chunk-count. "
223 "batch-size needs to be divisible by batch-chunk-count, otherwise an "
224 "error will be raised. This parameter is used to reduce number of "
225 "samples loaded in each iteration, in order to reduce the memory usage "
226 "in exchange for processing time (more iterations). This is specially "
227 "interesting whe one is running with GPUs with limited RAM. The "
228 "default of 1 forces the whole batch to be processed at once. Otherwise "
229 "the batch is broken into batch-chunk-count pieces, and gradients are "
230 "accumulated to complete each batch.",
231 required=True,
232 show_default=True,
233 default=1,
234 type=click.IntRange(min=1),
235 cls=ResourceOption,
236)
237@click.option(
238 "--drop-incomplete-batch/--no-drop-incomplete-batch",
239 "-D",
240 help="If set, then may drop the last batch in an epoch, in case it is "
241 "incomplete. If you set this option, you should also consider "
242 "increasing the total number of epochs of training, as the total number "
243 "of training steps may be reduced",
244 required=True,
245 show_default=True,
246 default=False,
247 cls=ResourceOption,
248)
249@click.option(
250 "--epochs",
251 "-e",
252 help="Number of epochs (complete training set passes) to train for. "
253 "If continuing from a saved checkpoint, ensure to provide a greater "
254 "number of epochs than that saved on the checkpoint to be loaded. ",
255 show_default=True,
256 required=True,
257 default=1000,
258 type=click.IntRange(min=1),
259 cls=ResourceOption,
260)
261@click.option(
262 "--checkpoint-period",
263 "-p",
264 help="Number of epochs after which a checkpoint is saved. "
265 "A value of zero will disable check-pointing. If checkpointing is "
266 "enabled and training stops, it is automatically resumed from the "
267 "last saved checkpoint if training is restarted with the same "
268 "configuration.",
269 show_default=True,
270 required=True,
271 default=0,
272 type=click.IntRange(min=0),
273 cls=ResourceOption,
274)
275@click.option(
276 "--device",
277 "-d",
278 help='A string indicating the device to use (e.g. "cpu" or "cuda:0")',
279 show_default=True,
280 required=True,
281 default="cpu",
282 cls=ResourceOption,
283)
284@click.option(
285 "--seed",
286 "-s",
287 help="Seed to use for the random number generator",
288 show_default=True,
289 required=False,
290 default=42,
291 type=click.IntRange(min=0),
292 cls=ResourceOption,
293)
294@click.option(
295 "--parallel",
296 "-P",
297 help="""Use multiprocessing for data loading: if set to -1 (default),
298 disables multiprocessing data loading. Set to 0 to enable as many data
299 loading instances as processing cores as available in the system. Set to
300 >= 1 to enable that many multiprocessing instances for data loading.""",
301 type=click.IntRange(min=-1),
302 show_default=True,
303 required=True,
304 default=-1,
305 cls=ResourceOption,
306)
307@click.option(
308 "--weight",
309 "-w",
310 help="Path or URL to pretrained model file (.pth extension)",
311 required=False,
312 cls=ResourceOption,
313)
314@click.option(
315 "--normalization",
316 "-n",
317 help="Z-Normalization of input images: 'imagenet' for ImageNet parameters,"
318 " 'current' for parameters of the current trainset, "
319 "'none' for no normalization.",
320 required=False,
321 default="none",
322 cls=ResourceOption,
323)
324@click.option(
325 "--monitoring-interval",
326 "-I",
327 help="""Time between checks for the use of resources during each training
328 epoch. An interval of 5 seconds, for example, will lead to CPU and GPU
329 resources being probed every 5 seconds during each training epoch.
330 Values registered in the training logs correspond to averages (or maxima)
331 observed through possibly many probes in each epoch. Notice that setting a
332 very small value may cause the probing process to become extremely busy,
333 potentially biasing the overall perception of resource usage.""",
334 type=click.FloatRange(min=0.1),
335 show_default=True,
336 required=True,
337 default=5.0,
338 cls=ResourceOption,
339)
340@verbosity_option(cls=ResourceOption)
341def train(
342 model,
343 optimizer,
344 output_folder,
345 epochs,
346 batch_size,
347 batch_chunk_count,
348 drop_incomplete_batch,
349 criterion,
350 criterion_valid,
351 dataset,
352 checkpoint_period,
353 device,
354 seed,
355 parallel,
356 weight,
357 normalization,
358 monitoring_interval,
359 verbose,
360 **kwargs,
361):
362 """Trains an CNN to perform tuberculosis detection
363
364 Training is performed for a configurable number of epochs, and generates at
365 least a final_model.pth. It may also generate a number of intermediate
366 checkpoints. Checkpoints are model files (.pth files) that are stored
367 during the training and useful to resume the procedure in case it stops
368 abruptly.
369 """
370
371 device = setup_pytorch_device(device)
372
373 set_seeds(seed, all_gpus=False)
374
375 use_dataset = dataset
376 validation_dataset = None
377 extra_validation_datasets = []
378
379 if isinstance(dataset, dict):
380 if "__train__" in dataset:
381 logger.info("Found (dedicated) '__train__' set for training")
382 use_dataset = dataset["__train__"]
383 else:
384 use_dataset = dataset["train"]
385
386 if "__valid__" in dataset:
387 logger.info("Found (dedicated) '__valid__' set for validation")
388 logger.info("Will checkpoint lowest loss model on validation set")
389 validation_dataset = dataset["__valid__"]
390
391 if "__extra_valid__" in dataset:
392 if not isinstance(dataset["__extra_valid__"], list):
393 raise RuntimeError(
394 f"If present, dataset['__extra_valid__'] must be a list, "
395 f"but you passed a {type(dataset['__extra_valid__'])}, "
396 f"which is invalid."
397 )
398 logger.info(
399 f"Found {len(dataset['__extra_valid__'])} extra validation "
400 f"set(s) to be tracked during training"
401 )
402 logger.info(
403 "Extra validation sets are NOT used for model checkpointing!"
404 )
405 extra_validation_datasets = dataset["__extra_valid__"]
406
407 # PyTorch dataloader
408 multiproc_kwargs = dict()
409 if parallel < 0:
410 multiproc_kwargs["num_workers"] = 0
411 else:
412 multiproc_kwargs["num_workers"] = (
413 parallel or multiprocessing.cpu_count()
414 )
415
416 if multiproc_kwargs["num_workers"] > 0 and sys.platform == "darwin":
417 multiproc_kwargs[
418 "multiprocessing_context"
419 ] = multiprocessing.get_context("spawn")
420
421 batch_chunk_size = batch_size
422 if batch_size % batch_chunk_count != 0:
423 # batch_size must be divisible by batch_chunk_count.
424 raise RuntimeError(
425 f"--batch-size ({batch_size}) must be divisible by "
426 f"--batch-chunk-size ({batch_chunk_count})."
427 )
428 else:
429 batch_chunk_size = batch_size // batch_chunk_count
430
431 # Create weighted random sampler
432 train_samples_weights = get_samples_weights(use_dataset)
433 train_samples_weights = train_samples_weights.to(
434 device=device, non_blocking=torch.cuda.is_available()
435 )
436 train_sampler = WeightedRandomSampler(
437 train_samples_weights, len(train_samples_weights), replacement=True
438 )
439
440 # Redefine a weighted criterion if possible
441 if isinstance(criterion, torch.nn.BCEWithLogitsLoss):
442 positive_weights = get_positive_weights(use_dataset)
443 positive_weights = positive_weights.to(
444 device=device, non_blocking=torch.cuda.is_available()
445 )
446 criterion = BCEWithLogitsLoss(pos_weight=positive_weights)
447 else:
448 logger.warning("Weighted criterion not supported")
449
450 # PyTorch dataloader
451
452 data_loader = DataLoader(
453 dataset=use_dataset,
454 batch_size=batch_chunk_size,
455 drop_last=drop_incomplete_batch,
456 pin_memory=torch.cuda.is_available(),
457 sampler=train_sampler,
458 **multiproc_kwargs,
459 )
460
461 valid_loader = None
462 if validation_dataset is not None:
463
464 # Redefine a weighted valid criterion if possible
465 if (
466 isinstance(criterion_valid, torch.nn.BCEWithLogitsLoss)
467 or criterion_valid is None
468 ):
469 positive_weights = get_positive_weights(validation_dataset)
470 positive_weights = positive_weights.to(
471 device=device, non_blocking=torch.cuda.is_available()
472 )
473 criterion_valid = BCEWithLogitsLoss(pos_weight=positive_weights)
474 else:
475 logger.warning("Weighted valid criterion not supported")
476
477 valid_loader = DataLoader(
478 dataset=validation_dataset,
479 batch_size=batch_chunk_size,
480 shuffle=False,
481 drop_last=False,
482 pin_memory=torch.cuda.is_available(),
483 **multiproc_kwargs,
484 )
485
486 extra_valid_loaders = [
487 DataLoader(
488 dataset=k,
489 batch_size=batch_chunk_size,
490 shuffle=False,
491 drop_last=False,
492 pin_memory=torch.cuda.is_available(),
493 **multiproc_kwargs,
494 )
495 for k in extra_validation_datasets
496 ]
497
498 # Create z-normalization model layer if needed
499 if normalization == "imagenet":
500 model.normalizer.set_mean_std(
501 [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
502 )
503 logger.info("Z-normalization with ImageNet mean and std")
504 elif normalization == "current":
505 # Compute mean/std of current train subset
506 temp_dl = DataLoader(dataset=use_dataset, batch_size=len(use_dataset))
507
508 data = next(iter(temp_dl))
509 mean = data[1].mean(dim=[0, 2, 3])
510 std = data[1].std(dim=[0, 2, 3])
511
512 model.normalizer.set_mean_std(mean, std)
513
514 # Format mean and std for logging
515 mean = str(
516 [
517 round(x, 3)
518 for x in ((mean * 10**3).round() / (10**3)).tolist()
519 ]
520 )
521 std = str(
522 [
523 round(x, 3)
524 for x in ((std * 10**3).round() / (10**3)).tolist()
525 ]
526 )
527 logger.info("Z-normalization with mean {} and std {}".format(mean, std))
528
529 # Checkpointer
530 checkpointer = Checkpointer(model, optimizer, path=output_folder)
531
532 # Load pretrained weights if needed
533 if weight is not None:
534 if weight.startswith("http"):
535 logger.info(f"Temporarily downloading '{weight}'...")
536 f = download_to_tempfile(weight, progress=True)
537 weight_fullpath = os.path.abspath(f.name)
538 else:
539 weight_fullpath = os.path.abspath(weight)
540 checkpointer.load(weight_fullpath, strict=False)
541
542 arguments = {}
543 arguments["epoch"] = 0
544 arguments["max_epoch"] = epochs
545
546 logger.info("Training for {} epochs".format(arguments["max_epoch"]))
547 logger.info("Continuing from epoch {}".format(arguments["epoch"]))
548
549 run(
550 model=model,
551 data_loader=data_loader,
552 valid_loader=valid_loader,
553 extra_valid_loaders=extra_valid_loaders,
554 optimizer=optimizer,
555 criterion=criterion,
556 checkpointer=checkpointer,
557 checkpoint_period=checkpoint_period,
558 device=device,
559 arguments=arguments,
560 output_folder=output_folder,
561 monitoring_interval=monitoring_interval,
562 batch_chunk_count=batch_chunk_count,
563 criterion_valid=criterion_valid,
564 )