Coverage for /scratch/builds/bob/bob.med.tb/miniconda/conda-bld/bob.med.tb_1637571489937/_test_env_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placeho/lib/python3.8/site-packages/bob/med/tb/scripts/train.py: 93%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

86 statements  

1#!/usr/bin/env python 

2# coding=utf-8 

3 

4import os 

5 

6import click 

7import torch 

8from torch.nn import BCEWithLogitsLoss 

9from torch.utils.data import DataLoader, WeightedRandomSampler 

10from ..configs.datasets import get_samples_weights, get_positive_weights 

11 

12from bob.extension.scripts.click_helper import ( 

13 verbosity_option, 

14 ConfigCommand, 

15 ResourceOption, 

16) 

17 

18from ..utils.checkpointer import Checkpointer 

19from ..engine.trainer import run 

20from .tb import download_to_tempfile 

21from ..models.normalizer import TorchVisionNormalizer 

22 

23import logging 

24logger = logging.getLogger(__name__) 

25 

26 

27@click.command( 

28 entry_point_group="bob.med.tb.config", 

29 cls=ConfigCommand, 

30 epilog="""Examples: 

31 

32\b 

33 1. Trains PASA model with Montgomery dataset, 

34 on a GPU (``cuda:0``): 

35 

36 $ bob tb train -vv pasa montgomery --batch-size=4 --device="cuda:0" 

37 

38""", 

39) 

40@click.option( 

41 "--output-folder", 

42 "-o", 

43 help="Path where to store the generated model (created if does not exist)", 

44 required=True, 

45 type=click.Path(), 

46 default="results", 

47 cls=ResourceOption, 

48) 

49@click.option( 

50 "--model", 

51 "-m", 

52 help="A torch.nn.Module instance implementing the network to be trained", 

53 required=True, 

54 cls=ResourceOption, 

55) 

56@click.option( 

57 "--dataset", 

58 "-d", 

59 help="A torch.utils.data.dataset.Dataset instance implementing a dataset " 

60 "to be used for training the model, possibly including all pre-processing " 

61 "pipelines required or, optionally, a dictionary mapping string keys to " 

62 "torch.utils.data.dataset.Dataset instances. At least one key " 

63 "named ``train`` must be available. This dataset will be used for " 

64 "training the network model. The dataset description must include all " 

65 "required pre-processing, including eventual data augmentation. If a " 

66 "dataset named ``__train__`` is available, it is used prioritarily for " 

67 "training instead of ``train``. If a dataset named ``__valid__`` is " 

68 "available, it is used for model validation (and automatic check-pointing) " 

69 "at each epoch.", 

70 required=True, 

71 cls=ResourceOption, 

72) 

73@click.option( 

74 "--optimizer", 

75 help="A torch.optim.Optimizer that will be used to train the network", 

76 required=True, 

77 cls=ResourceOption, 

78) 

79@click.option( 

80 "--criterion", 

81 help="A loss function to compute the CNN error for every sample " 

82 "respecting the PyTorch API for loss functions (see torch.nn.modules.loss)", 

83 required=True, 

84 cls=ResourceOption, 

85) 

86@click.option( 

87 "--criterion_valid", 

88 help="A specific loss function for the validation set to compute the CNN" 

89 "error for every sample respecting the PyTorch API for loss functions" 

90 "(see torch.nn.modules.loss)", 

91 required=False, 

92 cls=ResourceOption, 

93) 

94@click.option( 

95 "--batch-size", 

96 "-b", 

97 help="Number of samples in every batch (this parameter affects " 

98 "memory requirements for the network). If the number of samples in " 

99 "the batch is larger than the total number of samples available for " 

100 "training, this value is truncated. If this number is smaller, then " 

101 "batches of the specified size are created and fed to the network " 

102 "until there are no more new samples to feed (epoch is finished). " 

103 "If the total number of training samples is not a multiple of the " 

104 "batch-size, the last batch will be smaller than the first, unless " 

105 "--drop-incomplete--batch is set, in which case this batch is not used.", 

106 required=True, 

107 show_default=True, 

108 default=1, 

109 type=click.IntRange(min=1), 

110 cls=ResourceOption, 

111) 

112@click.option( 

113 "--drop-incomplete-batch/--no-drop-incomplete-batch", 

114 "-D", 

115 help="If set, then may drop the last batch in an epoch, in case it is " 

116 "incomplete. If you set this option, you should also consider " 

117 "increasing the total number of epochs of training, as the total number " 

118 "of training steps may be reduced", 

119 required=True, 

120 show_default=True, 

121 default=False, 

122 cls=ResourceOption, 

123) 

124@click.option( 

125 "--epochs", 

126 "-e", 

127 help="Number of epochs (complete training set passes) to train for", 

128 show_default=True, 

129 required=True, 

130 default=1000, 

131 type=click.IntRange(min=1), 

132 cls=ResourceOption, 

133) 

134@click.option( 

135 "--checkpoint-period", 

136 "-p", 

137 help="Number of epochs after which a checkpoint is saved. " 

138 "A value of zero will disable check-pointing. If checkpointing is " 

139 "enabled and training stops, it is automatically resumed from the " 

140 "last saved checkpoint if training is restarted with the same " 

141 "configuration.", 

142 show_default=True, 

143 required=True, 

144 default=0, 

145 type=click.IntRange(min=0), 

146 cls=ResourceOption, 

147) 

148@click.option( 

149 "--device", 

150 "-d", 

151 help='A string indicating the device to use (e.g. "cpu" or "cuda:0")', 

152 show_default=True, 

153 required=True, 

154 default="cpu", 

155 cls=ResourceOption, 

156) 

157@click.option( 

158 "--seed", 

159 "-s", 

160 help="Seed to use for the random number generator", 

161 show_default=True, 

162 required=False, 

163 default=42, 

164 type=click.IntRange(min=0), 

165 cls=ResourceOption, 

166) 

167@click.option( 

168 "--num_workers", 

169 "-ns", 

170 help="Number of parallel threads to use", 

171 show_default=True, 

172 required=False, 

173 default=0, 

174 type=click.IntRange(min=0), 

175 cls=ResourceOption, 

176) 

177@click.option( 

178 "--weight", 

179 "-w", 

180 help="Path or URL to pretrained model file (.pth extension)", 

181 required=False, 

182 cls=ResourceOption, 

183) 

184@click.option( 

185 "--normalization", 

186 "-n", 

187 help="Z-Normalization of input images: 'imagenet' for ImageNet parameters," 

188 " 'current' for parameters of the current trainset, " 

189 "'none' for no normalization.", 

190 required=False, 

191 default="none", 

192 cls=ResourceOption, 

193) 

194@verbosity_option(cls=ResourceOption) 

195def train( 

196 model, 

197 optimizer, 

198 output_folder, 

199 epochs, 

200 batch_size, 

201 drop_incomplete_batch, 

202 criterion, 

203 criterion_valid, 

204 dataset, 

205 checkpoint_period, 

206 device, 

207 seed, 

208 num_workers, 

209 weight, 

210 normalization, 

211 verbose, 

212 **kwargs, 

213): 

214 """Trains an CNN to perform tuberculosis detection 

215 

216 Training is performed for a configurable number of epochs, and generates at 

217 least a final_model.pth. It may also generate a number of intermediate 

218 checkpoints. Checkpoints are model files (.pth files) that are stored 

219 during the training and useful to resume the procedure in case it stops 

220 abruptly. 

221 """ 

222 

223 torch.manual_seed(seed) 

224 

225 use_dataset = dataset 

226 validation_dataset = None 

227 if isinstance(dataset, dict): 

228 if "__train__" in dataset: 

229 logger.info("Found (dedicated) '__train__' set for training") 

230 use_dataset = dataset["__train__"] 

231 else: 

232 use_dataset = dataset["train"] 

233 

234 if "__valid__" in dataset: 

235 logger.info("Found (dedicated) '__valid__' set for validation") 

236 logger.info("Will checkpoint lowest loss model on validation set") 

237 validation_dataset = dataset["__valid__"] 

238 

239 # Create weighted random sampler 

240 train_samples_weights = get_samples_weights(use_dataset) 

241 train_samples_weights = train_samples_weights.to( 

242 device=device, non_blocking=torch.cuda.is_available() 

243 ) 

244 train_sampler = WeightedRandomSampler(train_samples_weights, len(train_samples_weights), replacement=True) 

245 

246 # Redefine a weighted criterion if possible 

247 if isinstance(criterion, torch.nn.BCEWithLogitsLoss): 

248 positive_weights = get_positive_weights(use_dataset) 

249 positive_weights = positive_weights.to( 

250 device=device, non_blocking=torch.cuda.is_available() 

251 ) 

252 criterion = BCEWithLogitsLoss(pos_weight=positive_weights) 

253 else: 

254 logger.warning("Weighted criterion not supported") 

255 

256 # PyTorch dataloader 

257 data_loader = DataLoader( 

258 dataset=use_dataset, 

259 batch_size=batch_size, 

260 num_workers=num_workers, 

261 drop_last=drop_incomplete_batch, 

262 pin_memory=torch.cuda.is_available(), 

263 sampler=train_sampler 

264 ) 

265 

266 valid_loader = None 

267 if validation_dataset is not None: 

268 

269 # Redefine a weighted valid criterion if possible 

270 if isinstance(criterion_valid, torch.nn.BCEWithLogitsLoss) or criterion_valid is None: 

271 positive_weights = get_positive_weights(validation_dataset) 

272 positive_weights = positive_weights.to( 

273 device=device, non_blocking=torch.cuda.is_available() 

274 ) 

275 criterion_valid = BCEWithLogitsLoss(pos_weight=positive_weights) 

276 else: 

277 logger.warning("Weighted valid criterion not supported") 

278 

279 valid_loader = DataLoader( 

280 dataset=validation_dataset, 

281 batch_size=batch_size, 

282 num_workers=num_workers, 

283 shuffle=False, 

284 drop_last=False, 

285 pin_memory=torch.cuda.is_available(), 

286 ) 

287 

288 # Create z-normalization model layer if needed 

289 if normalization == "imagenet": 

290 model.normalizer.set_mean_std([0.485, 0.456, 0.406], 

291 [0.229, 0.224, 0.225]) 

292 logger.info("Z-normalization with ImageNet mean and std") 

293 elif normalization == "current": 

294 # Compute mean/std of current train subset 

295 temp_dl = DataLoader( 

296 dataset=use_dataset, 

297 batch_size=len(use_dataset) 

298 ) 

299 

300 data = next(iter(temp_dl)) 

301 mean = data[1].mean(dim=[0,2,3]) 

302 std = data[1].std(dim=[0,2,3]) 

303 

304 model.normalizer.set_mean_std(mean, std) 

305 

306 # Format mean and std for logging 

307 mean = str([round(x, 3) for x in ((mean * 10**3).round() / (10**3)).tolist()]) 

308 std = str([round(x, 3) for x in ((std * 10**3).round() / (10**3)).tolist()]) 

309 logger.info("Z-normalization with mean {} and std {}".format(mean, std)) 

310 

311 # Checkpointer 

312 checkpointer = Checkpointer(model, optimizer, path=output_folder) 

313 

314 # Load pretrained weights if needed 

315 if weight is not None: 

316 if weight.startswith("http"): 

317 logger.info(f"Temporarily downloading '{weight}'...") 

318 f = download_to_tempfile(weight, progress=True) 

319 weight_fullpath = os.path.abspath(f.name) 

320 else: 

321 weight_fullpath = os.path.abspath(weight) 

322 checkpointer.load(weight_fullpath, strict=False) 

323 

324 arguments = {} 

325 arguments["epoch"] = 0 

326 arguments["max_epoch"] = epochs 

327 

328 logger.info("Training for {} epochs".format(arguments["max_epoch"])) 

329 logger.info("Continuing from epoch {}".format(arguments["epoch"])) 

330 

331 run( 

332 model, 

333 data_loader, 

334 valid_loader, 

335 optimizer, 

336 criterion, 

337 checkpointer, 

338 checkpoint_period, 

339 device, 

340 arguments, 

341 output_folder, 

342 criterion_valid, 

343 )