Coverage for src/deepdraw/script/train.py: 96%

74 statements  

« prev     ^ index     » next       coverage.py v7.3.1, created at 2023-11-30 15:00 +0100

1# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> 

2# 

3# SPDX-License-Identifier: GPL-3.0-or-later 

4 

5import multiprocessing 

6import sys 

7 

8import click 

9import torch 

10 

11from clapper.click import ConfigCommand, ResourceOption, verbosity_option 

12from clapper.logging import setup 

13from torch.utils.data import DataLoader 

14 

15logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") 

16 

17from ..engine.trainer import run 

18from ..utils.checkpointer import Checkpointer 

19from .common import set_seeds, setup_pytorch_device 

20 

21 

22@click.command( 

23 entry_point_group="deepdraw.config", 

24 cls=ConfigCommand, 

25 epilog="""Examples: 

26 

27\b 

28 1. Trains a U-Net model (VGG-16 backbone) with DRIVE (vessel segmentation), 

29 on a GPU (``cuda:0``): 

30 

31 .. code:: sh 

32 

33 $ deepdraw train -vv unet drive --batch-size=4 --device="cuda:0" 

34 

35 

36\b 

37 2. Trains a HED model with HRF on a GPU (``cuda:0``): 

38 

39 .. code:: sh 

40 

41 $ deepdraw train -vv hed hrf --batch-size=8 --device="cuda:0" 

42 

43 

44\b 

45 3. Trains a M2U-Net model on the COVD-DRIVE dataset on the CPU: 

46 

47 .. code:: sh 

48 

49 $ deepdraw train -vv m2unet covd-drive --batch-size=8 

50""", 

51) 

52@click.option( 

53 "--output-folder", 

54 "-o", 

55 help="Path where to store the generated model (created if does not exist)", 

56 required=True, 

57 type=click.Path(), 

58 default="results", 

59 cls=ResourceOption, 

60) 

61@click.option( 

62 "--model", 

63 "-m", 

64 help="A torch.nn.Module instance implementing the network to be trained", 

65 required=True, 

66 cls=ResourceOption, 

67) 

68@click.option( 

69 "--dataset", 

70 "-d", 

71 help="A dictionary mapping string keys to " 

72 "torch.utils.data.dataset.Dataset instances implementing datasets " 

73 "to be used for training and validating the model, possibly including all " 

74 "pre-processing pipelines required or, optionally, a dictionary mapping " 

75 "string keys to torch.utils.data.dataset.Dataset instances. At least " 

76 "one key named ``train`` must be available. This dataset will be used for " 

77 "training the network model. The dataset description must include all " 

78 "required pre-processing, including eventual data augmentation. If a " 

79 "dataset named ``__train__`` is available, it is used prioritarily for " 

80 "training instead of ``train``. If a dataset named ``__valid__`` is " 

81 "available, it is used for model validation (and automatic " 

82 "check-pointing) at each epoch. If a dataset list named " 

83 "``__extra_valid__`` is available, then it will be tracked during the " 

84 "validation process and its loss output at the training log as well, " 

85 "in the format of an array occupying a single column. All other keys " 

86 "are considered test datasets and are ignored during training", 

87 required=True, 

88 cls=ResourceOption, 

89) 

90@click.option( 

91 "--optimizer", 

92 help="A torch.optim.Optimizer that will be used to train the network", 

93 required=True, 

94 cls=ResourceOption, 

95) 

96@click.option( 

97 "--criterion", 

98 help="A loss function to compute the FCN error for every sample " 

99 "respecting the PyTorch API for loss functions (see torch.nn.modules.loss)", 

100 required=True, 

101 cls=ResourceOption, 

102) 

103@click.option( 

104 "--scheduler", 

105 help="A learning rate scheduler that drives changes in the learning " 

106 "rate depending on the FCN state (see torch.optim.lr_scheduler)", 

107 required=True, 

108 cls=ResourceOption, 

109) 

110@click.option( 

111 "--batch-size", 

112 "-b", 

113 help="Number of samples in every batch (this parameter affects " 

114 "memory requirements for the network). If the number of samples in " 

115 "the batch is larger than the total number of samples available for " 

116 "training, this value is truncated. If this number is smaller, then " 

117 "batches of the specified size are created and fed to the network " 

118 "until there are no more new samples to feed (epoch is finished). " 

119 "If the total number of training samples is not a multiple of the " 

120 "batch-size, the last batch will be smaller than the first, unless " 

121 "--drop-incomplete-batch is set, in which case this batch is not used.", 

122 required=True, 

123 show_default=True, 

124 default=2, 

125 type=click.IntRange(min=1), 

126 cls=ResourceOption, 

127) 

128@click.option( 

129 "--batch-chunk-count", 

130 "-c", 

131 help="Number of chunks in every batch (this parameter affects " 

132 "memory requirements for the network). The number of samples " 

133 "loaded for every iteration will be batch-size/batch-chunk-count. " 

134 "batch-size needs to be divisible by batch-chunk-count, otherwise an " 

135 "error will be raised. This parameter is used to reduce number of " 

136 "samples loaded in each iteration, in order to reduce the memory usage " 

137 "in exchange for processing time (more iterations). This is specially " 

138 "interesting whe one is running with GPUs with limited RAM. The " 

139 "default of 1 forces the whole batch to be processed at once. Otherwise " 

140 "the batch is broken into batch-chunk-count pieces, and gradients are " 

141 "accumulated to complete each batch.", 

142 required=True, 

143 show_default=True, 

144 default=1, 

145 type=click.IntRange(min=1), 

146 cls=ResourceOption, 

147) 

148@click.option( 

149 "--drop-incomplete-batch/--no-drop-incomplete-batch", 

150 "-D", 

151 help="If set, then may drop the last batch in an epoch, in case it is " 

152 "incomplete. If you set this option, you should also consider " 

153 "increasing the total number of epochs of training, as the total number " 

154 "of training steps may be reduced", 

155 required=True, 

156 show_default=True, 

157 default=False, 

158 cls=ResourceOption, 

159) 

160@click.option( 

161 "--epochs", 

162 "-e", 

163 help="Number of epochs (complete training set passes) to train for. " 

164 "If continuing from a saved checkpoint, ensure to provide a greater " 

165 "number of epochs than that saved on the checkpoint to be loaded. ", 

166 show_default=True, 

167 required=True, 

168 default=1000, 

169 type=click.IntRange(min=1), 

170 cls=ResourceOption, 

171) 

172@click.option( 

173 "--checkpoint-period", 

174 "-p", 

175 help="Number of epochs after which a checkpoint is saved. " 

176 "A value of zero will disable check-pointing. If checkpointing is " 

177 "enabled and training stops, it is automatically resumed from the " 

178 "last saved checkpoint if training is restarted with the same " 

179 "configuration.", 

180 show_default=True, 

181 required=True, 

182 default=0, 

183 type=click.IntRange(min=0), 

184 cls=ResourceOption, 

185) 

186@click.option( 

187 "--device", 

188 "-d", 

189 help='A string indicating the device to use (e.g. "cpu" or "cuda:0")', 

190 show_default=True, 

191 required=True, 

192 default="cpu", 

193 cls=ResourceOption, 

194) 

195@click.option( 

196 "--seed", 

197 "-s", 

198 help="Seed to use for the random number generator", 

199 show_default=True, 

200 required=False, 

201 default=42, 

202 type=click.IntRange(min=0), 

203 cls=ResourceOption, 

204) 

205@click.option( 

206 "--parallel", 

207 "-P", 

208 help="""Use multiprocessing for data loading: if set to -1 (default), 

209 disables multiprocessing data loading. Set to 0 to enable as many data 

210 loading instances as processing cores as available in the system. Set to 

211 >= 1 to enable that many multiprocessing instances for data loading.""", 

212 type=click.IntRange(min=-1), 

213 show_default=True, 

214 required=True, 

215 default=-1, 

216 cls=ResourceOption, 

217) 

218@click.option( 

219 "--monitoring-interval", 

220 "-I", 

221 help="""Time between checks for the use of resources during each training 

222 epoch. An interval of 5 seconds, for example, will lead to CPU and GPU 

223 resources being probed every 5 seconds during each training epoch. 

224 Values registered in the training logs correspond to averages (or maxima) 

225 observed through possibly many probes in each epoch. Notice that setting a 

226 very small value may cause the probing process to become extremely busy, 

227 potentially biasing the overall perception of resource usage.""", 

228 type=click.FloatRange(min=0.1), 

229 show_default=True, 

230 required=True, 

231 default=5.0, 

232 cls=ResourceOption, 

233) 

234@verbosity_option(logger=logger, cls=ResourceOption) 

235@click.pass_context 

236def train( 

237 ctx, 

238 model, 

239 optimizer, 

240 scheduler, 

241 output_folder, 

242 epochs, 

243 batch_size, 

244 batch_chunk_count, 

245 drop_incomplete_batch, 

246 criterion, 

247 dataset, 

248 checkpoint_period, 

249 device, 

250 seed, 

251 parallel, 

252 monitoring_interval, 

253 verbose, 

254 **kwargs, 

255): 

256 """Trains an FCN to perform binary segmentation. 

257 

258 Training is performed for a configurable number of epochs, and generates at 

259 least a final_model.pth. It may also generate a number of intermediate 

260 checkpoints. Checkpoints are model files (.pth files) that are stored 

261 during the training and useful to resume the procedure in case it stops 

262 abruptly. 

263 

264 Tip: In case the model has been trained over a number of epochs, it is 

265 possible to continue training, by simply relaunching the same command, and 

266 changing the number of epochs to a number greater than the number where 

267 the original training session stopped (or the last checkpoint was saved). 

268 """ 

269 device = setup_pytorch_device(device) 

270 

271 set_seeds(seed, all_gpus=False) 

272 

273 use_dataset = dataset 

274 validation_dataset = None 

275 extra_validation_datasets = [] 

276 if isinstance(dataset, dict): 

277 if "__train__" in dataset: 

278 logger.info("Found (dedicated) '__train__' set for training") 

279 use_dataset = dataset["__train__"] 

280 else: 

281 use_dataset = dataset["train"] 

282 

283 if "__valid__" in dataset: 

284 logger.info("Found (dedicated) '__valid__' set for validation") 

285 logger.info("Will checkpoint lowest loss model on validation set") 

286 validation_dataset = dataset["__valid__"] 

287 

288 if "__extra_valid__" in dataset: 

289 if not isinstance(dataset["__extra_valid__"], list): 

290 raise RuntimeError( 

291 f"If present, dataset['__extra_valid__'] must be a list, " 

292 f"but you passed a {type(dataset['__extra_valid__'])}, " 

293 f"which is invalid." 

294 ) 

295 logger.info( 

296 f"Found {len(dataset['__extra_valid__'])} extra validation " 

297 f"set(s) to be tracked during training" 

298 ) 

299 logger.info( 

300 "Extra validation sets are NOT used for model checkpointing!" 

301 ) 

302 extra_validation_datasets = dataset["__extra_valid__"] 

303 

304 # PyTorch dataloader 

305 multiproc_kwargs = dict() 

306 if parallel < 0: 

307 multiproc_kwargs["num_workers"] = 0 

308 else: 

309 multiproc_kwargs["num_workers"] = ( 

310 parallel or multiprocessing.cpu_count() 

311 ) 

312 

313 if multiproc_kwargs["num_workers"] > 0 and sys.platform.startswith( 

314 "darwin" 

315 ): 

316 multiproc_kwargs[ 

317 "multiprocessing_context" 

318 ] = multiprocessing.get_context("spawn") 

319 

320 batch_chunk_size = batch_size 

321 if batch_size % batch_chunk_count != 0: 

322 # batch_size must be divisible by batch_chunk_count. 

323 raise RuntimeError( 

324 f"--batch-size ({batch_size}) must be divisible by " 

325 f"--batch-chunk-size ({batch_chunk_count})." 

326 ) 

327 else: 

328 batch_chunk_size = batch_size // batch_chunk_count 

329 

330 data_loader = DataLoader( 

331 dataset=use_dataset, 

332 batch_size=batch_chunk_size, 

333 shuffle=True, 

334 drop_last=drop_incomplete_batch, 

335 pin_memory=torch.cuda.is_available(), 

336 **multiproc_kwargs, 

337 ) 

338 

339 valid_loader = None 

340 if validation_dataset is not None: 

341 valid_loader = DataLoader( 

342 dataset=validation_dataset, 

343 batch_size=batch_chunk_size, 

344 shuffle=False, 

345 drop_last=False, 

346 pin_memory=torch.cuda.is_available(), 

347 **multiproc_kwargs, 

348 ) 

349 

350 extra_valid_loaders = [ 

351 DataLoader( 

352 dataset=k, 

353 batch_size=batch_chunk_size, 

354 shuffle=False, 

355 drop_last=False, 

356 pin_memory=torch.cuda.is_available(), 

357 **multiproc_kwargs, 

358 ) 

359 for k in extra_validation_datasets 

360 ] 

361 

362 checkpointer = Checkpointer(model, optimizer, scheduler, path=output_folder) 

363 

364 arguments = {} 

365 arguments["epoch"] = 0 

366 extra_checkpoint_data = checkpointer.load() 

367 arguments.update(extra_checkpoint_data) 

368 arguments["max_epoch"] = epochs 

369 

370 logger.info("Training for {} epochs".format(arguments["max_epoch"])) 

371 logger.info("Continuing from epoch {}".format(arguments["epoch"])) 

372 

373 run( 

374 model, 

375 data_loader, 

376 valid_loader, 

377 extra_valid_loaders, 

378 optimizer, 

379 criterion, 

380 scheduler, 

381 checkpointer, 

382 checkpoint_period, 

383 device, 

384 arguments, 

385 output_folder, 

386 monitoring_interval, 

387 batch_chunk_count, 

388 )