Coverage for /scratch/builds/bob/bob.med.tb/miniconda/conda-bld/bob.med.tb_1674079587905/_test_env_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placeho/lib/python3.10/site-packages/bob/med/tb/engine/trainer.py: 93%

162 statements  

« prev     ^ index     » next       coverage.py v7.0.5, created at 2023-01-18 22:14 +0000

1#!/usr/bin/env python 

2# -*- coding: utf-8 -*- 

3 

4import os 

5import sys 

6import csv 

7import time 

8import shutil 

9import datetime 

10import contextlib 

11 

12import numpy 

13import torch 

14from tqdm import tqdm 

15 

16from ..utils.measure import SmoothedValue 

17from ..utils.summary import summary 

18 

19# from ..utils.resources import cpu_constants, gpu_constants, cpu_log, gpu_log 

20from ..utils.resources import ( 

21 ResourceMonitor, 

22 cpu_constants, 

23 gpu_constants, 

24) 

25 

26import logging 

27 

28logger = logging.getLogger(__name__) 

29 

30 

31@contextlib.contextmanager 

32def torch_evaluation(model): 

33 """Context manager to turn ON/OFF model evaluation 

34 

35 This context manager will turn evaluation mode ON on entry and turn it OFF 

36 when exiting the ``with`` statement block. 

37 

38 

39 Parameters 

40 ---------- 

41 

42 model : :py:class:`torch.nn.Module` 

43 Network 

44 

45 

46 Yields 

47 ------ 

48 

49 model : :py:class:`torch.nn.Module` 

50 Network 

51 

52 """ 

53 

54 model.eval() 

55 yield model 

56 model.train() 

57 

58 

59def check_gpu(device): 

60 """ 

61 Check the device type and the availability of GPU. 

62 

63 Parameters 

64 ---------- 

65 

66 device : :py:class:`torch.device` 

67 device to use 

68 

69 """ 

70 if device.type == "cuda": 

71 # asserts we do have a GPU 

72 assert bool( 

73 gpu_constants() 

74 ), f"Device set to '{device}', but nvidia-smi is not installed" 

75 

76 

77def save_model_summary(output_folder, model): 

78 """ 

79 Save a little summary of the model in a txt file. 

80 

81 Parameters 

82 ---------- 

83 

84 output_folder : str 

85 output path 

86 

87 model : :py:class:`torch.nn.Module` 

88 Network (e.g. driu, hed, unet) 

89 

90 Returns 

91 ------- 

92 r : str 

93 The model summary in a text format. 

94 

95 n : int 

96 The number of parameters of the model. 

97 

98 """ 

99 summary_path = os.path.join(output_folder, "model_summary.txt") 

100 logger.info(f"Saving model summary at {summary_path}...") 

101 with open(summary_path, "wt") as f: 

102 r, n = summary(model) 

103 logger.info(f"Model has {n} parameters...") 

104 f.write(r) 

105 return r, n 

106 

107 

108def static_information_to_csv(static_logfile_name, device, n): 

109 """ 

110 Save the static information in a csv file. 

111 

112 Parameters 

113 ---------- 

114 

115 static_logfile_name : str 

116 The static file name which is a join between the output folder and "constant.csv" 

117 

118 """ 

119 if os.path.exists(static_logfile_name): 

120 backup = static_logfile_name + "~" 

121 if os.path.exists(backup): 

122 os.unlink(backup) 

123 shutil.move(static_logfile_name, backup) 

124 with open(static_logfile_name, "w", newline="") as f: 

125 logdata = cpu_constants() 

126 if device.type == "cuda": 

127 logdata += gpu_constants() 

128 logdata += (("model_size", n),) 

129 logwriter = csv.DictWriter(f, fieldnames=[k[0] for k in logdata]) 

130 logwriter.writeheader() 

131 logwriter.writerow(dict(k for k in logdata)) 

132 

133 

134def check_exist_logfile(logfile_name, arguments): 

135 """ 

136 Check existance of logfile (trainlog.csv), 

137 If the logfile exist the and the epochs number are still 0, The logfile will be replaced. 

138 

139 Parameters 

140 ---------- 

141 

142 logfile_name : str 

143 The logfile_name which is a join between the output_folder and trainlog.csv 

144 

145 arguments : dict 

146 start and end epochs 

147 

148 """ 

149 if arguments["epoch"] == 0 and os.path.exists(logfile_name): 

150 backup = logfile_name + "~" 

151 if os.path.exists(backup): 

152 os.unlink(backup) 

153 shutil.move(logfile_name, backup) 

154 

155 

156def create_logfile_fields(valid_loader, extra_valid_loaders, device): 

157 """ 

158 Creation of the logfile fields that will appear in the logfile. 

159 

160 Parameters 

161 ---------- 

162 

163 valid_loader : :py:class:`torch.utils.data.DataLoader` 

164 To be used to validate the model and enable automatic checkpointing. 

165 If set to ``None``, then do not validate it. 

166 

167 extra_valid_loaders : :py:class:`list` of :py:class:`torch.utils.data.DataLoader` 

168 To be used to validate the model, however **does not affect** automatic 

169 checkpointing. If set to ``None``, or empty, then does not log anything 

170 else. Otherwise, an extra column with the loss of every dataset in 

171 this list is kept on the final training log. 

172 

173 device : :py:class:`torch.device` 

174 device to use 

175 

176 Returns 

177 ------- 

178 

179 logfile_fields: tuple 

180 The fields that will appear in trainlog.csv 

181 

182 

183 """ 

184 logfile_fields = ( 

185 "epoch", 

186 "total_time", 

187 "eta", 

188 "loss", 

189 "learning_rate", 

190 ) 

191 if valid_loader is not None: 

192 logfile_fields += ("validation_loss",) 

193 if extra_valid_loaders: 

194 logfile_fields += ("extra_validation_losses",) 

195 logfile_fields += tuple( 

196 ResourceMonitor.monitored_keys(device.type == "cuda") 

197 ) 

198 return logfile_fields 

199 

200 

201def train_epoch(loader, model, optimizer, device, criterion, batch_chunk_count): 

202 """Trains the model for a single epoch (through all batches) 

203 

204 Parameters 

205 ---------- 

206 

207 loader : :py:class:`torch.utils.data.DataLoader` 

208 To be used to train the model 

209 

210 model : :py:class:`torch.nn.Module` 

211 Network (e.g. driu, hed, unet) 

212 

213 optimizer : :py:mod:`torch.optim` 

214 

215 device : :py:class:`torch.device` 

216 device to use 

217 

218 criterion : :py:class:`torch.nn.modules.loss._Loss` 

219 

220 batch_chunk_count: int 

221 If this number is different than 1, then each batch will be divided in 

222 this number of chunks. Gradients will be accumulated to perform each 

223 mini-batch. This is particularly interesting when one has limited RAM 

224 on the GPU, but would like to keep training with larger batches. One 

225 exchanges for longer processing times in this case. To better understand 

226 gradient accumulation, read 

227 https://stackoverflow.com/questions/62067400/understanding-accumulated-gradients-in-pytorch. 

228 

229 

230 Returns 

231 ------- 

232 

233 loss : float 

234 A floating-point value corresponding the weighted average of this 

235 epoch's loss 

236 

237 """ 

238 

239 losses_in_epoch = [] 

240 samples_in_epoch = [] 

241 losses_in_batch = [] 

242 samples_in_batch = [] 

243 

244 # progress bar only on interactive jobs 

245 for idx, samples in enumerate( 

246 tqdm(loader, desc="train", leave=False, disable=None) 

247 ): 

248 

249 images = samples[1].to( 

250 device=device, non_blocking=torch.cuda.is_available() 

251 ) 

252 labels = samples[2].to( 

253 device=device, non_blocking=torch.cuda.is_available() 

254 ) 

255 

256 # Increase label dimension if too low 

257 # Allows single and multiclass usage 

258 if labels.ndim == 1: 

259 labels = torch.reshape(labels, (labels.shape[0], 1)) 

260 

261 # Forward pass on the network 

262 outputs = model(images) 

263 

264 loss = criterion(outputs, labels.double()) 

265 

266 losses_in_batch.append(loss.item()) 

267 samples_in_batch.append(len(samples)) 

268 

269 # Normalize loss to account for batch accumulation 

270 loss = loss / batch_chunk_count 

271 

272 # Accumulate gradients - does not update weights just yet... 

273 loss.backward() 

274 

275 # Weight update on the network 

276 if ((idx + 1) % batch_chunk_count == 0) or (idx + 1 == len(loader)): 

277 # Advances optimizer to the "next" state and applies weight update 

278 # over the whole model 

279 optimizer.step() 

280 

281 # Zeroes gradients for the next batch 

282 optimizer.zero_grad() 

283 

284 # Normalize loss for current batch 

285 batch_loss = numpy.average( 

286 losses_in_batch, weights=samples_in_batch 

287 ) 

288 losses_in_epoch.append(batch_loss.item()) 

289 samples_in_epoch.append(len(samples)) 

290 

291 losses_in_batch.clear() 

292 samples_in_batch.clear() 

293 logger.debug(f"batch loss: {batch_loss.item()}") 

294 

295 return numpy.average(losses_in_epoch, weights=samples_in_epoch) 

296 

297 

298def validate_epoch(loader, model, device, criterion, pbar_desc): 

299 """ 

300 Processes input samples and returns loss (scalar) 

301 

302 

303 Parameters 

304 ---------- 

305 

306 loader : :py:class:`torch.utils.data.DataLoader` 

307 To be used to validate the model 

308 

309 model : :py:class:`torch.nn.Module` 

310 Network (e.g. driu, hed, unet) 

311 

312 optimizer : :py:mod:`torch.optim` 

313 

314 device : :py:class:`torch.device` 

315 device to use 

316 

317 criterion : :py:class:`torch.nn.modules.loss._Loss` 

318 loss function 

319 

320 pbar_desc : str 

321 A string for the progress bar descriptor 

322 

323 

324 Returns 

325 ------- 

326 

327 loss : float 

328 A floating-point value corresponding the weighted average of this 

329 epoch's loss 

330 

331 """ 

332 

333 batch_losses = [] 

334 samples_in_batch = [] 

335 

336 with torch.no_grad(), torch_evaluation(model): 

337 

338 for samples in tqdm(loader, desc=pbar_desc, leave=False, disable=None): 

339 images = samples[1].to( 

340 device=device, 

341 non_blocking=torch.cuda.is_available(), 

342 ) 

343 labels = samples[2].to( 

344 device=device, 

345 non_blocking=torch.cuda.is_available(), 

346 ) 

347 

348 # Increase label dimension if too low 

349 # Allows single and multiclass usage 

350 if labels.ndim == 1: 

351 labels = torch.reshape(labels, (labels.shape[0], 1)) 

352 

353 # data forwarding on the existing network 

354 outputs = model(images) 

355 loss = criterion(outputs, labels.double()) 

356 

357 batch_losses.append(loss.item()) 

358 samples_in_batch.append(len(samples)) 

359 

360 return numpy.average(batch_losses, weights=samples_in_batch) 

361 

362 

363def checkpointer_process( 

364 checkpointer, 

365 checkpoint_period, 

366 valid_loss, 

367 lowest_validation_loss, 

368 arguments, 

369 epoch, 

370 max_epoch, 

371): 

372 """ 

373 Process the checkpointer, save the final model and keep track of the best model. 

374 

375 Parameters 

376 ---------- 

377 

378 checkpointer : :py:class:`bob.med.tb.utils.checkpointer.Checkpointer` 

379 checkpointer implementation 

380 

381 checkpoint_period : int 

382 save a checkpoint every ``n`` epochs. If set to ``0`` (zero), then do 

383 not save intermediary checkpoints 

384 

385 valid_loss : float 

386 Current epoch validation loss 

387 

388 lowest_validation_loss : float 

389 Keeps track of the best (lowest) validation loss 

390 

391 arguments : dict 

392 start and end epochs 

393 

394 max_epoch : int 

395 end_potch 

396 

397 Returns 

398 ------- 

399 

400 lowest_validation_loss : float 

401 The lowest validation loss currently observed 

402 

403 

404 """ 

405 if checkpoint_period and (epoch % checkpoint_period == 0): 

406 checkpointer.save("model_periodic_save", **arguments) 

407 

408 if valid_loss is not None and valid_loss < lowest_validation_loss: 

409 lowest_validation_loss = valid_loss 

410 logger.info( 

411 f"Found new low on validation set:" f" {lowest_validation_loss:.6f}" 

412 ) 

413 checkpointer.save("model_lowest_valid_loss", **arguments) 

414 

415 if epoch >= max_epoch: 

416 checkpointer.save("model_final_epoch", **arguments) 

417 

418 return lowest_validation_loss 

419 

420 

421def write_log_info( 

422 epoch, 

423 current_time, 

424 eta_seconds, 

425 loss, 

426 valid_loss, 

427 extra_valid_losses, 

428 optimizer, 

429 logwriter, 

430 logfile, 

431 resource_data, 

432): 

433 """ 

434 Write log info in trainlog.csv 

435 

436 Parameters 

437 ---------- 

438 

439 epoch : int 

440 Current epoch 

441 

442 current_time : float 

443 Current training time 

444 

445 eta_seconds : float 

446 estimated time-of-arrival taking into consideration previous epoch performance 

447 

448 loss : float 

449 Current epoch's training loss 

450 

451 valid_loss : :py:class:`float`, None 

452 Current epoch's validation loss 

453 

454 extra_valid_losses : :py:class:`list` of :py:class:`float` 

455 Validation losses from other validation datasets being currently 

456 tracked 

457 

458 optimizer : :py:mod:`torch.optim` 

459 

460 logwriter : csv.DictWriter 

461 Dictionary writer that give the ability to write on the trainlog.csv 

462 

463 logfile : io.TextIOWrapper 

464 

465 resource_data : tuple 

466 Monitored resources at the machine (CPU and GPU) 

467 

468 """ 

469 

470 logdata = ( 

471 ("epoch", f"{epoch}"), 

472 ( 

473 "total_time", 

474 f"{datetime.timedelta(seconds=int(current_time))}", 

475 ), 

476 ("eta", f"{datetime.timedelta(seconds=int(eta_seconds))}"), 

477 ("loss", f"{loss:.6f}"), 

478 ("learning_rate", f"{optimizer.param_groups[0]['lr']:.6f}"), 

479 ) 

480 

481 if valid_loss is not None: 

482 logdata += (("validation_loss", f"{valid_loss:.6f}"),) 

483 

484 if extra_valid_losses: 

485 entry = numpy.array_str( 

486 numpy.array(extra_valid_losses), 

487 max_line_width=sys.maxsize, 

488 precision=6, 

489 ) 

490 logdata += (("extra_validation_losses", entry),) 

491 

492 logdata += resource_data 

493 

494 logwriter.writerow(dict(k for k in logdata)) 

495 logfile.flush() 

496 tqdm.write("|".join([f"{k}: {v}" for (k, v) in logdata[:4]])) 

497 

498 

499def run( 

500 model, 

501 data_loader, 

502 valid_loader, 

503 extra_valid_loaders, 

504 optimizer, 

505 criterion, 

506 checkpointer, 

507 checkpoint_period, 

508 device, 

509 arguments, 

510 output_folder, 

511 monitoring_interval, 

512 batch_chunk_count, 

513 criterion_valid, 

514): 

515 """ 

516 Fits a CNN model using supervised learning and save it to disk. 

517 

518 This method supports periodic checkpointing and the output of a 

519 CSV-formatted log with the evolution of some figures during training. 

520 

521 

522 Parameters 

523 ---------- 

524 

525 model : :py:class:`torch.nn.Module` 

526 Network (e.g. driu, hed, unet) 

527 

528 data_loader : :py:class:`torch.utils.data.DataLoader` 

529 To be used to train the model 

530 

531 valid_loaders : :py:class:`list` of :py:class:`torch.utils.data.DataLoader` 

532 To be used to validate the model and enable automatic checkpointing. 

533 If ``None``, then do not validate it. 

534 

535 extra_valid_loaders : :py:class:`list` of :py:class:`torch.utils.data.DataLoader` 

536 To be used to validate the model, however **does not affect** automatic 

537 checkpointing. If empty, then does not log anything else. Otherwise, 

538 an extra column with the loss of every dataset in this list is kept on 

539 the final training log. 

540 

541 optimizer : :py:mod:`torch.optim` 

542 

543 criterion : :py:class:`torch.nn.modules.loss._Loss` 

544 loss function 

545 

546 checkpointer : :py:class:`bob.med.tb.utils.checkpointer.Checkpointer` 

547 checkpointer implementation 

548 

549 checkpoint_period : int 

550 save a checkpoint every ``n`` epochs. If set to ``0`` (zero), then do 

551 not save intermediary checkpoints 

552 

553 device : :py:class:`torch.device` 

554 device to use 

555 

556 arguments : dict 

557 start and end epochs 

558 

559 output_folder : str 

560 output path 

561 

562 monitoring_interval : int, float 

563 interval, in seconds (or fractions), through which we should monitor 

564 resources during training. 

565 

566 batch_chunk_count: int 

567 If this number is different than 1, then each batch will be divided in 

568 this number of chunks. Gradients will be accumulated to perform each 

569 mini-batch. This is particularly interesting when one has limited RAM 

570 on the GPU, but would like to keep training with larger batches. One 

571 exchanges for longer processing times in this case. 

572 

573 criterion_valid : :py:class:`torch.nn.modules.loss._Loss` 

574 specific loss function for the validation set 

575 

576 """ 

577 

578 start_epoch = arguments["epoch"] 

579 max_epoch = arguments["max_epoch"] 

580 

581 check_gpu(device) 

582 

583 os.makedirs(output_folder, exist_ok=True) 

584 

585 # Save model summary 

586 r, n = save_model_summary(output_folder, model) 

587 

588 # write static information to a CSV file 

589 static_logfile_name = os.path.join(output_folder, "constants.csv") 

590 

591 static_information_to_csv(static_logfile_name, device, n) 

592 

593 # Log continous information to (another) file 

594 logfile_name = os.path.join(output_folder, "trainlog.csv") 

595 

596 check_exist_logfile(logfile_name, arguments) 

597 

598 logfile_fields = create_logfile_fields( 

599 valid_loader, extra_valid_loaders, device 

600 ) 

601 

602 # the lowest validation loss obtained so far - this value is updated only 

603 # if a validation set is available 

604 lowest_validation_loss = sys.float_info.max 

605 

606 # set a specific validation criterion if the user has set one 

607 criterion_valid = criterion_valid or criterion 

608 

609 with open(logfile_name, "a+", newline="") as logfile: 

610 logwriter = csv.DictWriter(logfile, fieldnames=logfile_fields) 

611 

612 if arguments["epoch"] == 0: 

613 logwriter.writeheader() 

614 

615 model.train() # set training mode 

616 

617 model.to(device) # set/cast parameters to device 

618 for state in optimizer.state.values(): 

619 for k, v in state.items(): 

620 if isinstance(v, torch.Tensor): 

621 state[k] = v.to(device) 

622 

623 # Total training timer 

624 start_training_time = time.time() 

625 

626 for epoch in tqdm( 

627 range(start_epoch, max_epoch), 

628 desc="epoch", 

629 leave=False, 

630 disable=None, 

631 ): 

632 

633 with ResourceMonitor( 

634 interval=monitoring_interval, 

635 has_gpu=(device.type == "cuda"), 

636 main_pid=os.getpid(), 

637 logging_level=logging.ERROR, 

638 ) as resource_monitor: 

639 epoch = epoch + 1 

640 arguments["epoch"] = epoch 

641 

642 # Epoch time 

643 start_epoch_time = time.time() 

644 

645 train_loss = train_epoch( 

646 data_loader, 

647 model, 

648 optimizer, 

649 device, 

650 criterion, 

651 batch_chunk_count, 

652 ) 

653 

654 valid_loss = ( 

655 validate_epoch( 

656 valid_loader, model, device, criterion_valid, "valid" 

657 ) 

658 if valid_loader is not None 

659 else None 

660 ) 

661 

662 extra_valid_losses = [] 

663 for pos, extra_valid_loader in enumerate(extra_valid_loaders): 

664 loss = validate_epoch( 

665 extra_valid_loader, 

666 model, 

667 device, 

668 criterion_valid, 

669 f"xval@{pos+1}", 

670 ) 

671 extra_valid_losses.append(loss) 

672 

673 lowest_validation_loss = checkpointer_process( 

674 checkpointer, 

675 checkpoint_period, 

676 valid_loss, 

677 lowest_validation_loss, 

678 arguments, 

679 epoch, 

680 max_epoch, 

681 ) 

682 

683 # computes ETA (estimated time-of-arrival; end of training) taking 

684 # into consideration previous epoch performance 

685 epoch_time = time.time() - start_epoch_time 

686 eta_seconds = epoch_time * (max_epoch - epoch) 

687 current_time = time.time() - start_training_time 

688 

689 write_log_info( 

690 epoch, 

691 current_time, 

692 eta_seconds, 

693 train_loss, 

694 valid_loss, 

695 extra_valid_losses, 

696 optimizer, 

697 logwriter, 

698 logfile, 

699 resource_monitor.data, 

700 ) 

701 

702 total_training_time = time.time() - start_training_time 

703 logger.info( 

704 f"Total training time: {datetime.timedelta(seconds=total_training_time)} ({(total_training_time/max_epoch):.4f}s in average per epoch)" 

705 )