Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1#!/usr/bin/env python 

2# -*- coding: utf-8 -*- 

3 

4import contextlib 

5import csv 

6import datetime 

7import distutils.version 

8import logging 

9import os 

10import shutil 

11import sys 

12import time 

13 

14import torch 

15 

16from tqdm import tqdm 

17 

18from ..utils.measure import SmoothedValue 

19from ..utils.resources import cpu_constants, cpu_log, gpu_constants, gpu_log 

20from ..utils.summary import summary 

21 

22logger = logging.getLogger(__name__) 

23 

24PYTORCH_GE_110 = distutils.version.LooseVersion(torch.__version__) >= "1.1.0" 

25 

26 

27@contextlib.contextmanager 

28def torch_evaluation(model): 

29 """Context manager to turn ON/OFF model evaluation 

30 

31 This context manager will turn evaluation mode ON on entry and turn it OFF 

32 when exiting the ``with`` statement block. 

33 

34 

35 Parameters 

36 ---------- 

37 

38 model : :py:class:`torch.nn.Module` 

39 Network (e.g. driu, hed, unet) 

40 

41 

42 Yields 

43 ------ 

44 

45 model : :py:class:`torch.nn.Module` 

46 Network (e.g. driu, hed, unet) 

47 

48 """ 

49 

50 model.eval() 

51 yield model 

52 model.train() 

53 

54 

55def check_gpu(device): 

56 """ 

57 Check the device type and the availability of GPU. 

58 

59 Parameters 

60 ---------- 

61 

62 device : :py:class:`torch.device` 

63 device to use 

64 

65 """ 

66 if device.type == "cuda": 

67 # asserts we do have a GPU 

68 assert bool( 

69 gpu_constants() 

70 ), f"Device set to '{device}', but nvidia-smi is not installed" 

71 

72 

73def save_model_summary(output_folder, model): 

74 """ 

75 Save a little summary of the model in a txt file. 

76 

77 Parameters 

78 ---------- 

79 

80 output_folder : str 

81 output path 

82 

83 model : :py:class:`torch.nn.Module` 

84 Network (e.g. driu, hed, unet) 

85 

86 Returns 

87 ------- 

88 r : str 

89 The model summary in a text format. 

90 

91 n : int 

92 The number of parameters of the model. 

93 

94 """ 

95 summary_path = os.path.join(output_folder, "model_summary.txt") 

96 logger.info(f"Saving model summary at {summary_path}...") 

97 with open(summary_path, "wt") as f: 

98 r, n = summary(model) 

99 logger.info(f"Model has {n} parameters...") 

100 f.write(r) 

101 return r, n 

102 

103 

104def static_information_to_csv(static_logfile_name, device, n): 

105 """ 

106 Save the static information in a csv file. 

107 

108 Parameters 

109 ---------- 

110 

111 static_logfile_name : str 

112 The static file name which is a join between the output folder and "constant.csv" 

113 

114 """ 

115 if os.path.exists(static_logfile_name): 

116 backup = static_logfile_name + "~" 

117 if os.path.exists(backup): 

118 os.unlink(backup) 

119 shutil.move(static_logfile_name, backup) 

120 with open(static_logfile_name, "w", newline="") as f: 

121 logdata = cpu_constants() 

122 if device.type == "cuda": 

123 logdata += gpu_constants() 

124 logdata += (("model_size", n),) 

125 logwriter = csv.DictWriter(f, fieldnames=[k[0] for k in logdata]) 

126 logwriter.writeheader() 

127 logwriter.writerow(dict(k for k in logdata)) 

128 

129 

130def check_exist_logfile(logfile_name, arguments): 

131 """ 

132 Check existance of logfile (trainlog.csv), 

133 If the logfile exist the and the epochs number are still 0, The logfile will be replaced. 

134 

135 Parameters 

136 ---------- 

137 

138 logfile_name : str 

139 The logfile_name which is a join between the output_folder and trainlog.csv 

140 

141 arguments : dict 

142 start and end epochs 

143 

144 """ 

145 if arguments["epoch"] == 0 and os.path.exists(logfile_name): 

146 backup = logfile_name + "~" 

147 if os.path.exists(backup): 

148 os.unlink(backup) 

149 shutil.move(logfile_name, backup) 

150 

151 

152def create_logfile_fields(valid_loader, device): 

153 """ 

154 Creation of the logfile fields that will appear in the logfile. 

155 

156 Parameters 

157 ---------- 

158 

159 valid_loader : :py:class:`torch.utils.data.DataLoader` 

160 To be used to validate the model and enable automatic checkpointing. 

161 If set to ``None``, then do not validate it. 

162 

163 device : :py:class:`torch.device` 

164 device to use 

165 

166 Returns 

167 ------- 

168 

169 logfile_fields: tuple 

170 The fields that will appear in trainlog.csv 

171 

172 

173 """ 

174 logfile_fields = ( 

175 "epoch", 

176 "total_time", 

177 "eta", 

178 "average_loss", 

179 "median_loss", 

180 "learning_rate", 

181 ) 

182 if valid_loader is not None: 

183 logfile_fields += ("validation_average_loss", "validation_median_loss") 

184 logfile_fields += tuple([k[0] for k in cpu_log()]) 

185 if device.type == "cuda": 

186 logfile_fields += tuple([k[0] for k in gpu_log()]) 

187 return logfile_fields 

188 

189 

190def train_sample_process(samples, model, optimizer, losses, device, criterion): 

191 """ 

192 Processing the training inputs (Images, ground truth, masks) and apply the backprogration to update the training losses. 

193 

194 Parameters 

195 ---------- 

196 

197 samples : list 

198 

199 model : :py:class:`torch.nn.Module` 

200 Network (e.g. driu, hed, unet) 

201 

202 optimizer : :py:mod:`torch.optim` 

203 

204 losses : :py:class:`bob.ip.binseg.utils.measure.SmoothedValue` 

205 

206 device : :py:class:`torch.device` 

207 device to use 

208 

209 criterion : :py:class:`torch.nn.modules.loss._Loss` 

210 loss function 

211 

212 Returns 

213 ------- 

214 

215 losses : :py:class:`bob.ip.binseg.utils.measure.SmoothedValue` 

216 

217 optimizer : :py:mod:`torch.optim` 

218 

219 

220 """ 

221 images = samples[1].to( 

222 device=device, non_blocking=torch.cuda.is_available() 

223 ) 

224 ground_truths = samples[2].to( 

225 device=device, non_blocking=torch.cuda.is_available() 

226 ) 

227 masks = ( 

228 torch.ones_like(ground_truths) 

229 if len(samples) < 4 

230 else samples[3].to( 

231 device=device, non_blocking=torch.cuda.is_available() 

232 ) 

233 ) 

234 outputs = model(images) 

235 loss = criterion(outputs, ground_truths, masks) 

236 optimizer.zero_grad() 

237 loss.backward() 

238 optimizer.step() 

239 losses.update(loss) 

240 logger.debug(f"batch loss: {loss.item()}") 

241 return losses, optimizer 

242 

243 

244def valid_sample_process(samples, model, valid_losses, device, criterion): 

245 

246 """ 

247 Processing the validation inputs (Images, ground truth, masks) and update validation losses. 

248 

249 Parameters 

250 ---------- 

251 

252 samples : list 

253 

254 model : :py:class:`torch.nn.Module` 

255 Network (e.g. driu, hed, unet) 

256 

257 optimizer : :py:mod:`torch.optim` 

258 

259 valid_losses : :py:class:`bob.ip.binseg.utils.measure.SmoothedValue` 

260 

261 device : :py:class:`torch.device` 

262 device to use 

263 

264 criterion : :py:class:`torch.nn.modules.loss._Loss` 

265 loss function 

266 

267 Returns 

268 ------- 

269 

270 valid_losses : :py:class:`bob.ip.binseg.utils.measure.SmoothedValue` 

271 

272 """ 

273 images = samples[1].to( 

274 device=device, 

275 non_blocking=torch.cuda.is_available(), 

276 ) 

277 ground_truths = samples[2].to( 

278 device=device, 

279 non_blocking=torch.cuda.is_available(), 

280 ) 

281 masks = ( 

282 torch.ones_like(ground_truths) 

283 if len(samples) < 4 

284 else samples[3].to( 

285 device=device, 

286 non_blocking=torch.cuda.is_available(), 

287 ) 

288 ) 

289 

290 outputs = model(images) 

291 loss = criterion(outputs, ground_truths, masks) 

292 valid_losses.update(loss) 

293 return valid_losses 

294 

295 

296def checkpointer_process( 

297 checkpointer, 

298 checkpoint_period, 

299 valid_losses, 

300 lowest_validation_loss, 

301 arguments, 

302 epoch, 

303 max_epoch, 

304): 

305 """ 

306 Process the checkpointer, save the final model and keep track of the best model. 

307 

308 Parameters 

309 ---------- 

310 

311 checkpointer : :py:class:`bob.ip.binseg.utils.checkpointer.Checkpointer` 

312 checkpointer implementation 

313 

314 checkpoint_period : int 

315 save a checkpoint every ``n`` epochs. If set to ``0`` (zero), then do 

316 not save intermediary checkpoints 

317 

318 valid_losses : :py:class:`bob.ip.binseg.utils.measure.SmoothedValue` 

319 

320 lowest_validation_loss : float 

321 Keep track of the best (lowest) validation loss 

322 

323 arguments : dict 

324 start and end epochs 

325 

326 max_epoch : int 

327 end_potch 

328 

329 

330 

331 """ 

332 if checkpoint_period and (epoch % checkpoint_period == 0): 

333 checkpointer.save(f"model_{epoch:03d}", **arguments) 

334 

335 if valid_losses is not None and valid_losses.avg < lowest_validation_loss: 

336 lowest_validation_loss = valid_losses.avg 

337 logger.info( 

338 f"Found new low on validation set:" f" {lowest_validation_loss:.6f}" 

339 ) 

340 checkpointer.save("model_lowest_valid_loss", **arguments) 

341 

342 if epoch >= max_epoch: 

343 checkpointer.save("model_final", **arguments) 

344 

345 

346def write_log_info( 

347 epoch, 

348 current_time, 

349 eta_seconds, 

350 losses, 

351 valid_losses, 

352 optimizer, 

353 logwriter, 

354 logfile, 

355 device, 

356): 

357 """ 

358 Write log info in trainlog.csv 

359 

360 Parameters 

361 ---------- 

362 

363 epoch : int 

364 Current epoch 

365 

366 current_time : float 

367 Current training time 

368 

369 eta_seconds : float 

370 estimated time-of-arrival taking into consideration previous epoch performance 

371 

372 losses : :py:class:`bob.ip.binseg.utils.measure.SmoothedValue` 

373 

374 valid_losses : :py:class:`bob.ip.binseg.utils.measure.SmoothedValue` 

375 

376 optimizer : :py:mod:`torch.optim` 

377 

378 logwriter : csv.DictWriter 

379 Dictionary writer that give the ability to write on the trainlog.csv 

380 

381 logfile: io.TextIOWrapper 

382 

383 device : :py:class:`torch.device` 

384 device to use 

385 

386 

387 

388 """ 

389 logdata = ( 

390 ("epoch", f"{epoch}"), 

391 ( 

392 "total_time", 

393 f"{datetime.timedelta(seconds=int(current_time))}", 

394 ), 

395 ("eta", f"{datetime.timedelta(seconds=int(eta_seconds))}"), 

396 ("average_loss", f"{losses.avg:.6f}"), 

397 ("median_loss", f"{losses.median:.6f}"), 

398 ("learning_rate", f"{optimizer.param_groups[0]['lr']:.6f}"), 

399 ) 

400 if valid_losses is not None: 

401 logdata += ( 

402 ("validation_average_loss", f"{valid_losses.avg:.6f}"), 

403 ("validation_median_loss", f"{valid_losses.median:.6f}"), 

404 ) 

405 logdata += cpu_log() 

406 if device.type == "cuda": 

407 logdata += gpu_log() 

408 

409 logwriter.writerow(dict(k for k in logdata)) 

410 logfile.flush() 

411 tqdm.write("|".join([f"{k}: {v}" for (k, v) in logdata[:4]])) 

412 

413 

414def run( 

415 model, 

416 data_loader, 

417 valid_loader, 

418 optimizer, 

419 criterion, 

420 scheduler, 

421 checkpointer, 

422 checkpoint_period, 

423 device, 

424 arguments, 

425 output_folder, 

426): 

427 """ 

428 Fits an FCN model using supervised learning and save it to disk. 

429 

430 This method supports periodic checkpointing and the output of a 

431 CSV-formatted log with the evolution of some figures during training. 

432 

433 

434 Parameters 

435 ---------- 

436 

437 model : :py:class:`torch.nn.Module` 

438 Network (e.g. driu, hed, unet) 

439 

440 data_loader : :py:class:`torch.utils.data.DataLoader` 

441 To be used to train the model 

442 

443 valid_loader : :py:class:`torch.utils.data.DataLoader` 

444 To be used to validate the model and enable automatic checkpointing. 

445 If set to ``None``, then do not validate it. 

446 

447 optimizer : :py:mod:`torch.optim` 

448 

449 criterion : :py:class:`torch.nn.modules.loss._Loss` 

450 loss function 

451 

452 scheduler : :py:mod:`torch.optim` 

453 learning rate scheduler 

454 

455 checkpointer : :py:class:`bob.ip.binseg.utils.checkpointer.Checkpointer` 

456 checkpointer implementation 

457 

458 checkpoint_period : int 

459 save a checkpoint every ``n`` epochs. If set to ``0`` (zero), then do 

460 not save intermediary checkpoints 

461 

462 device : :py:class:`torch.device` 

463 device to use 

464 

465 arguments : dict 

466 start and end epochs 

467 

468 output_folder : str 

469 output path 

470 """ 

471 

472 start_epoch = arguments["epoch"] 

473 max_epoch = arguments["max_epoch"] 

474 

475 check_gpu(device) 

476 

477 os.makedirs(output_folder, exist_ok=True) 

478 

479 # Save model summary 

480 r, n = save_model_summary(output_folder, model) 

481 

482 # write static information to a CSV file 

483 static_logfile_name = os.path.join(output_folder, "constants.csv") 

484 

485 static_information_to_csv(static_logfile_name, device, n) 

486 

487 # Log continous information to (another) file 

488 logfile_name = os.path.join(output_folder, "trainlog.csv") 

489 

490 check_exist_logfile(logfile_name, arguments) 

491 

492 logfile_fields = create_logfile_fields(valid_loader, device) 

493 

494 # the lowest validation loss obtained so far - this value is updated only 

495 # if a validation set is available 

496 lowest_validation_loss = sys.float_info.max 

497 

498 with open(logfile_name, "a+", newline="") as logfile: 

499 logwriter = csv.DictWriter(logfile, fieldnames=logfile_fields) 

500 

501 if arguments["epoch"] == 0: 

502 logwriter.writeheader() 

503 

504 model.train() # set training mode 

505 

506 model.to(device) # set/cast parameters to device 

507 for state in optimizer.state.values(): 

508 for k, v in state.items(): 

509 if isinstance(v, torch.Tensor): 

510 state[k] = v.to(device) 

511 

512 # Total training timer 

513 start_training_time = time.time() 

514 

515 for epoch in tqdm( 

516 range(start_epoch, max_epoch), 

517 desc="epoch", 

518 leave=False, 

519 disable=None, 

520 ): 

521 if not PYTORCH_GE_110: 

522 scheduler.step() 

523 losses = SmoothedValue(len(data_loader)) 

524 epoch = epoch + 1 

525 arguments["epoch"] = epoch 

526 

527 # Epoch time 

528 start_epoch_time = time.time() 

529 

530 # progress bar only on interactive jobs 

531 for samples in tqdm( 

532 data_loader, desc="batch", leave=False, disable=None 

533 ): 

534 # data forwarding on the existing network 

535 losses, optimizer = train_sample_process( 

536 samples, model, optimizer, losses, device, criterion 

537 ) 

538 

539 if PYTORCH_GE_110: 

540 scheduler.step() 

541 

542 # calculates the validation loss if necessary 

543 valid_losses = None 

544 if valid_loader is not None: 

545 

546 with torch.no_grad(), torch_evaluation(model): 

547 

548 valid_losses = SmoothedValue(len(valid_loader)) 

549 for samples in tqdm( 

550 valid_loader, desc="valid", leave=False, disable=None 

551 ): 

552 # data forwarding on the existing network 

553 valid_losses = valid_sample_process( 

554 samples, model, valid_losses, device, criterion 

555 ) 

556 

557 checkpointer_process( 

558 checkpointer, 

559 checkpoint_period, 

560 valid_losses, 

561 lowest_validation_loss, 

562 arguments, 

563 epoch, 

564 max_epoch, 

565 ) 

566 

567 # computes ETA (estimated time-of-arrival; end of training) taking 

568 # into consideration previous epoch performance 

569 epoch_time = time.time() - start_epoch_time 

570 eta_seconds = epoch_time * (max_epoch - epoch) 

571 current_time = time.time() - start_training_time 

572 

573 write_log_info( 

574 epoch, 

575 current_time, 

576 eta_seconds, 

577 losses, 

578 valid_losses, 

579 optimizer, 

580 logwriter, 

581 logfile, 

582 device, 

583 ) 

584 

585 total_training_time = time.time() - start_training_time 

586 logger.info( 

587 f"Total training time: {datetime.timedelta(seconds=total_training_time)} ({(total_training_time/max_epoch):.4f}s in average per epoch)" 

588 )