Coverage for src/deepdraw/script/experiment.py: 90%

49 statements  

« prev     ^ index     » next       coverage.py v7.3.1, created at 2023-11-30 15:00 +0100

1# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> 

2# 

3# SPDX-License-Identifier: GPL-3.0-or-later 

4 

5import os 

6import shutil 

7 

8import click 

9 

10from clapper.click import ConfigCommand, ResourceOption, verbosity_option 

11from clapper.logging import setup 

12 

13logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") 

14 

15from .common import save_sh_command 

16 

17 

18@click.command( 

19 entry_point_group="deepdraw.config", 

20 cls=ConfigCommand, 

21 epilog="""Examples: 

22 

23\b 

24 1. Trains an M2U-Net model (VGG-16 backbone) with DRIVE (vessel 

25 segmentation), on the CPU, for only two epochs, then runs inference and 

26 evaluation on stock datasets, report performance as a table and a figure: 

27 

28 .. code:: sh 

29 

30 $ deepdraw experiment -vv m2unet drive --epochs=2 

31""", 

32) 

33@click.option( 

34 "--output-folder", 

35 "-o", 

36 help="Path where to store experiment outputs (created if does not exist)", 

37 required=True, 

38 type=click.Path(), 

39 default="results", 

40 cls=ResourceOption, 

41) 

42@click.option( 

43 "--model", 

44 "-m", 

45 help="A torch.nn.Module instance implementing the network to be trained, and then evaluated", 

46 required=True, 

47 cls=ResourceOption, 

48) 

49@click.option( 

50 "--dataset", 

51 "-d", 

52 help="A dictionary mapping string keys to " 

53 "torch.utils.data.dataset.Dataset instances implementing datasets " 

54 "to be used for training and validating the model, possibly including all " 

55 "pre-processing pipelines required or, optionally, a dictionary mapping " 

56 "string keys to torch.utils.data.dataset.Dataset instances. At least " 

57 "one key named ``train`` must be available. This dataset will be used for " 

58 "training the network model. The dataset description must include all " 

59 "required pre-processing, including eventual data augmentation. If a " 

60 "dataset named ``__train__`` is available, it is used prioritarily for " 

61 "training instead of ``train``. If a dataset named ``__valid__`` is " 

62 "available, it is used for model validation (and automatic " 

63 "check-pointing) at each epoch. If a dataset list named " 

64 "``__valid_extra__`` is available, then it will be tracked during the " 

65 "validation process and its loss output at the training log as well, " 

66 "in the format of an array occupying a single column. All other keys " 

67 "are considered test datasets and only used during analysis, to report " 

68 "the final system performance", 

69 required=True, 

70 cls=ResourceOption, 

71) 

72@click.option( 

73 "--second-annotator", 

74 "-S", 

75 help="A dataset or dictionary, like in --dataset, with the same " 

76 "sample keys, but with annotations from a different annotator that is " 

77 "going to be compared to the one in --dataset", 

78 required=False, 

79 default=None, 

80 cls=ResourceOption, 

81 show_default=True, 

82) 

83@click.option( 

84 "--optimizer", 

85 help="A torch.optim.Optimizer that will be used to train the network", 

86 required=True, 

87 cls=ResourceOption, 

88) 

89@click.option( 

90 "--criterion", 

91 help="A loss function to compute the FCN error for every sample " 

92 "respecting the PyTorch API for loss functions (see torch.nn.modules.loss)", 

93 required=True, 

94 cls=ResourceOption, 

95) 

96@click.option( 

97 "--scheduler", 

98 help="A learning rate scheduler that drives changes in the learning " 

99 "rate depending on the FCN state (see torch.optim.lr_scheduler)", 

100 required=True, 

101 cls=ResourceOption, 

102) 

103@click.option( 

104 "--batch-size", 

105 "-b", 

106 help="Number of samples in every batch (this parameter affects " 

107 "memory requirements for the network). If the number of samples in " 

108 "the batch is larger than the total number of samples available for " 

109 "training, this value is truncated. If this number is smaller, then " 

110 "batches of the specified size are created and fed to the network " 

111 "until there are no more new samples to feed (epoch is finished). " 

112 "If the total number of training samples is not a multiple of the " 

113 "batch-size, the last batch will be smaller than the first, unless " 

114 "--drop-incomplete-batch is set, in which case this batch is not used.", 

115 required=True, 

116 show_default=True, 

117 default=2, 

118 type=click.IntRange(min=1), 

119 cls=ResourceOption, 

120) 

121@click.option( 

122 "--batch-chunk-count", 

123 "-c", 

124 help="Number of chunks in every batch (this parameter affects " 

125 "memory requirements for the network). The number of samples " 

126 "loaded for every iteration will be batch-size/batch-chunk-count. " 

127 "batch-size needs to be divisible by batch-chunk-count, otherwise an " 

128 "error will be raised. This parameter is used to reduce number of " 

129 "samples loaded in each iteration, in order to reduce the memory usage " 

130 "in exchange for processing time (more iterations). This is specially " 

131 "interesting whe one is running with GPUs with limited RAM. The " 

132 "default of 1 forces the whole batch to be processed at once. Otherwise " 

133 "the batch is broken into batch-chunk-count pieces, and gradients are " 

134 "accumulated to complete each batch.", 

135 required=True, 

136 show_default=True, 

137 default=1, 

138 type=click.IntRange(min=1), 

139 cls=ResourceOption, 

140) 

141@click.option( 

142 "--drop-incomplete-batch/--no-drop-incomplete-batch", 

143 "-D", 

144 help="If set, then may drop the last batch in an epoch, in case it is " 

145 "incomplete. If you set this option, you should also consider " 

146 "increasing the total number of epochs of training, as the total number " 

147 "of training steps may be reduced", 

148 required=True, 

149 show_default=True, 

150 default=False, 

151 cls=ResourceOption, 

152) 

153@click.option( 

154 "--epochs", 

155 "-e", 

156 help="Number of epochs (complete training set passes) to train for. " 

157 "If continuing from a saved checkpoint, ensure to provide a greater " 

158 "number of epochs than that saved on the checkpoint to be loaded. ", 

159 show_default=True, 

160 required=True, 

161 default=1000, 

162 type=click.IntRange(min=1), 

163 cls=ResourceOption, 

164) 

165@click.option( 

166 "--checkpoint-period", 

167 "-p", 

168 help="Number of epochs after which a checkpoint is saved. " 

169 "A value of zero will disable check-pointing. If checkpointing is " 

170 "enabled and training stops, it is automatically resumed from the " 

171 "last saved checkpoint if training is restarted with the same " 

172 "configuration.", 

173 show_default=True, 

174 required=True, 

175 default=0, 

176 type=click.IntRange(min=0), 

177 cls=ResourceOption, 

178) 

179@click.option( 

180 "--device", 

181 "-d", 

182 help='A string indicating the device to use (e.g. "cpu" or "cuda:0")', 

183 show_default=True, 

184 required=True, 

185 default="cpu", 

186 cls=ResourceOption, 

187) 

188@click.option( 

189 "--seed", 

190 "-s", 

191 help="Seed to use for the random number generator", 

192 show_default=True, 

193 required=False, 

194 default=42, 

195 type=click.IntRange(min=0), 

196 cls=ResourceOption, 

197) 

198@click.option( 

199 "--parallel", 

200 "-P", 

201 help="""Use multiprocessing for data loading and processing: if set to -1 

202 (default), disables multiprocessing altogether. Set to 0 to enable as many 

203 data loading instances as processing cores as available in the system. Set 

204 to >= 1 to enable that many multiprocessing instances for data 

205 processing.""", 

206 type=click.IntRange(min=-1), 

207 show_default=True, 

208 required=True, 

209 default=-1, 

210 cls=ResourceOption, 

211) 

212@click.option( 

213 "--monitoring-interval", 

214 "-I", 

215 help="""Time between checks for the use of resources during each training 

216 epoch. An interval of 5 seconds, for example, will lead to CPU and GPU 

217 resources being probed every 5 seconds during each training epoch. 

218 Values registered in the training logs correspond to averages (or maxima) 

219 observed through possibly many probes in each epoch. Notice that setting a 

220 very small value may cause the probing process to become extremely busy, 

221 potentially biasing the overall perception of resource usage.""", 

222 type=click.FloatRange(min=0.1), 

223 show_default=True, 

224 required=True, 

225 default=5.0, 

226 cls=ResourceOption, 

227) 

228@click.option( 

229 "--overlayed/--no-overlayed", 

230 "-O", 

231 help="Creates overlayed representations of the output probability maps, " 

232 "similar to --overlayed in prediction-mode, except it includes " 

233 "distinctive colours for true and false positives and false negatives. " 

234 "If not set, or empty then do **NOT** output overlayed images.", 

235 show_default=True, 

236 default=False, 

237 required=False, 

238 cls=ResourceOption, 

239) 

240@click.option( 

241 "--steps", 

242 "-S", 

243 help="This number is used to define the number of threshold steps to " 

244 "consider when evaluating the highest possible F1-score on test data.", 

245 default=1000, 

246 show_default=True, 

247 required=True, 

248 cls=ResourceOption, 

249) 

250@click.option( 

251 "--plot-limits", 

252 "-L", 

253 help="""If set, this option affects the performance comparison plots. It 

254 must be a 4-tuple containing the bounds of the plot for the x and y axis 

255 respectively (format: x_low, x_high, y_low, y_high]). If not set, use 

256 normal bounds ([0, 1, 0, 1]) for the performance curve.""", 

257 default=[0.0, 1.0, 0.0, 1.0], 

258 show_default=True, 

259 nargs=4, 

260 type=float, 

261 cls=ResourceOption, 

262) 

263@verbosity_option(logger=logger, cls=ResourceOption) 

264@click.pass_context 

265def experiment( 

266 ctx, 

267 model, 

268 optimizer, 

269 scheduler, 

270 output_folder, 

271 epochs, 

272 batch_size, 

273 batch_chunk_count, 

274 drop_incomplete_batch, 

275 criterion, 

276 dataset, 

277 second_annotator, 

278 checkpoint_period, 

279 device, 

280 seed, 

281 parallel, 

282 monitoring_interval, 

283 overlayed, 

284 steps, 

285 plot_limits, 

286 verbose, 

287 **kwargs, 

288): 

289 """Runs a complete experiment, from training, to prediction and evaluation. 

290 

291 This script is just a wrapper around the individual scripts for training, 

292 running prediction, evaluating and comparing FCN model performance. It 

293 organises the output in a preset way:: 

294 

295 \b 

296 └─ <output-folder>/ 

297 ├── model/ #the generated model will be here 

298 ├── predictions/ #the prediction outputs for the train/test set 

299 ├── overlayed/ #the overlayed outputs for the train/test set 

300 ├── predictions/ #predictions overlayed on the input images 

301 ├── analysis/ #predictions overlayed on the input images 

302 ├ #including analysis of false positives, negatives 

303 ├ #and true positives 

304 └── second-annotator/ #if set, store overlayed images for the 

305 #second annotator here 

306 └── analysis / #the outputs of the analysis of both train/test sets 

307 #includes second-annotator "mesures" as well, if 

308 # configured 

309 

310 Training is performed for a configurable number of epochs, and generates at 

311 least a final_model.pth. It may also generate a number of intermediate 

312 checkpoints. Checkpoints are model files (.pth files) that are stored 

313 during the training and useful to resume the procedure in case it stops 

314 abruptly. 

315 

316 N.B.: The tool is designed to prevent analysis bias and allows one to 

317 provide (potentially multiple) separate subsets for training, 

318 validation, and evaluation. Instead of using simple datasets, datasets 

319 for full experiment running should be dictionaries with specific subset 

320 names: 

321 

322 * ``__train__``: dataset used for training, prioritarily. It is typically 

323 the dataset containing data augmentation pipelines. 

324 * ``__valid__``: dataset used for validation. It is typically disjoint 

325 from the training and test sets. In such a case, we checkpoint the model 

326 with the lowest loss on the validation set as well, throughout all the 

327 training, besides the model at the end of training. 

328 * ``train`` (optional): a copy of the ``__train__`` dataset, without data 

329 augmentation, that will be evaluated alongside other sets available 

330 * ``__valid_extra__``: a list of datasets that are tracked during 

331 validation, but do not affect checkpoiting. If present, an extra 

332 column with an array containing the loss of each set is kept on the 

333 training log. 

334 * ``*``: any other name, not starting with an underscore character (``_``), 

335 will be considered a test set for evaluation. 

336 

337 N.B.2: The threshold used for calculating the F1-score on the test set, or 

338 overlay analysis (false positives, negatives and true positives overprinted 

339 on the original image) also follows the logic above. 

340 """ 

341 

342 command_sh = os.path.join(output_folder, "command.sh") 

343 if os.path.exists(command_sh): 

344 backup = command_sh + "~" 

345 if os.path.exists(backup): 

346 os.unlink(backup) 

347 shutil.move(command_sh, backup) 

348 save_sh_command(command_sh) 

349 

350 # training 

351 logger.info("Started training") 

352 

353 from .train import train 

354 

355 train_output_folder = os.path.join(output_folder, "model") 

356 ctx.invoke( 

357 train, 

358 model=model, 

359 optimizer=optimizer, 

360 scheduler=scheduler, 

361 output_folder=train_output_folder, 

362 epochs=epochs, 

363 batch_size=batch_size, 

364 batch_chunk_count=batch_chunk_count, 

365 drop_incomplete_batch=drop_incomplete_batch, 

366 criterion=criterion, 

367 dataset=dataset, 

368 checkpoint_period=checkpoint_period, 

369 device=device, 

370 seed=seed, 

371 parallel=parallel, 

372 monitoring_interval=monitoring_interval, 

373 verbose=verbose, 

374 ) 

375 logger.info("Ended training") 

376 

377 from .train_analysis import train_analysis 

378 

379 ctx.invoke( 

380 train_analysis, 

381 log=os.path.join(train_output_folder, "trainlog.csv"), 

382 constants=os.path.join(train_output_folder, "constants.csv"), 

383 output_pdf=os.path.join(train_output_folder, "trainlog.pdf"), 

384 verbose=verbose, 

385 ) 

386 

387 from .analyze import analyze 

388 

389 # preferably, we use the best model on the validation set 

390 # otherwise, we get the last saved model 

391 model_file = os.path.join( 

392 train_output_folder, "model_lowest_valid_loss.pth" 

393 ) 

394 if not os.path.exists(model_file): 

395 model_file = os.path.join(train_output_folder, "model_final_epoch.pth") 

396 

397 ctx.invoke( 

398 analyze, 

399 model=model, 

400 output_folder=output_folder, 

401 batch_size=batch_size, 

402 dataset=dataset, 

403 second_annotator=second_annotator, 

404 device=device, 

405 overlayed=overlayed, 

406 weight=model_file, 

407 steps=steps, 

408 parallel=parallel, 

409 plot_limits=plot_limits, 

410 verbose=verbose, 

411 )