Coverage for src/deepdraw/engine/trainer.py: 91%

159 statements  

« prev     ^ index     » next       coverage.py v7.3.1, created at 2023-11-30 15:00 +0100

1# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> 

2# 

3# SPDX-License-Identifier: GPL-3.0-or-later 

4 

5import contextlib 

6import csv 

7import datetime 

8import logging 

9import os 

10import shutil 

11import sys 

12import time 

13 

14import numpy 

15import torch 

16 

17from tqdm import tqdm 

18 

19from ..utils.resources import ResourceMonitor, cpu_constants, gpu_constants 

20from ..utils.summary import summary 

21 

22logger = logging.getLogger(__name__) 

23 

24 

25@contextlib.contextmanager 

26def torch_evaluation(model): 

27 """Context manager to turn ON/OFF model evaluation. 

28 

29 This context manager will turn evaluation mode ON on entry and turn it OFF 

30 when exiting the ``with`` statement block. 

31 

32 

33 Parameters 

34 ---------- 

35 

36 model : :py:class:`torch.nn.Module` 

37 Network (e.g. driu, hed, unet) 

38 

39 

40 Yields 

41 ------ 

42 

43 model : :py:class:`torch.nn.Module` 

44 Network (e.g. driu, hed, unet) 

45 """ 

46 

47 model.eval() 

48 yield model 

49 model.train() 

50 

51 

52def check_gpu(device): 

53 """Check the device type and the availability of GPU. 

54 

55 Parameters 

56 ---------- 

57 

58 device : :py:class:`torch.device` 

59 device to use 

60 """ 

61 if device.type == "cuda": 

62 # asserts we do have a GPU 

63 assert bool( 

64 gpu_constants() 

65 ), f"Device set to '{device}', but nvidia-smi is not installed" 

66 

67 

68def save_model_summary(output_folder, model): 

69 """Save a little summary of the model in a txt file. 

70 

71 Parameters 

72 ---------- 

73 

74 output_folder : str 

75 output path 

76 

77 model : :py:class:`torch.nn.Module` 

78 Network (e.g. driu, hed, unet) 

79 

80 Returns 

81 ------- 

82 r : str 

83 The model summary in a text format. 

84 

85 n : int 

86 The number of parameters of the model. 

87 """ 

88 summary_path = os.path.join(output_folder, "model_summary.txt") 

89 logger.info(f"Saving model summary at {summary_path}...") 

90 with open(summary_path, "w") as f: 

91 r, n = summary(model) 

92 logger.info(f"Model has {n} parameters...") 

93 f.write(r) 

94 return r, n 

95 

96 

97def static_information_to_csv(static_logfile_name, device, n): 

98 """Save the static information in a csv file. 

99 

100 Parameters 

101 ---------- 

102 

103 static_logfile_name : str 

104 The static file name which is a join between the output folder and "constant.csv" 

105 """ 

106 if os.path.exists(static_logfile_name): 

107 backup = static_logfile_name + "~" 

108 if os.path.exists(backup): 

109 os.unlink(backup) 

110 shutil.move(static_logfile_name, backup) 

111 with open(static_logfile_name, "w", newline="") as f: 

112 logdata = cpu_constants() 

113 if device.type == "cuda": 

114 logdata += gpu_constants() 

115 logdata += (("model_size", n),) 

116 logwriter = csv.DictWriter(f, fieldnames=[k[0] for k in logdata]) 

117 logwriter.writeheader() 

118 logwriter.writerow(dict(k for k in logdata)) 

119 

120 

121def check_exist_logfile(logfile_name, arguments): 

122 """Check existance of logfile (trainlog.csv), If the logfile exist the and 

123 the epochs number are still 0, The logfile will be replaced. 

124 

125 Parameters 

126 ---------- 

127 

128 logfile_name : str 

129 The logfile_name which is a join between the output_folder and trainlog.csv 

130 

131 arguments : dict 

132 start and end epochs 

133 """ 

134 if arguments["epoch"] == 0 and os.path.exists(logfile_name): 

135 backup = logfile_name + "~" 

136 if os.path.exists(backup): 

137 os.unlink(backup) 

138 shutil.move(logfile_name, backup) 

139 

140 

141def create_logfile_fields(valid_loader, extra_valid_loaders, device): 

142 """Creation of the logfile fields that will appear in the logfile. 

143 

144 Parameters 

145 ---------- 

146 

147 valid_loader : :py:class:`torch.utils.data.DataLoader` 

148 To be used to validate the model and enable automatic checkpointing. 

149 If set to ``None``, then do not validate it. 

150 

151 extra_valid_loaders : :py:class:`list` of :py:class:`torch.utils.data.DataLoader` 

152 To be used to validate the model, however **does not affect** automatic 

153 checkpointing. If set to ``None``, or empty, then does not log anything 

154 else. Otherwise, an extra column with the loss of every dataset in 

155 this list is kept on the final training log. 

156 

157 device : :py:class:`torch.device` 

158 device to use 

159 

160 Returns 

161 ------- 

162 

163 logfile_fields: tuple 

164 The fields that will appear in trainlog.csv 

165 """ 

166 logfile_fields = ( 

167 "epoch", 

168 "total_time", 

169 "eta", 

170 "loss", 

171 "learning_rate", 

172 ) 

173 if valid_loader is not None: 

174 logfile_fields += ("validation_loss",) 

175 if extra_valid_loaders: 

176 logfile_fields += ("extra_validation_losses",) 

177 logfile_fields += tuple( 

178 ResourceMonitor.monitored_keys(device.type == "cuda") 

179 ) 

180 return logfile_fields 

181 

182 

183def train_epoch(loader, model, optimizer, device, criterion, batch_chunk_count): 

184 """Trains the model for a single epoch (through all batches) 

185 

186 Parameters 

187 ---------- 

188 

189 loader : :py:class:`torch.utils.data.DataLoader` 

190 To be used to train the model 

191 

192 model : :py:class:`torch.nn.Module` 

193 Network (e.g. driu, hed, unet) 

194 

195 optimizer : :py:mod:`torch.optim` 

196 

197 device : :py:class:`torch.device` 

198 device to use 

199 

200 criterion : :py:class:`torch.nn.modules.loss._Loss` 

201 

202 batch_chunk_count: int 

203 If this number is different than 1, then each batch will be divided in 

204 this number of chunks. Gradients will be accumulated to perform each 

205 mini-batch. This is particularly interesting when one has limited RAM 

206 on the GPU, but would like to keep training with larger batches. One 

207 exchanges for longer processing times in this case. To better understand 

208 gradient accumulation, read 

209 https://stackoverflow.com/questions/62067400/understanding-accumulated-gradients-in-pytorch. 

210 

211 

212 Returns 

213 ------- 

214 

215 loss : float 

216 A floating-point value corresponding the weighted average of this 

217 epoch's loss 

218 """ 

219 

220 losses_in_epoch = [] 

221 samples_in_epoch = [] 

222 losses_in_batch = [] 

223 samples_in_batch = [] 

224 

225 # progress bar only on interactive jobs 

226 for idx, samples in enumerate( 

227 tqdm(loader, desc="train", leave=False, disable=None) 

228 ): 

229 images = samples[1].to( 

230 device=device, non_blocking=torch.cuda.is_available() 

231 ) 

232 ground_truths = samples[2].to( 

233 device=device, non_blocking=torch.cuda.is_available() 

234 ) 

235 masks = ( 

236 torch.ones_like(ground_truths) 

237 if len(samples) < 4 

238 else samples[3].to( 

239 device=device, non_blocking=torch.cuda.is_available() 

240 ) 

241 ) 

242 

243 # Forward pass on the network 

244 outputs = model(images) 

245 loss = criterion(outputs, ground_truths, masks) 

246 

247 losses_in_batch.append(loss.item()) 

248 samples_in_batch.append(len(samples)) 

249 

250 # Normalize loss to account for batch accumulation 

251 loss = loss / batch_chunk_count 

252 

253 # Accumulate gradients - does not update weights just yet... 

254 loss.backward() 

255 

256 # Weight update on the network 

257 if ((idx + 1) % batch_chunk_count == 0) or (idx + 1 == len(loader)): 

258 # Advances optimizer to the "next" state and applies weight update 

259 # over the whole model 

260 optimizer.step() 

261 

262 # Zeroes gradients for the next batch 

263 optimizer.zero_grad() 

264 

265 # Normalize loss for current batch 

266 batch_loss = numpy.average( 

267 losses_in_batch, weights=samples_in_batch 

268 ) 

269 losses_in_epoch.append(batch_loss.item()) 

270 samples_in_epoch.append(len(samples)) 

271 

272 losses_in_batch.clear() 

273 samples_in_batch.clear() 

274 logger.debug(f"batch loss: {batch_loss.item()}") 

275 

276 return numpy.average(losses_in_epoch, weights=samples_in_epoch) 

277 

278 

279def validate_epoch(loader, model, device, criterion, pbar_desc): 

280 """Processes input samples and returns loss (scalar) 

281 

282 Parameters 

283 ---------- 

284 

285 loader : :py:class:`torch.utils.data.DataLoader` 

286 To be used to validate the model 

287 

288 model : :py:class:`torch.nn.Module` 

289 Network (e.g. driu, hed, unet) 

290 

291 optimizer : :py:mod:`torch.optim` 

292 

293 device : :py:class:`torch.device` 

294 device to use 

295 

296 criterion : :py:class:`torch.nn.modules.loss._Loss` 

297 loss function 

298 

299 pbar_desc : str 

300 A string for the progress bar descriptor 

301 

302 

303 Returns 

304 ------- 

305 

306 loss : float 

307 A floating-point value corresponding the weighted average of this 

308 epoch's loss 

309 """ 

310 

311 batch_losses = [] 

312 samples_in_batch = [] 

313 

314 with torch.no_grad(), torch_evaluation(model): 

315 for samples in tqdm(loader, desc=pbar_desc, leave=False, disable=None): 

316 images = samples[1].to( 

317 device=device, 

318 non_blocking=torch.cuda.is_available(), 

319 ) 

320 ground_truths = samples[2].to( 

321 device=device, 

322 non_blocking=torch.cuda.is_available(), 

323 ) 

324 masks = ( 

325 torch.ones_like(ground_truths) 

326 if len(samples) < 4 

327 else samples[3].to( 

328 device=device, 

329 non_blocking=torch.cuda.is_available(), 

330 ) 

331 ) 

332 

333 # data forwarding on the existing network 

334 outputs = model(images) 

335 loss = criterion(outputs, ground_truths, masks) 

336 

337 batch_losses.append(loss.item()) 

338 samples_in_batch.append(len(samples)) 

339 

340 return numpy.average(batch_losses, weights=samples_in_batch) 

341 

342 

343def checkpointer_process( 

344 checkpointer, 

345 checkpoint_period, 

346 valid_loss, 

347 lowest_validation_loss, 

348 arguments, 

349 epoch, 

350 max_epoch, 

351): 

352 """Process the checkpointer, save the final model and keep track of the 

353 best model. 

354 

355 Parameters 

356 ---------- 

357 

358 checkpointer : :py:class:`deepdraw.utils.checkpointer.Checkpointer` 

359 checkpointer implementation 

360 

361 checkpoint_period : int 

362 save a checkpoint every ``n`` epochs. If set to ``0`` (zero), then do 

363 not save intermediary checkpoints 

364 

365 valid_loss : float 

366 Current epoch validation loss 

367 

368 lowest_validation_loss : float 

369 Keeps track of the best (lowest) validation loss 

370 

371 arguments : dict 

372 start and end epochs 

373 

374 max_epoch : int 

375 end_potch 

376 

377 Returns 

378 ------- 

379 

380 lowest_validation_loss : float 

381 The lowest validation loss currently observed 

382 """ 

383 if checkpoint_period and (epoch % checkpoint_period == 0): 

384 checkpointer.save("model_periodic_save", **arguments) 

385 

386 if valid_loss is not None and valid_loss < lowest_validation_loss: 

387 lowest_validation_loss = valid_loss 

388 logger.info( 

389 f"Found new low on validation set:" f" {lowest_validation_loss:.6f}" 

390 ) 

391 checkpointer.save("model_lowest_valid_loss", **arguments) 

392 

393 if epoch >= max_epoch: 

394 checkpointer.save("model_final_epoch", **arguments) 

395 

396 return lowest_validation_loss 

397 

398 

399def write_log_info( 

400 epoch, 

401 current_time, 

402 eta_seconds, 

403 loss, 

404 valid_loss, 

405 extra_valid_losses, 

406 optimizer, 

407 logwriter, 

408 logfile, 

409 resource_data, 

410): 

411 """Write log info in trainlog.csv. 

412 

413 Parameters 

414 ---------- 

415 

416 epoch : int 

417 Current epoch 

418 

419 current_time : float 

420 Current training time 

421 

422 eta_seconds : float 

423 estimated time-of-arrival taking into consideration previous epoch performance 

424 

425 loss : float 

426 Current epoch's training loss 

427 

428 valid_loss : :py:class:`float`, None 

429 Current epoch's validation loss 

430 

431 extra_valid_losses : :py:class:`list` of :py:class:`float` 

432 Validation losses from other validation datasets being currently 

433 tracked 

434 

435 optimizer : :py:mod:`torch.optim` 

436 

437 logwriter : csv.DictWriter 

438 Dictionary writer that give the ability to write on the trainlog.csv 

439 

440 logfile : io.TextIOWrapper 

441 

442 resource_data : tuple 

443 Monitored resources at the machine (CPU and GPU) 

444 """ 

445 

446 logdata = ( 

447 ("epoch", f"{epoch}"), 

448 ( 

449 "total_time", 

450 f"{datetime.timedelta(seconds=int(current_time))}", 

451 ), 

452 ("eta", f"{datetime.timedelta(seconds=int(eta_seconds))}"), 

453 ("loss", f"{loss:.6f}"), 

454 ("learning_rate", f"{optimizer.param_groups[0]['lr']:.6f}"), 

455 ) 

456 

457 if valid_loss is not None: 

458 logdata += (("validation_loss", f"{valid_loss:.6f}"),) 

459 

460 if extra_valid_losses: 

461 entry = numpy.array_str( 

462 numpy.array(extra_valid_losses), 

463 max_line_width=sys.maxsize, 

464 precision=6, 

465 ) 

466 logdata += (("extra_validation_losses", entry),) 

467 

468 logdata += resource_data 

469 

470 logwriter.writerow(dict(k for k in logdata)) 

471 logfile.flush() 

472 tqdm.write("|".join([f"{k}: {v}" for (k, v) in logdata[:4]])) 

473 

474 

475def run( 

476 model, 

477 data_loader, 

478 valid_loader, 

479 extra_valid_loaders, 

480 optimizer, 

481 criterion, 

482 scheduler, 

483 checkpointer, 

484 checkpoint_period, 

485 device, 

486 arguments, 

487 output_folder, 

488 monitoring_interval, 

489 batch_chunk_count, 

490): 

491 """Fits an FCN model using supervised learning and save it to disk. 

492 

493 This method supports periodic checkpointing and the output of a 

494 CSV-formatted log with the evolution of some figures during training. 

495 

496 

497 Parameters 

498 ---------- 

499 

500 model : :py:class:`torch.nn.Module` 

501 Network (e.g. driu, hed, unet) 

502 

503 data_loader : :py:class:`torch.utils.data.DataLoader` 

504 To be used to train the model 

505 

506 valid_loaders : :py:class:`list` of :py:class:`torch.utils.data.DataLoader` 

507 To be used to validate the model and enable automatic checkpointing. 

508 If ``None``, then do not validate it. 

509 

510 extra_valid_loaders : :py:class:`list` of :py:class:`torch.utils.data.DataLoader` 

511 To be used to validate the model, however **does not affect** automatic 

512 checkpointing. If empty, then does not log anything else. Otherwise, 

513 an extra column with the loss of every dataset in this list is kept on 

514 the final training log. 

515 

516 optimizer : :py:mod:`torch.optim` 

517 

518 criterion : :py:class:`torch.nn.modules.loss._Loss` 

519 loss function 

520 

521 scheduler : :py:mod:`torch.optim` 

522 learning rate scheduler 

523 

524 checkpointer : :py:class:`deepdraw.utils.checkpointer.Checkpointer` 

525 checkpointer implementation 

526 

527 checkpoint_period : int 

528 save a checkpoint every ``n`` epochs. If set to ``0`` (zero), then do 

529 not save intermediary checkpoints 

530 

531 device : :py:class:`torch.device` 

532 device to use 

533 

534 arguments : dict 

535 start and end epochs 

536 

537 output_folder : str 

538 output path 

539 

540 monitoring_interval : int, float 

541 interval, in seconds (or fractions), through which we should monitor 

542 resources during training. 

543 

544 batch_chunk_count: int 

545 If this number is different than 1, then each batch will be divided in 

546 this number of chunks. Gradients will be accumulated to perform each 

547 mini-batch. This is particularly interesting when one has limited RAM 

548 on the GPU, but would like to keep training with larger batches. One 

549 exchanges for longer processing times in this case. 

550 """ 

551 

552 start_epoch = arguments["epoch"] 

553 max_epoch = arguments["max_epoch"] 

554 

555 check_gpu(device) 

556 

557 os.makedirs(output_folder, exist_ok=True) 

558 

559 # Save model summary 

560 r, n = save_model_summary(output_folder, model) 

561 

562 # write static information to a CSV file 

563 static_logfile_name = os.path.join(output_folder, "constants.csv") 

564 

565 static_information_to_csv(static_logfile_name, device, n) 

566 

567 # Log continous information to (another) file 

568 logfile_name = os.path.join(output_folder, "trainlog.csv") 

569 

570 check_exist_logfile(logfile_name, arguments) 

571 

572 logfile_fields = create_logfile_fields( 

573 valid_loader, extra_valid_loaders, device 

574 ) 

575 

576 # the lowest validation loss obtained so far - this value is updated only 

577 # if a validation set is available 

578 lowest_validation_loss = sys.float_info.max 

579 

580 with open(logfile_name, "a+", newline="") as logfile: 

581 logwriter = csv.DictWriter(logfile, fieldnames=logfile_fields) 

582 

583 if arguments["epoch"] == 0: 

584 logwriter.writeheader() 

585 

586 model.train() # set training mode 

587 

588 model.to(device) # set/cast parameters to device 

589 for state in optimizer.state.values(): 

590 for k, v in state.items(): 

591 if isinstance(v, torch.Tensor): 

592 state[k] = v.to(device) 

593 

594 # Total training timer 

595 start_training_time = time.time() 

596 

597 for epoch in tqdm( 

598 range(start_epoch, max_epoch), 

599 desc="epoch", 

600 leave=False, 

601 disable=None, 

602 ): 

603 with ResourceMonitor( 

604 interval=monitoring_interval, 

605 has_gpu=(device.type == "cuda"), 

606 main_pid=os.getpid(), 

607 logging_level=logging.ERROR, 

608 ) as resource_monitor: 

609 epoch = epoch + 1 

610 arguments["epoch"] = epoch 

611 

612 # Epoch time 

613 start_epoch_time = time.time() 

614 

615 train_loss = train_epoch( 

616 data_loader, 

617 model, 

618 optimizer, 

619 device, 

620 criterion, 

621 batch_chunk_count, 

622 ) 

623 

624 scheduler.step() 

625 

626 valid_loss = ( 

627 validate_epoch( 

628 valid_loader, model, device, criterion, "valid" 

629 ) 

630 if valid_loader is not None 

631 else None 

632 ) 

633 

634 extra_valid_losses = [] 

635 for pos, extra_valid_loader in enumerate(extra_valid_loaders): 

636 loss = validate_epoch( 

637 extra_valid_loader, 

638 model, 

639 device, 

640 criterion, 

641 f"xval@{pos+1}", 

642 ) 

643 extra_valid_losses.append(loss) 

644 

645 lowest_validation_loss = checkpointer_process( 

646 checkpointer, 

647 checkpoint_period, 

648 valid_loss, 

649 lowest_validation_loss, 

650 arguments, 

651 epoch, 

652 max_epoch, 

653 ) 

654 

655 # computes ETA (estimated time-of-arrival; end of training) taking 

656 # into consideration previous epoch performance 

657 epoch_time = time.time() - start_epoch_time 

658 eta_seconds = epoch_time * (max_epoch - epoch) 

659 current_time = time.time() - start_training_time 

660 

661 write_log_info( 

662 epoch, 

663 current_time, 

664 eta_seconds, 

665 train_loss, 

666 valid_loss, 

667 extra_valid_losses, 

668 optimizer, 

669 logwriter, 

670 logfile, 

671 resource_monitor.data, 

672 ) 

673 

674 total_training_time = time.time() - start_training_time 

675 logger.info( 

676 f"Total training time: {datetime.timedelta(seconds=total_training_time)} ({(total_training_time/max_epoch):.4f}s in average per epoch)" 

677 )