Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1#!/usr/bin/env python 

2# coding=utf-8 

3 

4import logging 

5import multiprocessing 

6import sys 

7 

8import click 

9import torch 

10 

11from torch.utils.data import DataLoader 

12 

13from bob.extension.scripts.click_helper import ( 

14 ConfigCommand, 

15 ResourceOption, 

16 verbosity_option, 

17) 

18 

19from ..utils.checkpointer import Checkpointer 

20from .binseg import set_seeds, setup_pytorch_device 

21 

22logger = logging.getLogger(__name__) 

23 

24 

25@click.command( 

26 entry_point_group="bob.ip.binseg.config", 

27 cls=ConfigCommand, 

28 epilog="""Examples: 

29 

30\b 

31 1. Trains a U-Net model (VGG-16 backbone) with DRIVE (vessel segmentation), 

32 on a GPU (``cuda:0``): 

33 

34 $ bob binseg train -vv unet drive --batch-size=4 --device="cuda:0" 

35 

36 2. Trains a HED model with HRF on a GPU (``cuda:0``): 

37 

38 $ bob binseg train -vv hed hrf --batch-size=8 --device="cuda:0" 

39 

40 3. Trains a M2U-Net model on the COVD-DRIVE dataset on the CPU: 

41 

42 $ bob binseg train -vv m2unet covd-drive --batch-size=8 

43 

44 4. Trains a DRIU model with SSL on the COVD-HRF dataset on the CPU: 

45 

46 $ bob binseg train -vv --ssl driu-ssl covd-drive-ssl --batch-size=1 

47 

48""", 

49) 

50@click.option( 

51 "--output-folder", 

52 "-o", 

53 help="Path where to store the generated model (created if does not exist)", 

54 required=True, 

55 type=click.Path(), 

56 default="results", 

57 cls=ResourceOption, 

58) 

59@click.option( 

60 "--model", 

61 "-m", 

62 help="A torch.nn.Module instance implementing the network to be trained", 

63 required=True, 

64 cls=ResourceOption, 

65) 

66@click.option( 

67 "--dataset", 

68 "-d", 

69 help="A torch.utils.data.dataset.Dataset instance implementing a dataset " 

70 "to be used for training the model, possibly including all pre-processing " 

71 "pipelines required or, optionally, a dictionary mapping string keys to " 

72 "torch.utils.data.dataset.Dataset instances. At least one key " 

73 "named ``train`` must be available. This dataset will be used for " 

74 "training the network model. The dataset description must include all " 

75 "required pre-processing, including eventual data augmentation. If a " 

76 "dataset named ``__train__`` is available, it is used prioritarily for " 

77 "training instead of ``train``. If a dataset named ``__valid__`` is " 

78 "available, it is used for model validation (and automatic check-pointing) " 

79 "at each epoch.", 

80 required=True, 

81 cls=ResourceOption, 

82) 

83@click.option( 

84 "--optimizer", 

85 help="A torch.optim.Optimizer that will be used to train the network", 

86 required=True, 

87 cls=ResourceOption, 

88) 

89@click.option( 

90 "--criterion", 

91 help="A loss function to compute the FCN error for every sample " 

92 "respecting the PyTorch API for loss functions (see torch.nn.modules.loss)", 

93 required=True, 

94 cls=ResourceOption, 

95) 

96@click.option( 

97 "--scheduler", 

98 help="A learning rate scheduler that drives changes in the learning " 

99 "rate depending on the FCN state (see torch.optim.lr_scheduler)", 

100 required=True, 

101 cls=ResourceOption, 

102) 

103@click.option( 

104 "--batch-size", 

105 "-b", 

106 help="Number of samples in every batch (this parameter affects " 

107 "memory requirements for the network). If the number of samples in " 

108 "the batch is larger than the total number of samples available for " 

109 "training, this value is truncated. If this number is smaller, then " 

110 "batches of the specified size are created and fed to the network " 

111 "until there are no more new samples to feed (epoch is finished). " 

112 "If the total number of training samples is not a multiple of the " 

113 "batch-size, the last batch will be smaller than the first, unless " 

114 "--drop-incomplete--batch is set, in which case this batch is not used.", 

115 required=True, 

116 show_default=True, 

117 default=2, 

118 type=click.IntRange(min=1), 

119 cls=ResourceOption, 

120) 

121@click.option( 

122 "--drop-incomplete-batch/--no-drop-incomplete-batch", 

123 "-D", 

124 help="If set, then may drop the last batch in an epoch, in case it is " 

125 "incomplete. If you set this option, you should also consider " 

126 "increasing the total number of epochs of training, as the total number " 

127 "of training steps may be reduced", 

128 required=True, 

129 show_default=True, 

130 default=False, 

131 cls=ResourceOption, 

132) 

133@click.option( 

134 "--epochs", 

135 "-e", 

136 help="Number of epochs (complete training set passes) to train for. " 

137 "If continuing from a saved checkpoint, ensure to provide a greater " 

138 "number of epochs than that saved on the checkpoint to be loaded. ", 

139 show_default=True, 

140 required=True, 

141 default=1000, 

142 type=click.IntRange(min=1), 

143 cls=ResourceOption, 

144) 

145@click.option( 

146 "--checkpoint-period", 

147 "-p", 

148 help="Number of epochs after which a checkpoint is saved. " 

149 "A value of zero will disable check-pointing. If checkpointing is " 

150 "enabled and training stops, it is automatically resumed from the " 

151 "last saved checkpoint if training is restarted with the same " 

152 "configuration.", 

153 show_default=True, 

154 required=True, 

155 default=0, 

156 type=click.IntRange(min=0), 

157 cls=ResourceOption, 

158) 

159@click.option( 

160 "--device", 

161 "-d", 

162 help='A string indicating the device to use (e.g. "cpu" or "cuda:0")', 

163 show_default=True, 

164 required=True, 

165 default="cpu", 

166 cls=ResourceOption, 

167) 

168@click.option( 

169 "--seed", 

170 "-s", 

171 help="Seed to use for the random number generator", 

172 show_default=True, 

173 required=False, 

174 default=42, 

175 type=click.IntRange(min=0), 

176 cls=ResourceOption, 

177) 

178@click.option( 

179 "--ssl/--no-ssl", 

180 help="Switch ON/OFF semi-supervised training mode", 

181 show_default=True, 

182 required=True, 

183 default=False, 

184 cls=ResourceOption, 

185) 

186@click.option( 

187 "--rampup", 

188 "-r", 

189 help="Ramp-up length in epochs (for SSL training only)", 

190 show_default=True, 

191 required=True, 

192 default=900, 

193 type=click.IntRange(min=0), 

194 cls=ResourceOption, 

195) 

196@click.option( 

197 "--multiproc-data-loading", 

198 "-P", 

199 help="""Use multiprocessing for data loading: if set to -1 (default), 

200 disables multiprocessing data loading. Set to 0 to enable as many data 

201 loading instances as processing cores as available in the system. Set to 

202 >= 1 to enable that many multiprocessing instances for data loading.""", 

203 type=click.IntRange(min=-1), 

204 show_default=True, 

205 required=True, 

206 default=-1, 

207 cls=ResourceOption, 

208) 

209@verbosity_option(cls=ResourceOption) 

210def train( 

211 model, 

212 optimizer, 

213 scheduler, 

214 output_folder, 

215 epochs, 

216 batch_size, 

217 drop_incomplete_batch, 

218 criterion, 

219 dataset, 

220 checkpoint_period, 

221 device, 

222 seed, 

223 ssl, 

224 rampup, 

225 multiproc_data_loading, 

226 verbose, 

227 **kwargs, 

228): 

229 """Trains an FCN to perform binary segmentation 

230 

231 Training is performed for a configurable number of epochs, and generates at 

232 least a final_model.pth. It may also generate a number of intermediate 

233 checkpoints. Checkpoints are model files (.pth files) that are stored 

234 during the training and useful to resume the procedure in case it stops 

235 abruptly. 

236 

237 Tip: In case the model has been trained over a number of epochs, it is 

238 possible to continue training, by simply relaunching the same command, and 

239 changing the number of epochs to a number greater than the number where 

240 the original training session stopped (or the last checkpoint was saved). 

241 

242 """ 

243 

244 device = setup_pytorch_device(device) 

245 

246 set_seeds(seed, all_gpus=False) 

247 

248 use_dataset = dataset 

249 validation_dataset = None 

250 if isinstance(dataset, dict): 

251 if "__train__" in dataset: 

252 logger.info("Found (dedicated) '__train__' set for training") 

253 use_dataset = dataset["__train__"] 

254 else: 

255 use_dataset = dataset["train"] 

256 

257 if "__valid__" in dataset: 

258 logger.info("Found (dedicated) '__valid__' set for validation") 

259 logger.info("Will checkpoint lowest loss model on validation set") 

260 validation_dataset = dataset["__valid__"] 

261 

262 # PyTorch dataloader 

263 multiproc_kwargs = dict() 

264 if multiproc_data_loading < 0: 

265 multiproc_kwargs["num_workers"] = 0 

266 elif multiproc_data_loading == 0: 

267 multiproc_kwargs["num_workers"] = multiprocessing.cpu_count() 

268 else: 

269 multiproc_kwargs["num_workers"] = multiproc_data_loading 

270 

271 if multiproc_kwargs["num_workers"] > 0 and sys.platform == "darwin": 

272 multiproc_kwargs[ 

273 "multiprocessing_context" 

274 ] = multiprocessing.get_context("spawn") 

275 

276 data_loader = DataLoader( 

277 dataset=use_dataset, 

278 batch_size=batch_size, 

279 shuffle=True, 

280 drop_last=drop_incomplete_batch, 

281 pin_memory=torch.cuda.is_available(), 

282 **multiproc_kwargs, 

283 ) 

284 

285 valid_loader = None 

286 if validation_dataset is not None: 

287 valid_loader = DataLoader( 

288 dataset=validation_dataset, 

289 batch_size=batch_size, 

290 shuffle=False, 

291 drop_last=False, 

292 pin_memory=torch.cuda.is_available(), 

293 **multiproc_kwargs, 

294 ) 

295 

296 checkpointer = Checkpointer(model, optimizer, scheduler, path=output_folder) 

297 

298 arguments = {} 

299 arguments["epoch"] = 0 

300 extra_checkpoint_data = checkpointer.load() 

301 arguments.update(extra_checkpoint_data) 

302 arguments["max_epoch"] = epochs 

303 

304 logger.info("Training for {} epochs".format(arguments["max_epoch"])) 

305 logger.info("Continuing from epoch {}".format(arguments["epoch"])) 

306 

307 if not ssl: 

308 from ..engine.trainer import run 

309 

310 run( 

311 model, 

312 data_loader, 

313 valid_loader, 

314 optimizer, 

315 criterion, 

316 scheduler, 

317 checkpointer, 

318 checkpoint_period, 

319 device, 

320 arguments, 

321 output_folder, 

322 ) 

323 

324 else: 

325 from ..engine.ssltrainer import run 

326 

327 run( 

328 model, 

329 data_loader, 

330 valid_loader, 

331 optimizer, 

332 criterion, 

333 scheduler, 

334 checkpointer, 

335 checkpoint_period, 

336 device, 

337 arguments, 

338 output_folder, 

339 rampup, 

340 )