Coverage for /scratch/builds/bob/bob.med.tb/miniconda/conda-bld/bob.med.tb_1637571489937/_test_env_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placeho/lib/python3.8/site-packages/bob/med/tb/engine/trainer.py: 93%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

124 statements  

1#!/usr/bin/env python 

2# -*- coding: utf-8 -*- 

3 

4import os 

5import sys 

6import csv 

7import time 

8import shutil 

9import datetime 

10import contextlib 

11import distutils.version 

12 

13import torch 

14from tqdm import tqdm 

15 

16from ..utils.measure import SmoothedValue 

17from ..utils.summary import summary 

18from ..utils.resources import cpu_constants, gpu_constants, cpu_log, gpu_log 

19 

20import logging 

21 

22logger = logging.getLogger(__name__) 

23 

24PYTORCH_GE_110 = distutils.version.StrictVersion(torch.__version__) >= "1.1.0" 

25 

26 

27@contextlib.contextmanager 

28def torch_evaluation(model): 

29 """Context manager to turn ON/OFF model evaluation 

30 

31 This context manager will turn evaluation mode ON on entry and turn it OFF 

32 when exiting the ``with`` statement block. 

33 

34 

35 Parameters 

36 ---------- 

37 

38 model : :py:class:`torch.nn.Module` 

39 Network (e.g. driu, hed, unet) 

40 

41 

42 Yields 

43 ------ 

44 

45 model : :py:class:`torch.nn.Module` 

46 Network (e.g. driu, hed, unet) 

47 

48 """ 

49 

50 model.eval() 

51 yield model 

52 model.train() 

53 

54 

55def run( 

56 model, 

57 data_loader, 

58 valid_loader, 

59 optimizer, 

60 criterion, 

61 checkpointer, 

62 checkpoint_period, 

63 device, 

64 arguments, 

65 output_folder, 

66 criterion_valid = None 

67): 

68 """ 

69 Fits a CNN model using supervised learning and save it to disk. 

70 

71 This method supports periodic checkpointing and the output of a 

72 CSV-formatted log with the evolution of some figures during training. 

73 

74 

75 Parameters 

76 ---------- 

77 

78 model : :py:class:`torch.nn.Module` 

79 Network (e.g. pasa) 

80 

81 data_loader : :py:class:`torch.utils.data.DataLoader` 

82 To be used to train the model 

83 

84 valid_loader : :py:class:`torch.utils.data.DataLoader` 

85 To be used to validate the model and enable automatic checkpointing. 

86 If set to ``None``, then do not validate it. 

87 

88 optimizer : :py:mod:`torch.optim` 

89 

90 criterion : :py:class:`torch.nn.modules.loss._Loss` 

91 loss function 

92 

93 checkpointer : :py:class:`bob.med.tb.utils.checkpointer.Checkpointer` 

94 checkpointer implementation 

95 

96 checkpoint_period : int 

97 save a checkpoint every ``n`` epochs. If set to ``0`` (zero), then do 

98 not save intermediary checkpoints 

99 

100 device : str 

101 device to use ``'cpu'`` or ``cuda:0`` 

102 

103 arguments : dict 

104 start and end epochs 

105 

106 output_folder : str 

107 output path 

108 

109 criterion_valid : :py:class:`torch.nn.modules.loss._Loss` 

110 specific loss function for the validation set 

111 """ 

112 

113 start_epoch = arguments["epoch"] 

114 max_epoch = arguments["max_epoch"] 

115 

116 if device != "cpu": 

117 # asserts we do have a GPU 

118 assert bool(gpu_constants()), ( 

119 f"Device set to '{device}', but cannot " 

120 f"find a GPU (maybe nvidia-smi is not installed?)" 

121 ) 

122 

123 os.makedirs(output_folder, exist_ok=True) 

124 

125 # Save model summary 

126 summary_path = os.path.join(output_folder, "model_summary.txt") 

127 logger.info(f"Saving model summary at {summary_path}...") 

128 with open(summary_path, "wt") as f: 

129 r, n = summary(model) 

130 logger.info(f"Model has {n} parameters...") 

131 f.write(r) 

132 

133 # write static information to a CSV file 

134 static_logfile_name = os.path.join(output_folder, "constants.csv") 

135 if os.path.exists(static_logfile_name): 

136 backup = static_logfile_name + "~" 

137 if os.path.exists(backup): 

138 os.unlink(backup) 

139 shutil.move(static_logfile_name, backup) 

140 with open(static_logfile_name, "w", newline="") as f: 

141 logdata = cpu_constants() 

142 if device != "cpu": 

143 logdata += gpu_constants() 

144 logdata += (("model_size", n),) 

145 logwriter = csv.DictWriter(f, fieldnames=[k[0] for k in logdata]) 

146 logwriter.writeheader() 

147 logwriter.writerow(dict(k for k in logdata)) 

148 

149 # Log continous information to (another) file 

150 logfile_name = os.path.join(output_folder, "trainlog.csv") 

151 

152 if arguments["epoch"] == 0 and os.path.exists(logfile_name): 

153 backup = logfile_name + "~" 

154 if os.path.exists(backup): 

155 os.unlink(backup) 

156 shutil.move(logfile_name, backup) 

157 

158 logfile_fields = ( 

159 "epoch", 

160 "total_time", 

161 "eta", 

162 "average_loss", 

163 "median_loss", 

164 "learning_rate", 

165 ) 

166 if valid_loader is not None: 

167 logfile_fields += ("validation_average_loss", "validation_median_loss") 

168 logfile_fields += tuple([k[0] for k in cpu_log()]) 

169 if device != "cpu": 

170 logfile_fields += tuple([k[0] for k in gpu_log()]) 

171 

172 # the lowest validation loss obtained so far - this value is updated only 

173 # if a validation set is available 

174 lowest_validation_loss = sys.float_info.max 

175 

176 with open(logfile_name, "a+", newline="") as logfile: 

177 logwriter = csv.DictWriter(logfile, fieldnames=logfile_fields) 

178 

179 if arguments["epoch"] == 0: 

180 logwriter.writeheader() 

181 

182 model.train() # set training mode 

183 

184 model.to(device) # set/cast parameters to device 

185 for state in optimizer.state.values(): 

186 for k, v in state.items(): 

187 if isinstance(v, torch.Tensor): 

188 state[k] = v.to(device) 

189 

190 # Total training timer 

191 start_training_time = time.time() 

192 

193 for epoch in tqdm( 

194 range(start_epoch, max_epoch), 

195 desc="epoch", 

196 leave=False, 

197 disable=None, 

198 ): 

199 losses = SmoothedValue(len(data_loader)) 

200 epoch = epoch + 1 

201 arguments["epoch"] = epoch 

202 

203 # Epoch time 

204 start_epoch_time = time.time() 

205 

206 # progress bar only on interactive jobs 

207 for samples in tqdm( 

208 data_loader, desc="batch", leave=False, disable=None 

209 ): 

210 

211 # data forwarding on the existing network 

212 images = samples[1].to( 

213 device=device, non_blocking=torch.cuda.is_available() 

214 ) 

215 labels = samples[2].to( 

216 device=device, non_blocking=torch.cuda.is_available() 

217 ) 

218 

219 # Increase labels dimension if too low 

220 # Allows single and multiclass usage 

221 if labels.ndim == 1: 

222 labels = torch.reshape(labels, (labels.shape[0], 1)) 

223 

224 outputs = model(images) 

225 

226 # loss evaluation and learning (backward step) 

227 loss = criterion(outputs, labels.double()) 

228 optimizer.zero_grad() 

229 loss.backward() 

230 optimizer.step() 

231 

232 losses.update(loss) 

233 logger.debug(f"batch loss: {loss.item()}") 

234 

235 

236 # calculates the validation loss if necessary 

237 valid_losses = None 

238 if valid_loader is not None: 

239 

240 with torch.no_grad(), torch_evaluation(model): 

241 

242 valid_losses = SmoothedValue(len(valid_loader)) 

243 for samples in tqdm( 

244 valid_loader, desc="valid", leave=False, disable=None 

245 ): 

246 # data forwarding on the existing network 

247 images = samples[1].to( 

248 device=device, 

249 non_blocking=torch.cuda.is_available(), 

250 ) 

251 labels = samples[2].to( 

252 device=device, 

253 non_blocking=torch.cuda.is_available(), 

254 ) 

255 

256 # Increase labels dimension if too low 

257 # Allows single and multiclass usage 

258 if labels.ndim == 1: 

259 labels = torch.reshape(labels, (labels.shape[0], 1)) 

260 

261 outputs = model(images) 

262 

263 if criterion_valid is not None: 

264 loss = criterion_valid(outputs, labels.double()) 

265 else: 

266 loss = criterion(outputs, labels.double()) 

267 valid_losses.update(loss) 

268 

269 if checkpoint_period and (epoch % checkpoint_period == 0): 

270 checkpointer.save(f"model_{epoch:03d}", **arguments) 

271 

272 if ( 

273 valid_losses is not None 

274 and valid_losses.avg < lowest_validation_loss 

275 ): 

276 lowest_validation_loss = valid_losses.avg 

277 logger.info( 

278 f"Found new low on validation set:" 

279 f" {lowest_validation_loss:.6f}" 

280 ) 

281 checkpointer.save(f"model_lowest_valid_loss", **arguments) 

282 

283 if epoch >= max_epoch: 

284 checkpointer.save("model_final", **arguments) 

285 

286 # computes ETA (estimated time-of-arrival; end of training) taking 

287 # into consideration previous epoch performance 

288 epoch_time = time.time() - start_epoch_time 

289 eta_seconds = epoch_time * (max_epoch - epoch) 

290 current_time = time.time() - start_training_time 

291 

292 logdata = ( 

293 ("epoch", f"{epoch}"), 

294 ( 

295 "total_time", 

296 f"{datetime.timedelta(seconds=int(current_time))}", 

297 ), 

298 ("eta", f"{datetime.timedelta(seconds=int(eta_seconds))}"), 

299 ("average_loss", f"{losses.avg:.6f}"), 

300 ("median_loss", f"{losses.median:.6f}"), 

301 ("learning_rate", f"{optimizer.param_groups[0]['lr']:.6f}"), 

302 ) 

303 if valid_losses is not None: 

304 logdata += ( 

305 ("validation_average_loss", f"{valid_losses.avg:.6f}"), 

306 ("validation_median_loss", f"{valid_losses.median:.6f}"), 

307 ) 

308 logdata += cpu_log() 

309 if device != "cpu": 

310 logdata += gpu_log() 

311 

312 logwriter.writerow(dict(k for k in logdata)) 

313 logfile.flush() 

314 tqdm.write("|".join([f"{k}: {v}" for (k, v) in logdata[:4]])) 

315 

316 total_training_time = time.time() - start_training_time 

317 logger.info( 

318 f"Total training time: {datetime.timedelta(seconds=total_training_time)} ({(total_training_time/max_epoch):.4f}s in average per epoch)" 

319 )