Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1#!/usr/bin/env python 

2# -*- coding: utf-8 -*- 

3 

4import csv 

5import datetime 

6import logging 

7import os 

8import shutil 

9import sys 

10import time 

11 

12import numpy 

13import torch 

14 

15from tqdm import tqdm 

16 

17from ..utils.measure import SmoothedValue 

18from ..utils.resources import cpu_constants, cpu_log, gpu_constants, gpu_log 

19from ..utils.summary import summary 

20from .trainer import PYTORCH_GE_110, torch_evaluation 

21 

22logger = logging.getLogger(__name__) 

23 

24 

25def sharpen(x, T): 

26 temp = x ** (1 / T) 

27 return temp / temp.sum(dim=1, keepdim=True) 

28 

29 

30def mix_up(alpha, input, target, unlabelled_input, unlabled_target): 

31 """Applies mix up as described in [MIXMATCH_19]. 

32 

33 Parameters 

34 ---------- 

35 alpha : float 

36 

37 input : :py:class:`torch.Tensor` 

38 

39 target : :py:class:`torch.Tensor` 

40 

41 unlabelled_input : :py:class:`torch.Tensor` 

42 

43 unlabled_target : :py:class:`torch.Tensor` 

44 

45 

46 Returns 

47 ------- 

48 

49 list 

50 

51 """ 

52 

53 with torch.no_grad(): 

54 l = numpy.random.beta(alpha, alpha) # Eq (8) 

55 l = max(l, 1 - l) # Eq (9) 

56 # Shuffle and concat. Alg. 1 Line: 12 

57 w_inputs = torch.cat([input, unlabelled_input], 0) 

58 w_targets = torch.cat([target, unlabled_target], 0) 

59 idx = torch.randperm(w_inputs.size(0)) # get random index 

60 

61 # Apply MixUp to labelled data and entries from W. Alg. 1 Line: 13 

62 input_mixedup = l * input + (1 - l) * w_inputs[idx[len(input) :]] 

63 target_mixedup = l * target + (1 - l) * w_targets[idx[len(target) :]] 

64 

65 # Apply MixUp to unlabelled data and entries from W. Alg. 1 Line: 14 

66 unlabelled_input_mixedup = ( 

67 l * unlabelled_input 

68 + (1 - l) * w_inputs[idx[: len(unlabelled_input)]] 

69 ) 

70 unlabled_target_mixedup = ( 

71 l * unlabled_target 

72 + (1 - l) * w_targets[idx[: len(unlabled_target)]] 

73 ) 

74 return ( 

75 input_mixedup, 

76 target_mixedup, 

77 unlabelled_input_mixedup, 

78 unlabled_target_mixedup, 

79 ) 

80 

81 

82def square_rampup(current, rampup_length=16): 

83 """slowly ramp-up ``lambda_u`` 

84 

85 Parameters 

86 ---------- 

87 

88 current : int 

89 current epoch 

90 

91 rampup_length : :obj:`int`, optional 

92 how long to ramp up, by default 16 

93 

94 Returns 

95 ------- 

96 

97 factor : float 

98 ramp up factor 

99 """ 

100 

101 if rampup_length == 0: 

102 return 1.0 

103 else: 

104 current = numpy.clip((current / float(rampup_length)) ** 2, 0.0, 1.0) 

105 return float(current) 

106 

107 

108def linear_rampup(current, rampup_length=16): 

109 """slowly ramp-up ``lambda_u`` 

110 

111 Parameters 

112 ---------- 

113 current : int 

114 current epoch 

115 

116 rampup_length : :obj:`int`, optional 

117 how long to ramp up, by default 16 

118 

119 Returns 

120 ------- 

121 

122 factor: float 

123 ramp up factor 

124 

125 """ 

126 if rampup_length == 0: 

127 return 1.0 

128 else: 

129 current = numpy.clip(current / rampup_length, 0.0, 1.0) 

130 return float(current) 

131 

132 

133def guess_labels(unlabelled_images, model): 

134 """ 

135 Calculate the average predictions by 2 augmentations: horizontal and vertical flips 

136 

137 Parameters 

138 ---------- 

139 

140 unlabelled_images : :py:class:`torch.Tensor` 

141 ``[n,c,h,w]`` 

142 

143 target : :py:class:`torch.Tensor` 

144 

145 Returns 

146 ------- 

147 

148 shape : :py:class:`torch.Tensor` 

149 ``[n,c,h,w]`` 

150 

151 """ 

152 with torch.no_grad(): 

153 guess1 = torch.sigmoid(model(unlabelled_images)).unsqueeze(0) 

154 # Horizontal flip and unsqueeze to work with batches (increase flip dimension by 1) 

155 hflip = torch.sigmoid(model(unlabelled_images.flip(2))).unsqueeze(0) 

156 guess2 = hflip.flip(3) 

157 # Vertical flip and unsqueeze to work with batches (increase flip dimension by 1) 

158 vflip = torch.sigmoid(model(unlabelled_images.flip(3))).unsqueeze(0) 

159 guess3 = vflip.flip(4) 

160 # Concat 

161 concat = torch.cat([guess1, guess2, guess3], 0) 

162 avg_guess = torch.mean(concat, 0) 

163 return avg_guess 

164 

165 

166def run( 

167 model, 

168 data_loader, 

169 valid_loader, 

170 optimizer, 

171 criterion, 

172 scheduler, 

173 checkpointer, 

174 checkpoint_period, 

175 device, 

176 arguments, 

177 output_folder, 

178 rampup_length, 

179): 

180 """ 

181 Fits an FCN model using semi-supervised learning and saves it to disk. 

182 

183 

184 This method supports periodic checkpointing and the output of a 

185 CSV-formatted log with the evolution of some figures during training. 

186 

187 

188 Parameters 

189 ---------- 

190 

191 model : :py:class:`torch.nn.Module` 

192 Network (e.g. driu, hed, unet) 

193 

194 data_loader : :py:class:`torch.utils.data.DataLoader` 

195 To be used to train the model 

196 

197 valid_loader : :py:class:`torch.utils.data.DataLoader` 

198 To be used to validate the model and enable automatic checkpointing. 

199 If set to ``None``, then do not validate it. 

200 

201 optimizer : :py:mod:`torch.optim` 

202 

203 criterion : :py:class:`torch.nn.modules.loss._Loss` 

204 loss function 

205 

206 scheduler : :py:mod:`torch.optim` 

207 learning rate scheduler 

208 

209 checkpointer : :py:class:`bob.ip.binseg.utils.checkpointer.Checkpointer` 

210 checkpointer implementation 

211 

212 checkpoint_period : int 

213 save a checkpoint every ``n`` epochs. If set to ``0`` (zero), then do 

214 not save intermediary checkpoints 

215 

216 device : str 

217 device to use ``'cpu'`` or ``cuda:0`` 

218 

219 arguments : dict 

220 start and end epochs 

221 

222 output_folder : str 

223 output path 

224 

225 rampup_length : int 

226 rampup epochs 

227 

228 """ 

229 

230 start_epoch = arguments["epoch"] 

231 max_epoch = arguments["max_epoch"] 

232 

233 if device != "cpu": 

234 # asserts we do have a GPU 

235 assert bool(gpu_constants()), ( 

236 f"Device set to '{device}', but cannot " 

237 f"find a GPU (maybe nvidia-smi is not installed?)" 

238 ) 

239 

240 os.makedirs(output_folder, exist_ok=True) 

241 

242 # Save model summary 

243 summary_path = os.path.join(output_folder, "model_summary.txt") 

244 logger.info(f"Saving model summary at {summary_path}...") 

245 with open(summary_path, "wt") as f: 

246 r, n = summary(model) 

247 logger.info(f"Model has {n} parameters...") 

248 f.write(r) 

249 

250 # write static information to a CSV file 

251 static_logfile_name = os.path.join(output_folder, "constants.csv") 

252 if os.path.exists(static_logfile_name): 

253 backup = static_logfile_name + "~" 

254 if os.path.exists(backup): 

255 os.unlink(backup) 

256 shutil.move(static_logfile_name, backup) 

257 with open(static_logfile_name, "w", newline="") as f: 

258 logdata = cpu_constants() 

259 if device != "cpu": 

260 logdata += gpu_constants() 

261 logdata += (("model_size", n),) 

262 logwriter = csv.DictWriter(f, fieldnames=[k[0] for k in logdata]) 

263 logwriter.writeheader() 

264 logwriter.writerow(dict(k for k in logdata)) 

265 

266 # Log continous information to (another) file 

267 logfile_name = os.path.join(output_folder, "trainlog.csv") 

268 

269 if arguments["epoch"] == 0 and os.path.exists(logfile_name): 

270 backup = logfile_name + "~" 

271 if os.path.exists(backup): 

272 os.unlink(backup) 

273 shutil.move(logfile_name, backup) 

274 

275 logfile_fields = ( 

276 "epoch", 

277 "total_time", 

278 "eta", 

279 "average_loss", 

280 "median_loss", 

281 "labelled_median_loss", 

282 "unlabelled_median_loss", 

283 "learning_rate", 

284 ) 

285 if valid_loader is not None: 

286 logfile_fields += ("validation_average_loss", "validation_median_loss") 

287 logfile_fields += tuple([k[0] for k in cpu_log()]) 

288 if device != "cpu": 

289 logfile_fields += tuple([k[0] for k in gpu_log()]) 

290 

291 # the lowest validation loss obtained so far - this value is updated only 

292 # if a validation set is available 

293 lowest_validation_loss = sys.float_info.max 

294 

295 with open(logfile_name, "a+", newline="") as logfile: 

296 logwriter = csv.DictWriter(logfile, fieldnames=logfile_fields) 

297 

298 if arguments["epoch"] == 0: 

299 logwriter.writeheader() 

300 

301 model.train() # set training mode 

302 

303 model.to(device) # set/cast parameters to device 

304 for state in optimizer.state.values(): 

305 for k, v in state.items(): 

306 if isinstance(v, torch.Tensor): 

307 state[k] = v.to(device) 

308 

309 # Total training timer 

310 start_training_time = time.time() 

311 

312 for epoch in tqdm( 

313 range(start_epoch, max_epoch), 

314 desc="epoch", 

315 leave=False, 

316 disable=None, 

317 ): 

318 if not PYTORCH_GE_110: 

319 scheduler.step() 

320 losses = SmoothedValue(len(data_loader)) 

321 labelled_loss = SmoothedValue(len(data_loader)) 

322 unlabelled_loss = SmoothedValue(len(data_loader)) 

323 epoch = epoch + 1 

324 arguments["epoch"] = epoch 

325 

326 # Epoch time 

327 start_epoch_time = time.time() 

328 

329 # progress bar only on interactive jobs 

330 for samples in tqdm( 

331 data_loader, desc="batch", leave=False, disable=None 

332 ): 

333 

334 # data forwarding on the existing network 

335 

336 # labelled 

337 images = samples[1].to( 

338 device=device, non_blocking=torch.cuda.is_available() 

339 ) 

340 ground_truths = samples[2].to( 

341 device=device, non_blocking=torch.cuda.is_available() 

342 ) 

343 unlabelled_images = samples[4].to( 

344 device=device, non_blocking=torch.cuda.is_available() 

345 ) 

346 # labelled outputs 

347 outputs = model(images) 

348 unlabelled_outputs = model(unlabelled_images) 

349 # guessed unlabelled outputs 

350 unlabelled_ground_truths = guess_labels( 

351 unlabelled_images, model 

352 ) 

353 

354 # loss evaluation and learning (backward step) 

355 ramp_up_factor = square_rampup( 

356 epoch, rampup_length=rampup_length 

357 ) 

358 

359 # note: no support for masks... 

360 loss, ll, ul = criterion( 

361 outputs, 

362 ground_truths, 

363 unlabelled_outputs, 

364 unlabelled_ground_truths, 

365 ramp_up_factor, 

366 ) 

367 optimizer.zero_grad() 

368 loss.backward() 

369 optimizer.step() 

370 losses.update(loss) 

371 labelled_loss.update(ll) 

372 unlabelled_loss.update(ul) 

373 logger.debug(f"batch loss: {loss.item()}") 

374 

375 if PYTORCH_GE_110: 

376 scheduler.step() 

377 

378 # calculates the validation loss if necessary 

379 # note: validation does not comprise "unlabelled" losses 

380 valid_losses = None 

381 if valid_loader is not None: 

382 

383 with torch.no_grad(), torch_evaluation(model): 

384 

385 valid_losses = SmoothedValue(len(valid_loader)) 

386 for samples in tqdm( 

387 valid_loader, desc="valid", leave=False, disable=None 

388 ): 

389 # data forwarding on the existing network 

390 images = samples[1].to( 

391 device=device, 

392 non_blocking=torch.cuda.is_available(), 

393 ) 

394 ground_truths = samples[2].to( 

395 device=device, 

396 non_blocking=torch.cuda.is_available(), 

397 ) 

398 masks = ( 

399 torch.ones_like(ground_truths) 

400 if len(samples) < 4 

401 else samples[3].to( 

402 device=device, 

403 non_blocking=torch.cuda.is_available(), 

404 ) 

405 ) 

406 

407 outputs = model(images) 

408 loss = criterion(outputs, ground_truths, masks) 

409 valid_losses.update(loss) 

410 

411 if checkpoint_period and (epoch % checkpoint_period == 0): 

412 checkpointer.save(f"model_{epoch:03d}", **arguments) 

413 

414 if ( 

415 valid_losses is not None 

416 and valid_losses.avg < lowest_validation_loss 

417 ): 

418 lowest_validation_loss = valid_losses.avg 

419 logger.info( 

420 f"Found new low on validation set:" 

421 f" {lowest_validation_loss:.6f}" 

422 ) 

423 checkpointer.save("model_lowest_valid_loss", **arguments) 

424 

425 if epoch >= max_epoch: 

426 checkpointer.save("model_final", **arguments) 

427 

428 # computes ETA (estimated time-of-arrival; end of training) taking 

429 # into consideration previous epoch performance 

430 epoch_time = time.time() - start_epoch_time 

431 eta_seconds = epoch_time * (max_epoch - epoch) 

432 current_time = time.time() - start_training_time 

433 

434 logdata = ( 

435 ("epoch", f"{epoch}"), 

436 ( 

437 "total_time", 

438 f"{datetime.timedelta(seconds=int(current_time))}", 

439 ), 

440 ("eta", f"{datetime.timedelta(seconds=int(eta_seconds))}"), 

441 ("average_loss", f"{losses.avg:.6f}"), 

442 ("median_loss", f"{losses.median:.6f}"), 

443 ("labelled_median_loss", f"{labelled_loss.median:.6f}"), 

444 ("unlabelled_median_loss", f"{unlabelled_loss.median:.6f}"), 

445 ("learning_rate", f"{optimizer.param_groups[0]['lr']:.6f}"), 

446 ) 

447 if valid_losses is not None: 

448 logdata += ( 

449 ("validation_average_loss", f"{valid_losses.avg:.6f}"), 

450 ("validation_median_loss", f"{valid_losses.median:.6f}"), 

451 ) 

452 logdata += cpu_log() 

453 if device != "cpu": 

454 logdata += gpu_log() 

455 

456 if device != "cpu": 

457 logdata += gpu_log() 

458 

459 logwriter.writerow(dict(k for k in logdata)) 

460 logfile.flush() 

461 tqdm.write("|".join([f"{k}: {v}" for (k, v) in logdata[:4]])) 

462 

463 total_training_time = time.time() - start_training_time 

464 logger.info( 

465 f"Total training time: {datetime.timedelta(seconds=total_training_time)} ({(total_training_time/max_epoch):.4f}s in average per epoch)" 

466 )