Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1#!/usr/bin/env python 

2# coding=utf-8 

3 

4import logging 

5import os 

6import shutil 

7 

8import click 

9 

10from bob.extension.scripts.click_helper import ( 

11 ConfigCommand, 

12 ResourceOption, 

13 verbosity_option, 

14) 

15 

16from .binseg import save_sh_command 

17 

18logger = logging.getLogger(__name__) 

19 

20 

21@click.command( 

22 entry_point_group="bob.ip.binseg.config", 

23 cls=ConfigCommand, 

24 epilog="""Examples: 

25 

26\b 

27 1. Trains an M2U-Net model (VGG-16 backbone) with DRIVE (vessel 

28 segmentation), on the CPU, for only two epochs, then runs inference and 

29 evaluation on stock datasets, report performance as a table and a figure: 

30 

31 $ bob binseg experiment -vv m2unet drive --epochs=2 

32 

33""", 

34) 

35@click.option( 

36 "--output-folder", 

37 "-o", 

38 help="Path where to store experiment outputs (created if does not exist)", 

39 required=True, 

40 type=click.Path(), 

41 default="results", 

42 cls=ResourceOption, 

43) 

44@click.option( 

45 "--model", 

46 "-m", 

47 help="A torch.nn.Module instance implementing the network to be trained, and then evaluated", 

48 required=True, 

49 cls=ResourceOption, 

50) 

51@click.option( 

52 "--dataset", 

53 "-d", 

54 help="A dictionary mapping string keys to " 

55 "bob.ip.binseg.data.utils.SampleList2TorchDataset's. At least one key " 

56 "named 'train' must be available. This dataset will be used for training " 

57 "the network model. All other datasets will be used for prediction and " 

58 "evaluation. Dataset descriptions include all required pre-processing, " 

59 "including eventual data augmentation, which may be eventually excluded " 

60 "for prediction and evaluation purposes", 

61 required=True, 

62 cls=ResourceOption, 

63) 

64@click.option( 

65 "--second-annotator", 

66 "-S", 

67 help="A dataset or dictionary, like in --dataset, with the same " 

68 "sample keys, but with annotations from a different annotator that is " 

69 "going to be compared to the one in --dataset", 

70 required=False, 

71 default=None, 

72 cls=ResourceOption, 

73 show_default=True, 

74) 

75@click.option( 

76 "--optimizer", 

77 help="A torch.optim.Optimizer that will be used to train the network", 

78 required=True, 

79 cls=ResourceOption, 

80) 

81@click.option( 

82 "--criterion", 

83 help="A loss function to compute the FCN error for every sample " 

84 "respecting the PyTorch API for loss functions (see torch.nn.modules.loss)", 

85 required=True, 

86 cls=ResourceOption, 

87) 

88@click.option( 

89 "--scheduler", 

90 help="A learning rate scheduler that drives changes in the learning " 

91 "rate depending on the FCN state (see torch.optim.lr_scheduler)", 

92 required=True, 

93 cls=ResourceOption, 

94) 

95@click.option( 

96 "--batch-size", 

97 "-b", 

98 help="Number of samples in every batch (this parameter affects " 

99 "memory requirements for the network). If the number of samples in " 

100 "the batch is larger than the total number of samples available for " 

101 "training, this value is truncated. If this number is smaller, then " 

102 "batches of the specified size are created and fed to the network " 

103 "until there are no more new samples to feed (epoch is finished). " 

104 "If the total number of training samples is not a multiple of the " 

105 "batch-size, the last batch will be smaller than the first, unless " 

106 "--drop-incomplete--batch is set, in which case this batch is not used.", 

107 required=True, 

108 show_default=True, 

109 default=2, 

110 type=click.IntRange(min=1), 

111 cls=ResourceOption, 

112) 

113@click.option( 

114 "--drop-incomplete-batch/--no-drop-incomplete-batch", 

115 "-D", 

116 help="If set, then may drop the last batch in an epoch, in case it is " 

117 "incomplete. If you set this option, you should also consider " 

118 "increasing the total number of epochs of training, as the total number " 

119 "of training steps may be reduced", 

120 required=True, 

121 show_default=True, 

122 default=False, 

123 cls=ResourceOption, 

124) 

125@click.option( 

126 "--epochs", 

127 "-e", 

128 help="Number of epochs (complete training set passes) to train for. " 

129 "If continuing from a saved checkpoint, ensure to provide a greater " 

130 "number of epochs than that saved on the checkpoint to be loaded. ", 

131 show_default=True, 

132 required=True, 

133 default=1000, 

134 type=click.IntRange(min=1), 

135 cls=ResourceOption, 

136) 

137@click.option( 

138 "--checkpoint-period", 

139 "-p", 

140 help="Number of epochs after which a checkpoint is saved. " 

141 "A value of zero will disable check-pointing. If checkpointing is " 

142 "enabled and training stops, it is automatically resumed from the " 

143 "last saved checkpoint if training is restarted with the same " 

144 "configuration.", 

145 show_default=True, 

146 required=True, 

147 default=0, 

148 type=click.IntRange(min=0), 

149 cls=ResourceOption, 

150) 

151@click.option( 

152 "--device", 

153 "-d", 

154 help='A string indicating the device to use (e.g. "cpu" or "cuda:0")', 

155 show_default=True, 

156 required=True, 

157 default="cpu", 

158 cls=ResourceOption, 

159) 

160@click.option( 

161 "--seed", 

162 "-s", 

163 help="Seed to use for the random number generator", 

164 show_default=True, 

165 required=False, 

166 default=42, 

167 type=click.IntRange(min=0), 

168 cls=ResourceOption, 

169) 

170@click.option( 

171 "--ssl/--no-ssl", 

172 help="Switch ON/OFF semi-supervised training mode", 

173 show_default=True, 

174 required=True, 

175 default=False, 

176 cls=ResourceOption, 

177) 

178@click.option( 

179 "--rampup", 

180 "-r", 

181 help="Ramp-up length in epochs (for SSL training only)", 

182 show_default=True, 

183 required=True, 

184 default=900, 

185 type=click.IntRange(min=0), 

186 cls=ResourceOption, 

187) 

188@click.option( 

189 "--multiproc-data-loading", 

190 "-P", 

191 help="""Use multiprocessing for data loading: if set to -1 (default), 

192 disables multiprocessing data loading. Set to 0 to enable as many data 

193 loading instances as processing cores as available in the system. Set to 

194 >= 1 to enable that many multiprocessing instances for data loading.""", 

195 type=click.IntRange(min=-1), 

196 show_default=True, 

197 required=True, 

198 default=-1, 

199 cls=ResourceOption, 

200) 

201@click.option( 

202 "--overlayed/--no-overlayed", 

203 "-O", 

204 help="Creates overlayed representations of the output probability maps, " 

205 "similar to --overlayed in prediction-mode, except it includes " 

206 "distinctive colours for true and false positives and false negatives. " 

207 "If not set, or empty then do **NOT** output overlayed images.", 

208 show_default=True, 

209 default=False, 

210 required=False, 

211 cls=ResourceOption, 

212) 

213@click.option( 

214 "--steps", 

215 "-S", 

216 help="This number is used to define the number of threshold steps to " 

217 "consider when evaluating the highest possible F1-score on test data.", 

218 default=1000, 

219 show_default=True, 

220 required=True, 

221 cls=ResourceOption, 

222) 

223@verbosity_option(cls=ResourceOption) 

224@click.pass_context 

225def experiment( 

226 ctx, 

227 model, 

228 optimizer, 

229 scheduler, 

230 output_folder, 

231 epochs, 

232 batch_size, 

233 drop_incomplete_batch, 

234 criterion, 

235 dataset, 

236 second_annotator, 

237 checkpoint_period, 

238 device, 

239 seed, 

240 ssl, 

241 rampup, 

242 multiproc_data_loading, 

243 overlayed, 

244 steps, 

245 verbose, 

246 **kwargs, 

247): 

248 """Runs a complete experiment, from training, to prediction and evaluation 

249 

250 This script is just a wrapper around the individual scripts for training, 

251 running prediction, evaluating and comparing FCN model performance. It 

252 organises the output in a preset way:: 

253 

254 \b 

255 └─ <output-folder>/ 

256 ├── model/ #the generated model will be here 

257 ├── predictions/ #the prediction outputs for the train/test set 

258 ├── overlayed/ #the overlayed outputs for the train/test set 

259 ├── predictions/ #predictions overlayed on the input images 

260 ├── analysis/ #predictions overlayed on the input images 

261 ├ #including analysis of false positives, negatives 

262 ├ #and true positives 

263 └── second-annotator/ #if set, store overlayed images for the 

264 #second annotator here 

265 └── analysis / #the outputs of the analysis of both train/test sets 

266 #includes second-annotator "mesures" as well, if 

267 # configured 

268 

269 Training is performed for a configurable number of epochs, and generates at 

270 least a final_model.pth. It may also generate a number of intermediate 

271 checkpoints. Checkpoints are model files (.pth files) that are stored 

272 during the training and useful to resume the procedure in case it stops 

273 abruptly. 

274 

275 N.B.: The tool is designed to prevent analysis bias and allows one to 

276 provide separate subsets for training and evaluation. Instead of using 

277 simple datasets, datasets for full experiment running should be 

278 dictionaries with specific subset names: 

279 

280 * ``__train__``: dataset used for training, prioritarily. It is typically 

281 the dataset containing data augmentation pipelines. 

282 * ``__valid__``: dataset used for validation. It is typically disjoint 

283 from the training and test sets. In such a case, we checkpoint the model 

284 with the lowest loss on the validation set as well, throughout all the 

285 training, besides the model at the end of training. 

286 * ``train`` (optional): a copy of the ``__train__`` dataset, without data 

287 augmentation, that will be evaluated alongside other sets available 

288 * ``*``: any other name, not starting with an underscore character (``_``), 

289 will be considered a test set for evaluation. 

290 

291 N.B.2: The threshold used for calculating the F1-score on the test set, or 

292 overlay analysis (false positives, negatives and true positives overprinted 

293 on the original image) also follows the logic above. 

294 """ 

295 

296 command_sh = os.path.join(output_folder, "command.sh") 

297 if os.path.exists(command_sh): 

298 backup = command_sh + "~" 

299 if os.path.exists(backup): 

300 os.unlink(backup) 

301 shutil.move(command_sh, backup) 

302 save_sh_command(command_sh) 

303 

304 # training 

305 logger.info("Started training") 

306 

307 from .train import train 

308 

309 train_output_folder = os.path.join(output_folder, "model") 

310 

311 ctx.invoke( 

312 train, 

313 model=model, 

314 optimizer=optimizer, 

315 scheduler=scheduler, 

316 output_folder=train_output_folder, 

317 epochs=epochs, 

318 batch_size=batch_size, 

319 drop_incomplete_batch=drop_incomplete_batch, 

320 criterion=criterion, 

321 dataset=dataset, 

322 checkpoint_period=checkpoint_period, 

323 device=device, 

324 seed=seed, 

325 ssl=ssl, 

326 rampup=rampup, 

327 multiproc_data_loading=multiproc_data_loading, 

328 verbose=verbose, 

329 ) 

330 logger.info("Ended training") 

331 

332 from .analyze import analyze 

333 

334 # preferably, we use the best model on the validation set 

335 # otherwise, we get the last saved model 

336 model_file = os.path.join( 

337 train_output_folder, "model_lowest_valid_loss.pth" 

338 ) 

339 if not os.path.exists(model_file): 

340 model_file = os.path.join(train_output_folder, "model_final.pth") 

341 

342 ctx.invoke( 

343 analyze, 

344 model=model, 

345 output_folder=output_folder, 

346 batch_size=batch_size, 

347 dataset=dataset, 

348 second_annotator=second_annotator, 

349 device=device, 

350 overlayed=overlayed, 

351 weight=model_file, 

352 steps=steps, 

353 verbose=verbose, 

354 )