Coverage for /scratch/builds/bob/bob.med.tb/miniconda/conda-bld/bob.med.tb_1674079587905/_test_env_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placeho/lib/python3.10/site-packages/bob/med/tb/scripts/train.py: 82%

131 statements  

« prev     ^ index     » next       coverage.py v7.0.5, created at 2023-01-18 22:14 +0000

1#!/usr/bin/env python 

2# coding=utf-8 

3 

4import os 

5import sys 

6import random 

7import multiprocessing 

8 

9import click 

10import numpy 

11import torch 

12from torch.nn import BCEWithLogitsLoss 

13from torch.utils.data import DataLoader, WeightedRandomSampler 

14 

15from ..configs.datasets import get_samples_weights, get_positive_weights 

16 

17 

18from bob.extension.scripts.click_helper import ( 

19 verbosity_option, 

20 ConfigCommand, 

21 ResourceOption, 

22) 

23 

24from ..utils.checkpointer import Checkpointer 

25from ..engine.trainer import run 

26from .tb import download_to_tempfile 

27 

28import logging 

29 

30logger = logging.getLogger(__name__) 

31 

32 

33def setup_pytorch_device(name): 

34 """Sets-up the pytorch device to use 

35 

36 

37 Parameters 

38 ---------- 

39 

40 name : str 

41 The device name (``cpu``, ``cuda:0``, ``cuda:1``, and so on). If you 

42 set a specific cuda device such as ``cuda:1``, then we'll make sure it 

43 is currently set. 

44 

45 

46 Returns 

47 ------- 

48 

49 device : :py:class:`torch.device` 

50 The pytorch device to use, pre-configured (and checked) 

51 

52 """ 

53 

54 if name.startswith("cuda:"): 

55 # In case one has multiple devices, we must first set the one 

56 # we would like to use so pytorch can find it. 

57 logger.info(f"User set device to '{name}' - trying to force device...") 

58 os.environ["CUDA_VISIBLE_DEVICES"] = name.split(":", 1)[1] 

59 if not torch.cuda.is_available(): 

60 raise RuntimeError( 

61 f"CUDA is not currently available, but " 

62 f"you set device to '{name}'" 

63 ) 

64 # Let pytorch auto-select from environment variable 

65 return torch.device("cuda") 

66 

67 elif name.startswith("cuda"): # use default device 

68 logger.info(f"User set device to '{name}' - using default CUDA device") 

69 assert os.environ.get("CUDA_VISIBLE_DEVICES") is not None 

70 

71 # cuda or cpu 

72 return torch.device(name) 

73 

74 

75def set_seeds(value, all_gpus): 

76 """Sets up all relevant random seeds (numpy, python, cuda) 

77 

78 If running with multiple GPUs **at the same time**, set ``all_gpus`` to 

79 ``True`` to force all GPU seeds to be initialized. 

80 

81 Reference: `PyTorch page for reproducibility 

82 <https://pytorch.org/docs/stable/notes/randomness.html>`_. 

83 

84 

85 Parameters 

86 ---------- 

87 

88 value : int 

89 The random seed value to use 

90 

91 all_gpus : :py:class:`bool`, Optional 

92 If set, then reset the seed on all GPUs available at once. This is 

93 normally **not** what you want if running on a single GPU 

94 

95 """ 

96 

97 random.seed(value) 

98 numpy.random.seed(value) 

99 torch.manual_seed(value) 

100 torch.cuda.manual_seed(value) # noop if cuda not available 

101 

102 # set seeds for all gpus 

103 if all_gpus: 

104 torch.cuda.manual_seed_all(value) # noop if cuda not available 

105 

106 

107def set_reproducible_cuda(): 

108 """Turns-off all CUDA optimizations that would affect reproducibility 

109 

110 For full reproducibility, also ensure not to use multiple (parallel) data 

111 lowers. That is setup ``num_workers=0``. 

112 

113 Reference: `PyTorch page for reproducibility 

114 <https://pytorch.org/docs/stable/notes/randomness.html>`_. 

115 

116 

117 """ 

118 

119 # ensure to use only optimization algos for cuda that are known to have 

120 # a deterministic effect (not random) 

121 torch.backends.cudnn.deterministic = True 

122 

123 # turns off any optimization tricks 

124 torch.backends.cudnn.benchmark = False 

125 

126 

127@click.command( 

128 entry_point_group="bob.med.tb.config", 

129 cls=ConfigCommand, 

130 epilog="""Examples: 

131 

132\b 

133 1. Trains PASA model with Montgomery dataset, 

134 on a GPU (``cuda:0``): 

135 

136 $ bob tb train -vv pasa montgomery --batch-size=4 --device="cuda:0" 

137 

138""", 

139) 

140@click.option( 

141 "--output-folder", 

142 "-o", 

143 help="Path where to store the generated model (created if does not exist)", 

144 required=True, 

145 type=click.Path(), 

146 default="results", 

147 cls=ResourceOption, 

148) 

149@click.option( 

150 "--model", 

151 "-m", 

152 help="A torch.nn.Module instance implementing the network to be trained", 

153 required=True, 

154 cls=ResourceOption, 

155) 

156@click.option( 

157 "--dataset", 

158 "-d", 

159 help="A dictionary mapping string keys to " 

160 "torch.utils.data.dataset.Dataset instances implementing datasets " 

161 "to be used for training and validating the model, possibly including all " 

162 "pre-processing pipelines required or, optionally, a dictionary mapping " 

163 "string keys to torch.utils.data.dataset.Dataset instances. At least " 

164 "one key named ``train`` must be available. This dataset will be used for " 

165 "training the network model. The dataset description must include all " 

166 "required pre-processing, including eventual data augmentation. If a " 

167 "dataset named ``__train__`` is available, it is used prioritarily for " 

168 "training instead of ``train``. If a dataset named ``__valid__`` is " 

169 "available, it is used for model validation (and automatic " 

170 "check-pointing) at each epoch. If a dataset list named " 

171 "``__extra_valid__`` is available, then it will be tracked during the " 

172 "validation process and its loss output at the training log as well, " 

173 "in the format of an array occupying a single column. All other keys " 

174 "are considered test datasets and are ignored during training", 

175 required=True, 

176 cls=ResourceOption, 

177) 

178@click.option( 

179 "--optimizer", 

180 help="A torch.optim.Optimizer that will be used to train the network", 

181 required=True, 

182 cls=ResourceOption, 

183) 

184@click.option( 

185 "--criterion", 

186 help="A loss function to compute the CNN error for every sample " 

187 "respecting the PyTorch API for loss functions (see torch.nn.modules.loss)", 

188 required=True, 

189 cls=ResourceOption, 

190) 

191@click.option( 

192 "--criterion-valid", 

193 help="A specific loss function for the validation set to compute the CNN" 

194 "error for every sample respecting the PyTorch API for loss functions" 

195 "(see torch.nn.modules.loss)", 

196 required=False, 

197 cls=ResourceOption, 

198) 

199@click.option( 

200 "--batch-size", 

201 "-b", 

202 help="Number of samples in every batch (this parameter affects " 

203 "memory requirements for the network). If the number of samples in " 

204 "the batch is larger than the total number of samples available for " 

205 "training, this value is truncated. If this number is smaller, then " 

206 "batches of the specified size are created and fed to the network " 

207 "until there are no more new samples to feed (epoch is finished). " 

208 "If the total number of training samples is not a multiple of the " 

209 "batch-size, the last batch will be smaller than the first, unless " 

210 "--drop-incomplete-batch is set, in which case this batch is not used.", 

211 required=True, 

212 show_default=True, 

213 default=1, 

214 type=click.IntRange(min=1), 

215 cls=ResourceOption, 

216) 

217@click.option( 

218 "--batch-chunk-count", 

219 "-c", 

220 help="Number of chunks in every batch (this parameter affects " 

221 "memory requirements for the network). The number of samples " 

222 "loaded for every iteration will be batch-size/batch-chunk-count. " 

223 "batch-size needs to be divisible by batch-chunk-count, otherwise an " 

224 "error will be raised. This parameter is used to reduce number of " 

225 "samples loaded in each iteration, in order to reduce the memory usage " 

226 "in exchange for processing time (more iterations). This is specially " 

227 "interesting whe one is running with GPUs with limited RAM. The " 

228 "default of 1 forces the whole batch to be processed at once. Otherwise " 

229 "the batch is broken into batch-chunk-count pieces, and gradients are " 

230 "accumulated to complete each batch.", 

231 required=True, 

232 show_default=True, 

233 default=1, 

234 type=click.IntRange(min=1), 

235 cls=ResourceOption, 

236) 

237@click.option( 

238 "--drop-incomplete-batch/--no-drop-incomplete-batch", 

239 "-D", 

240 help="If set, then may drop the last batch in an epoch, in case it is " 

241 "incomplete. If you set this option, you should also consider " 

242 "increasing the total number of epochs of training, as the total number " 

243 "of training steps may be reduced", 

244 required=True, 

245 show_default=True, 

246 default=False, 

247 cls=ResourceOption, 

248) 

249@click.option( 

250 "--epochs", 

251 "-e", 

252 help="Number of epochs (complete training set passes) to train for. " 

253 "If continuing from a saved checkpoint, ensure to provide a greater " 

254 "number of epochs than that saved on the checkpoint to be loaded. ", 

255 show_default=True, 

256 required=True, 

257 default=1000, 

258 type=click.IntRange(min=1), 

259 cls=ResourceOption, 

260) 

261@click.option( 

262 "--checkpoint-period", 

263 "-p", 

264 help="Number of epochs after which a checkpoint is saved. " 

265 "A value of zero will disable check-pointing. If checkpointing is " 

266 "enabled and training stops, it is automatically resumed from the " 

267 "last saved checkpoint if training is restarted with the same " 

268 "configuration.", 

269 show_default=True, 

270 required=True, 

271 default=0, 

272 type=click.IntRange(min=0), 

273 cls=ResourceOption, 

274) 

275@click.option( 

276 "--device", 

277 "-d", 

278 help='A string indicating the device to use (e.g. "cpu" or "cuda:0")', 

279 show_default=True, 

280 required=True, 

281 default="cpu", 

282 cls=ResourceOption, 

283) 

284@click.option( 

285 "--seed", 

286 "-s", 

287 help="Seed to use for the random number generator", 

288 show_default=True, 

289 required=False, 

290 default=42, 

291 type=click.IntRange(min=0), 

292 cls=ResourceOption, 

293) 

294@click.option( 

295 "--parallel", 

296 "-P", 

297 help="""Use multiprocessing for data loading: if set to -1 (default), 

298 disables multiprocessing data loading. Set to 0 to enable as many data 

299 loading instances as processing cores as available in the system. Set to 

300 >= 1 to enable that many multiprocessing instances for data loading.""", 

301 type=click.IntRange(min=-1), 

302 show_default=True, 

303 required=True, 

304 default=-1, 

305 cls=ResourceOption, 

306) 

307@click.option( 

308 "--weight", 

309 "-w", 

310 help="Path or URL to pretrained model file (.pth extension)", 

311 required=False, 

312 cls=ResourceOption, 

313) 

314@click.option( 

315 "--normalization", 

316 "-n", 

317 help="Z-Normalization of input images: 'imagenet' for ImageNet parameters," 

318 " 'current' for parameters of the current trainset, " 

319 "'none' for no normalization.", 

320 required=False, 

321 default="none", 

322 cls=ResourceOption, 

323) 

324@click.option( 

325 "--monitoring-interval", 

326 "-I", 

327 help="""Time between checks for the use of resources during each training 

328 epoch. An interval of 5 seconds, for example, will lead to CPU and GPU 

329 resources being probed every 5 seconds during each training epoch. 

330 Values registered in the training logs correspond to averages (or maxima) 

331 observed through possibly many probes in each epoch. Notice that setting a 

332 very small value may cause the probing process to become extremely busy, 

333 potentially biasing the overall perception of resource usage.""", 

334 type=click.FloatRange(min=0.1), 

335 show_default=True, 

336 required=True, 

337 default=5.0, 

338 cls=ResourceOption, 

339) 

340@verbosity_option(cls=ResourceOption) 

341def train( 

342 model, 

343 optimizer, 

344 output_folder, 

345 epochs, 

346 batch_size, 

347 batch_chunk_count, 

348 drop_incomplete_batch, 

349 criterion, 

350 criterion_valid, 

351 dataset, 

352 checkpoint_period, 

353 device, 

354 seed, 

355 parallel, 

356 weight, 

357 normalization, 

358 monitoring_interval, 

359 verbose, 

360 **kwargs, 

361): 

362 """Trains an CNN to perform tuberculosis detection 

363 

364 Training is performed for a configurable number of epochs, and generates at 

365 least a final_model.pth. It may also generate a number of intermediate 

366 checkpoints. Checkpoints are model files (.pth files) that are stored 

367 during the training and useful to resume the procedure in case it stops 

368 abruptly. 

369 """ 

370 

371 device = setup_pytorch_device(device) 

372 

373 set_seeds(seed, all_gpus=False) 

374 

375 use_dataset = dataset 

376 validation_dataset = None 

377 extra_validation_datasets = [] 

378 

379 if isinstance(dataset, dict): 

380 if "__train__" in dataset: 

381 logger.info("Found (dedicated) '__train__' set for training") 

382 use_dataset = dataset["__train__"] 

383 else: 

384 use_dataset = dataset["train"] 

385 

386 if "__valid__" in dataset: 

387 logger.info("Found (dedicated) '__valid__' set for validation") 

388 logger.info("Will checkpoint lowest loss model on validation set") 

389 validation_dataset = dataset["__valid__"] 

390 

391 if "__extra_valid__" in dataset: 

392 if not isinstance(dataset["__extra_valid__"], list): 

393 raise RuntimeError( 

394 f"If present, dataset['__extra_valid__'] must be a list, " 

395 f"but you passed a {type(dataset['__extra_valid__'])}, " 

396 f"which is invalid." 

397 ) 

398 logger.info( 

399 f"Found {len(dataset['__extra_valid__'])} extra validation " 

400 f"set(s) to be tracked during training" 

401 ) 

402 logger.info( 

403 "Extra validation sets are NOT used for model checkpointing!" 

404 ) 

405 extra_validation_datasets = dataset["__extra_valid__"] 

406 

407 # PyTorch dataloader 

408 multiproc_kwargs = dict() 

409 if parallel < 0: 

410 multiproc_kwargs["num_workers"] = 0 

411 else: 

412 multiproc_kwargs["num_workers"] = ( 

413 parallel or multiprocessing.cpu_count() 

414 ) 

415 

416 if multiproc_kwargs["num_workers"] > 0 and sys.platform == "darwin": 

417 multiproc_kwargs[ 

418 "multiprocessing_context" 

419 ] = multiprocessing.get_context("spawn") 

420 

421 batch_chunk_size = batch_size 

422 if batch_size % batch_chunk_count != 0: 

423 # batch_size must be divisible by batch_chunk_count. 

424 raise RuntimeError( 

425 f"--batch-size ({batch_size}) must be divisible by " 

426 f"--batch-chunk-size ({batch_chunk_count})." 

427 ) 

428 else: 

429 batch_chunk_size = batch_size // batch_chunk_count 

430 

431 # Create weighted random sampler 

432 train_samples_weights = get_samples_weights(use_dataset) 

433 train_samples_weights = train_samples_weights.to( 

434 device=device, non_blocking=torch.cuda.is_available() 

435 ) 

436 train_sampler = WeightedRandomSampler( 

437 train_samples_weights, len(train_samples_weights), replacement=True 

438 ) 

439 

440 # Redefine a weighted criterion if possible 

441 if isinstance(criterion, torch.nn.BCEWithLogitsLoss): 

442 positive_weights = get_positive_weights(use_dataset) 

443 positive_weights = positive_weights.to( 

444 device=device, non_blocking=torch.cuda.is_available() 

445 ) 

446 criterion = BCEWithLogitsLoss(pos_weight=positive_weights) 

447 else: 

448 logger.warning("Weighted criterion not supported") 

449 

450 # PyTorch dataloader 

451 

452 data_loader = DataLoader( 

453 dataset=use_dataset, 

454 batch_size=batch_chunk_size, 

455 drop_last=drop_incomplete_batch, 

456 pin_memory=torch.cuda.is_available(), 

457 sampler=train_sampler, 

458 **multiproc_kwargs, 

459 ) 

460 

461 valid_loader = None 

462 if validation_dataset is not None: 

463 

464 # Redefine a weighted valid criterion if possible 

465 if ( 

466 isinstance(criterion_valid, torch.nn.BCEWithLogitsLoss) 

467 or criterion_valid is None 

468 ): 

469 positive_weights = get_positive_weights(validation_dataset) 

470 positive_weights = positive_weights.to( 

471 device=device, non_blocking=torch.cuda.is_available() 

472 ) 

473 criterion_valid = BCEWithLogitsLoss(pos_weight=positive_weights) 

474 else: 

475 logger.warning("Weighted valid criterion not supported") 

476 

477 valid_loader = DataLoader( 

478 dataset=validation_dataset, 

479 batch_size=batch_chunk_size, 

480 shuffle=False, 

481 drop_last=False, 

482 pin_memory=torch.cuda.is_available(), 

483 **multiproc_kwargs, 

484 ) 

485 

486 extra_valid_loaders = [ 

487 DataLoader( 

488 dataset=k, 

489 batch_size=batch_chunk_size, 

490 shuffle=False, 

491 drop_last=False, 

492 pin_memory=torch.cuda.is_available(), 

493 **multiproc_kwargs, 

494 ) 

495 for k in extra_validation_datasets 

496 ] 

497 

498 # Create z-normalization model layer if needed 

499 if normalization == "imagenet": 

500 model.normalizer.set_mean_std( 

501 [0.485, 0.456, 0.406], [0.229, 0.224, 0.225] 

502 ) 

503 logger.info("Z-normalization with ImageNet mean and std") 

504 elif normalization == "current": 

505 # Compute mean/std of current train subset 

506 temp_dl = DataLoader(dataset=use_dataset, batch_size=len(use_dataset)) 

507 

508 data = next(iter(temp_dl)) 

509 mean = data[1].mean(dim=[0, 2, 3]) 

510 std = data[1].std(dim=[0, 2, 3]) 

511 

512 model.normalizer.set_mean_std(mean, std) 

513 

514 # Format mean and std for logging 

515 mean = str( 

516 [ 

517 round(x, 3) 

518 for x in ((mean * 10**3).round() / (10**3)).tolist() 

519 ] 

520 ) 

521 std = str( 

522 [ 

523 round(x, 3) 

524 for x in ((std * 10**3).round() / (10**3)).tolist() 

525 ] 

526 ) 

527 logger.info("Z-normalization with mean {} and std {}".format(mean, std)) 

528 

529 # Checkpointer 

530 checkpointer = Checkpointer(model, optimizer, path=output_folder) 

531 

532 # Load pretrained weights if needed 

533 if weight is not None: 

534 if weight.startswith("http"): 

535 logger.info(f"Temporarily downloading '{weight}'...") 

536 f = download_to_tempfile(weight, progress=True) 

537 weight_fullpath = os.path.abspath(f.name) 

538 else: 

539 weight_fullpath = os.path.abspath(weight) 

540 checkpointer.load(weight_fullpath, strict=False) 

541 

542 arguments = {} 

543 arguments["epoch"] = 0 

544 arguments["max_epoch"] = epochs 

545 

546 logger.info("Training for {} epochs".format(arguments["max_epoch"])) 

547 logger.info("Continuing from epoch {}".format(arguments["epoch"])) 

548 

549 run( 

550 model=model, 

551 data_loader=data_loader, 

552 valid_loader=valid_loader, 

553 extra_valid_loaders=extra_valid_loaders, 

554 optimizer=optimizer, 

555 criterion=criterion, 

556 checkpointer=checkpointer, 

557 checkpoint_period=checkpoint_period, 

558 device=device, 

559 arguments=arguments, 

560 output_folder=output_folder, 

561 monitoring_interval=monitoring_interval, 

562 batch_chunk_count=batch_chunk_count, 

563 criterion_valid=criterion_valid, 

564 )