1#!/usr/bin/env python
2# coding=utf-8
3
4import os
5
6import click
7import torch
8from torch.nn import BCEWithLogitsLoss
9from torch.utils.data import DataLoader, WeightedRandomSampler
10from ..configs.datasets import get_samples_weights, get_positive_weights
11
12from bob.extension.scripts.click_helper import (
13 verbosity_option,
14 ConfigCommand,
15 ResourceOption,
16)
17
18from ..utils.checkpointer import Checkpointer
19from ..engine.trainer import run
20from .tb import download_to_tempfile
21from ..models.normalizer import TorchVisionNormalizer
22
23import logging
24logger = logging.getLogger(__name__)
25
26
27@click.command(
28 entry_point_group="bob.med.tb.config",
29 cls=ConfigCommand,
30 epilog="""Examples:
31
32\b
33 1. Trains PASA model with Montgomery dataset,
34 on a GPU (``cuda:0``):
35
36 $ bob tb train -vv pasa montgomery --batch-size=4 --device="cuda:0"
37
38""",
39)
40@click.option(
41 "--output-folder",
42 "-o",
43 help="Path where to store the generated model (created if does not exist)",
44 required=True,
45 type=click.Path(),
46 default="results",
47 cls=ResourceOption,
48)
49@click.option(
50 "--model",
51 "-m",
52 help="A torch.nn.Module instance implementing the network to be trained",
53 required=True,
54 cls=ResourceOption,
55)
56@click.option(
57 "--dataset",
58 "-d",
59 help="A torch.utils.data.dataset.Dataset instance implementing a dataset "
60 "to be used for training the model, possibly including all pre-processing "
61 "pipelines required or, optionally, a dictionary mapping string keys to "
62 "torch.utils.data.dataset.Dataset instances. At least one key "
63 "named ``train`` must be available. This dataset will be used for "
64 "training the network model. The dataset description must include all "
65 "required pre-processing, including eventual data augmentation. If a "
66 "dataset named ``__train__`` is available, it is used prioritarily for "
67 "training instead of ``train``. If a dataset named ``__valid__`` is "
68 "available, it is used for model validation (and automatic check-pointing) "
69 "at each epoch.",
70 required=True,
71 cls=ResourceOption,
72)
73@click.option(
74 "--optimizer",
75 help="A torch.optim.Optimizer that will be used to train the network",
76 required=True,
77 cls=ResourceOption,
78)
79@click.option(
80 "--criterion",
81 help="A loss function to compute the CNN error for every sample "
82 "respecting the PyTorch API for loss functions (see torch.nn.modules.loss)",
83 required=True,
84 cls=ResourceOption,
85)
86@click.option(
87 "--criterion_valid",
88 help="A specific loss function for the validation set to compute the CNN"
89 "error for every sample respecting the PyTorch API for loss functions"
90 "(see torch.nn.modules.loss)",
91 required=False,
92 cls=ResourceOption,
93)
94@click.option(
95 "--batch-size",
96 "-b",
97 help="Number of samples in every batch (this parameter affects "
98 "memory requirements for the network). If the number of samples in "
99 "the batch is larger than the total number of samples available for "
100 "training, this value is truncated. If this number is smaller, then "
101 "batches of the specified size are created and fed to the network "
102 "until there are no more new samples to feed (epoch is finished). "
103 "If the total number of training samples is not a multiple of the "
104 "batch-size, the last batch will be smaller than the first, unless "
105 "--drop-incomplete--batch is set, in which case this batch is not used.",
106 required=True,
107 show_default=True,
108 default=1,
109 type=click.IntRange(min=1),
110 cls=ResourceOption,
111)
112@click.option(
113 "--drop-incomplete-batch/--no-drop-incomplete-batch",
114 "-D",
115 help="If set, then may drop the last batch in an epoch, in case it is "
116 "incomplete. If you set this option, you should also consider "
117 "increasing the total number of epochs of training, as the total number "
118 "of training steps may be reduced",
119 required=True,
120 show_default=True,
121 default=False,
122 cls=ResourceOption,
123)
124@click.option(
125 "--epochs",
126 "-e",
127 help="Number of epochs (complete training set passes) to train for",
128 show_default=True,
129 required=True,
130 default=1000,
131 type=click.IntRange(min=1),
132 cls=ResourceOption,
133)
134@click.option(
135 "--checkpoint-period",
136 "-p",
137 help="Number of epochs after which a checkpoint is saved. "
138 "A value of zero will disable check-pointing. If checkpointing is "
139 "enabled and training stops, it is automatically resumed from the "
140 "last saved checkpoint if training is restarted with the same "
141 "configuration.",
142 show_default=True,
143 required=True,
144 default=0,
145 type=click.IntRange(min=0),
146 cls=ResourceOption,
147)
148@click.option(
149 "--device",
150 "-d",
151 help='A string indicating the device to use (e.g. "cpu" or "cuda:0")',
152 show_default=True,
153 required=True,
154 default="cpu",
155 cls=ResourceOption,
156)
157@click.option(
158 "--seed",
159 "-s",
160 help="Seed to use for the random number generator",
161 show_default=True,
162 required=False,
163 default=42,
164 type=click.IntRange(min=0),
165 cls=ResourceOption,
166)
167@click.option(
168 "--num_workers",
169 "-ns",
170 help="Number of parallel threads to use",
171 show_default=True,
172 required=False,
173 default=0,
174 type=click.IntRange(min=0),
175 cls=ResourceOption,
176)
177@click.option(
178 "--weight",
179 "-w",
180 help="Path or URL to pretrained model file (.pth extension)",
181 required=False,
182 cls=ResourceOption,
183)
184@click.option(
185 "--normalization",
186 "-n",
187 help="Z-Normalization of input images: 'imagenet' for ImageNet parameters,"
188 " 'current' for parameters of the current trainset, "
189 "'none' for no normalization.",
190 required=False,
191 default="none",
192 cls=ResourceOption,
193)
194@verbosity_option(cls=ResourceOption)
195def train(
196 model,
197 optimizer,
198 output_folder,
199 epochs,
200 batch_size,
201 drop_incomplete_batch,
202 criterion,
203 criterion_valid,
204 dataset,
205 checkpoint_period,
206 device,
207 seed,
208 num_workers,
209 weight,
210 normalization,
211 verbose,
212 **kwargs,
213):
214 """Trains an CNN to perform tuberculosis detection
215
216 Training is performed for a configurable number of epochs, and generates at
217 least a final_model.pth. It may also generate a number of intermediate
218 checkpoints. Checkpoints are model files (.pth files) that are stored
219 during the training and useful to resume the procedure in case it stops
220 abruptly.
221 """
222
223 torch.manual_seed(seed)
224
225 use_dataset = dataset
226 validation_dataset = None
227 if isinstance(dataset, dict):
228 if "__train__" in dataset:
229 logger.info("Found (dedicated) '__train__' set for training")
230 use_dataset = dataset["__train__"]
231 else:
232 use_dataset = dataset["train"]
233
234 if "__valid__" in dataset:
235 logger.info("Found (dedicated) '__valid__' set for validation")
236 logger.info("Will checkpoint lowest loss model on validation set")
237 validation_dataset = dataset["__valid__"]
238
239 # Create weighted random sampler
240 train_samples_weights = get_samples_weights(use_dataset)
241 train_samples_weights = train_samples_weights.to(
242 device=device, non_blocking=torch.cuda.is_available()
243 )
244 train_sampler = WeightedRandomSampler(train_samples_weights, len(train_samples_weights), replacement=True)
245
246 # Redefine a weighted criterion if possible
247 if isinstance(criterion, torch.nn.BCEWithLogitsLoss):
248 positive_weights = get_positive_weights(use_dataset)
249 positive_weights = positive_weights.to(
250 device=device, non_blocking=torch.cuda.is_available()
251 )
252 criterion = BCEWithLogitsLoss(pos_weight=positive_weights)
253 else:
254 logger.warning("Weighted criterion not supported")
255
256 # PyTorch dataloader
257 data_loader = DataLoader(
258 dataset=use_dataset,
259 batch_size=batch_size,
260 num_workers=num_workers,
261 drop_last=drop_incomplete_batch,
262 pin_memory=torch.cuda.is_available(),
263 sampler=train_sampler
264 )
265
266 valid_loader = None
267 if validation_dataset is not None:
268
269 # Redefine a weighted valid criterion if possible
270 if isinstance(criterion_valid, torch.nn.BCEWithLogitsLoss) or criterion_valid is None:
271 positive_weights = get_positive_weights(validation_dataset)
272 positive_weights = positive_weights.to(
273 device=device, non_blocking=torch.cuda.is_available()
274 )
275 criterion_valid = BCEWithLogitsLoss(pos_weight=positive_weights)
276 else:
277 logger.warning("Weighted valid criterion not supported")
278
279 valid_loader = DataLoader(
280 dataset=validation_dataset,
281 batch_size=batch_size,
282 num_workers=num_workers,
283 shuffle=False,
284 drop_last=False,
285 pin_memory=torch.cuda.is_available(),
286 )
287
288 # Create z-normalization model layer if needed
289 if normalization == "imagenet":
290 model.normalizer.set_mean_std([0.485, 0.456, 0.406],
291 [0.229, 0.224, 0.225])
292 logger.info("Z-normalization with ImageNet mean and std")
293 elif normalization == "current":
294 # Compute mean/std of current train subset
295 temp_dl = DataLoader(
296 dataset=use_dataset,
297 batch_size=len(use_dataset)
298 )
299
300 data = next(iter(temp_dl))
301 mean = data[1].mean(dim=[0,2,3])
302 std = data[1].std(dim=[0,2,3])
303
304 model.normalizer.set_mean_std(mean, std)
305
306 # Format mean and std for logging
307 mean = str([round(x, 3) for x in ((mean * 10**3).round() / (10**3)).tolist()])
308 std = str([round(x, 3) for x in ((std * 10**3).round() / (10**3)).tolist()])
309 logger.info("Z-normalization with mean {} and std {}".format(mean, std))
310
311 # Checkpointer
312 checkpointer = Checkpointer(model, optimizer, path=output_folder)
313
314 # Load pretrained weights if needed
315 if weight is not None:
316 if weight.startswith("http"):
317 logger.info(f"Temporarily downloading '{weight}'...")
318 f = download_to_tempfile(weight, progress=True)
319 weight_fullpath = os.path.abspath(f.name)
320 else:
321 weight_fullpath = os.path.abspath(weight)
322 checkpointer.load(weight_fullpath, strict=False)
323
324 arguments = {}
325 arguments["epoch"] = 0
326 arguments["max_epoch"] = epochs
327
328 logger.info("Training for {} epochs".format(arguments["max_epoch"]))
329 logger.info("Continuing from epoch {}".format(arguments["epoch"]))
330
331 run(
332 model,
333 data_loader,
334 valid_loader,
335 optimizer,
336 criterion,
337 checkpointer,
338 checkpoint_period,
339 device,
340 arguments,
341 output_folder,
342 criterion_valid,
343 )