Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
4import csv
5import datetime
6import logging
7import os
8import shutil
9import sys
10import time
12import numpy
13import torch
15from tqdm import tqdm
17from ..utils.measure import SmoothedValue
18from ..utils.resources import cpu_constants, cpu_log, gpu_constants, gpu_log
19from ..utils.summary import summary
20from .trainer import PYTORCH_GE_110, torch_evaluation
22logger = logging.getLogger(__name__)
25def sharpen(x, T):
26 temp = x ** (1 / T)
27 return temp / temp.sum(dim=1, keepdim=True)
30def mix_up(alpha, input, target, unlabelled_input, unlabled_target):
31 """Applies mix up as described in [MIXMATCH_19].
33 Parameters
34 ----------
35 alpha : float
37 input : :py:class:`torch.Tensor`
39 target : :py:class:`torch.Tensor`
41 unlabelled_input : :py:class:`torch.Tensor`
43 unlabled_target : :py:class:`torch.Tensor`
46 Returns
47 -------
49 list
51 """
53 with torch.no_grad():
54 l = numpy.random.beta(alpha, alpha) # Eq (8)
55 l = max(l, 1 - l) # Eq (9)
56 # Shuffle and concat. Alg. 1 Line: 12
57 w_inputs = torch.cat([input, unlabelled_input], 0)
58 w_targets = torch.cat([target, unlabled_target], 0)
59 idx = torch.randperm(w_inputs.size(0)) # get random index
61 # Apply MixUp to labelled data and entries from W. Alg. 1 Line: 13
62 input_mixedup = l * input + (1 - l) * w_inputs[idx[len(input) :]]
63 target_mixedup = l * target + (1 - l) * w_targets[idx[len(target) :]]
65 # Apply MixUp to unlabelled data and entries from W. Alg. 1 Line: 14
66 unlabelled_input_mixedup = (
67 l * unlabelled_input
68 + (1 - l) * w_inputs[idx[: len(unlabelled_input)]]
69 )
70 unlabled_target_mixedup = (
71 l * unlabled_target
72 + (1 - l) * w_targets[idx[: len(unlabled_target)]]
73 )
74 return (
75 input_mixedup,
76 target_mixedup,
77 unlabelled_input_mixedup,
78 unlabled_target_mixedup,
79 )
82def square_rampup(current, rampup_length=16):
83 """slowly ramp-up ``lambda_u``
85 Parameters
86 ----------
88 current : int
89 current epoch
91 rampup_length : :obj:`int`, optional
92 how long to ramp up, by default 16
94 Returns
95 -------
97 factor : float
98 ramp up factor
99 """
101 if rampup_length == 0:
102 return 1.0
103 else:
104 current = numpy.clip((current / float(rampup_length)) ** 2, 0.0, 1.0)
105 return float(current)
108def linear_rampup(current, rampup_length=16):
109 """slowly ramp-up ``lambda_u``
111 Parameters
112 ----------
113 current : int
114 current epoch
116 rampup_length : :obj:`int`, optional
117 how long to ramp up, by default 16
119 Returns
120 -------
122 factor: float
123 ramp up factor
125 """
126 if rampup_length == 0:
127 return 1.0
128 else:
129 current = numpy.clip(current / rampup_length, 0.0, 1.0)
130 return float(current)
133def guess_labels(unlabelled_images, model):
134 """
135 Calculate the average predictions by 2 augmentations: horizontal and vertical flips
137 Parameters
138 ----------
140 unlabelled_images : :py:class:`torch.Tensor`
141 ``[n,c,h,w]``
143 target : :py:class:`torch.Tensor`
145 Returns
146 -------
148 shape : :py:class:`torch.Tensor`
149 ``[n,c,h,w]``
151 """
152 with torch.no_grad():
153 guess1 = torch.sigmoid(model(unlabelled_images)).unsqueeze(0)
154 # Horizontal flip and unsqueeze to work with batches (increase flip dimension by 1)
155 hflip = torch.sigmoid(model(unlabelled_images.flip(2))).unsqueeze(0)
156 guess2 = hflip.flip(3)
157 # Vertical flip and unsqueeze to work with batches (increase flip dimension by 1)
158 vflip = torch.sigmoid(model(unlabelled_images.flip(3))).unsqueeze(0)
159 guess3 = vflip.flip(4)
160 # Concat
161 concat = torch.cat([guess1, guess2, guess3], 0)
162 avg_guess = torch.mean(concat, 0)
163 return avg_guess
166def run(
167 model,
168 data_loader,
169 valid_loader,
170 optimizer,
171 criterion,
172 scheduler,
173 checkpointer,
174 checkpoint_period,
175 device,
176 arguments,
177 output_folder,
178 rampup_length,
179):
180 """
181 Fits an FCN model using semi-supervised learning and saves it to disk.
184 This method supports periodic checkpointing and the output of a
185 CSV-formatted log with the evolution of some figures during training.
188 Parameters
189 ----------
191 model : :py:class:`torch.nn.Module`
192 Network (e.g. driu, hed, unet)
194 data_loader : :py:class:`torch.utils.data.DataLoader`
195 To be used to train the model
197 valid_loader : :py:class:`torch.utils.data.DataLoader`
198 To be used to validate the model and enable automatic checkpointing.
199 If set to ``None``, then do not validate it.
201 optimizer : :py:mod:`torch.optim`
203 criterion : :py:class:`torch.nn.modules.loss._Loss`
204 loss function
206 scheduler : :py:mod:`torch.optim`
207 learning rate scheduler
209 checkpointer : :py:class:`bob.ip.binseg.utils.checkpointer.Checkpointer`
210 checkpointer implementation
212 checkpoint_period : int
213 save a checkpoint every ``n`` epochs. If set to ``0`` (zero), then do
214 not save intermediary checkpoints
216 device : str
217 device to use ``'cpu'`` or ``cuda:0``
219 arguments : dict
220 start and end epochs
222 output_folder : str
223 output path
225 rampup_length : int
226 rampup epochs
228 """
230 start_epoch = arguments["epoch"]
231 max_epoch = arguments["max_epoch"]
233 if device != "cpu":
234 # asserts we do have a GPU
235 assert bool(gpu_constants()), (
236 f"Device set to '{device}', but cannot "
237 f"find a GPU (maybe nvidia-smi is not installed?)"
238 )
240 os.makedirs(output_folder, exist_ok=True)
242 # Save model summary
243 summary_path = os.path.join(output_folder, "model_summary.txt")
244 logger.info(f"Saving model summary at {summary_path}...")
245 with open(summary_path, "wt") as f:
246 r, n = summary(model)
247 logger.info(f"Model has {n} parameters...")
248 f.write(r)
250 # write static information to a CSV file
251 static_logfile_name = os.path.join(output_folder, "constants.csv")
252 if os.path.exists(static_logfile_name):
253 backup = static_logfile_name + "~"
254 if os.path.exists(backup):
255 os.unlink(backup)
256 shutil.move(static_logfile_name, backup)
257 with open(static_logfile_name, "w", newline="") as f:
258 logdata = cpu_constants()
259 if device != "cpu":
260 logdata += gpu_constants()
261 logdata += (("model_size", n),)
262 logwriter = csv.DictWriter(f, fieldnames=[k[0] for k in logdata])
263 logwriter.writeheader()
264 logwriter.writerow(dict(k for k in logdata))
266 # Log continous information to (another) file
267 logfile_name = os.path.join(output_folder, "trainlog.csv")
269 if arguments["epoch"] == 0 and os.path.exists(logfile_name):
270 backup = logfile_name + "~"
271 if os.path.exists(backup):
272 os.unlink(backup)
273 shutil.move(logfile_name, backup)
275 logfile_fields = (
276 "epoch",
277 "total_time",
278 "eta",
279 "average_loss",
280 "median_loss",
281 "labelled_median_loss",
282 "unlabelled_median_loss",
283 "learning_rate",
284 )
285 if valid_loader is not None:
286 logfile_fields += ("validation_average_loss", "validation_median_loss")
287 logfile_fields += tuple([k[0] for k in cpu_log()])
288 if device != "cpu":
289 logfile_fields += tuple([k[0] for k in gpu_log()])
291 # the lowest validation loss obtained so far - this value is updated only
292 # if a validation set is available
293 lowest_validation_loss = sys.float_info.max
295 with open(logfile_name, "a+", newline="") as logfile:
296 logwriter = csv.DictWriter(logfile, fieldnames=logfile_fields)
298 if arguments["epoch"] == 0:
299 logwriter.writeheader()
301 model.train() # set training mode
303 model.to(device) # set/cast parameters to device
304 for state in optimizer.state.values():
305 for k, v in state.items():
306 if isinstance(v, torch.Tensor):
307 state[k] = v.to(device)
309 # Total training timer
310 start_training_time = time.time()
312 for epoch in tqdm(
313 range(start_epoch, max_epoch),
314 desc="epoch",
315 leave=False,
316 disable=None,
317 ):
318 if not PYTORCH_GE_110:
319 scheduler.step()
320 losses = SmoothedValue(len(data_loader))
321 labelled_loss = SmoothedValue(len(data_loader))
322 unlabelled_loss = SmoothedValue(len(data_loader))
323 epoch = epoch + 1
324 arguments["epoch"] = epoch
326 # Epoch time
327 start_epoch_time = time.time()
329 # progress bar only on interactive jobs
330 for samples in tqdm(
331 data_loader, desc="batch", leave=False, disable=None
332 ):
334 # data forwarding on the existing network
336 # labelled
337 images = samples[1].to(
338 device=device, non_blocking=torch.cuda.is_available()
339 )
340 ground_truths = samples[2].to(
341 device=device, non_blocking=torch.cuda.is_available()
342 )
343 unlabelled_images = samples[4].to(
344 device=device, non_blocking=torch.cuda.is_available()
345 )
346 # labelled outputs
347 outputs = model(images)
348 unlabelled_outputs = model(unlabelled_images)
349 # guessed unlabelled outputs
350 unlabelled_ground_truths = guess_labels(
351 unlabelled_images, model
352 )
354 # loss evaluation and learning (backward step)
355 ramp_up_factor = square_rampup(
356 epoch, rampup_length=rampup_length
357 )
359 # note: no support for masks...
360 loss, ll, ul = criterion(
361 outputs,
362 ground_truths,
363 unlabelled_outputs,
364 unlabelled_ground_truths,
365 ramp_up_factor,
366 )
367 optimizer.zero_grad()
368 loss.backward()
369 optimizer.step()
370 losses.update(loss)
371 labelled_loss.update(ll)
372 unlabelled_loss.update(ul)
373 logger.debug(f"batch loss: {loss.item()}")
375 if PYTORCH_GE_110:
376 scheduler.step()
378 # calculates the validation loss if necessary
379 # note: validation does not comprise "unlabelled" losses
380 valid_losses = None
381 if valid_loader is not None:
383 with torch.no_grad(), torch_evaluation(model):
385 valid_losses = SmoothedValue(len(valid_loader))
386 for samples in tqdm(
387 valid_loader, desc="valid", leave=False, disable=None
388 ):
389 # data forwarding on the existing network
390 images = samples[1].to(
391 device=device,
392 non_blocking=torch.cuda.is_available(),
393 )
394 ground_truths = samples[2].to(
395 device=device,
396 non_blocking=torch.cuda.is_available(),
397 )
398 masks = (
399 torch.ones_like(ground_truths)
400 if len(samples) < 4
401 else samples[3].to(
402 device=device,
403 non_blocking=torch.cuda.is_available(),
404 )
405 )
407 outputs = model(images)
408 loss = criterion(outputs, ground_truths, masks)
409 valid_losses.update(loss)
411 if checkpoint_period and (epoch % checkpoint_period == 0):
412 checkpointer.save(f"model_{epoch:03d}", **arguments)
414 if (
415 valid_losses is not None
416 and valid_losses.avg < lowest_validation_loss
417 ):
418 lowest_validation_loss = valid_losses.avg
419 logger.info(
420 f"Found new low on validation set:"
421 f" {lowest_validation_loss:.6f}"
422 )
423 checkpointer.save("model_lowest_valid_loss", **arguments)
425 if epoch >= max_epoch:
426 checkpointer.save("model_final", **arguments)
428 # computes ETA (estimated time-of-arrival; end of training) taking
429 # into consideration previous epoch performance
430 epoch_time = time.time() - start_epoch_time
431 eta_seconds = epoch_time * (max_epoch - epoch)
432 current_time = time.time() - start_training_time
434 logdata = (
435 ("epoch", f"{epoch}"),
436 (
437 "total_time",
438 f"{datetime.timedelta(seconds=int(current_time))}",
439 ),
440 ("eta", f"{datetime.timedelta(seconds=int(eta_seconds))}"),
441 ("average_loss", f"{losses.avg:.6f}"),
442 ("median_loss", f"{losses.median:.6f}"),
443 ("labelled_median_loss", f"{labelled_loss.median:.6f}"),
444 ("unlabelled_median_loss", f"{unlabelled_loss.median:.6f}"),
445 ("learning_rate", f"{optimizer.param_groups[0]['lr']:.6f}"),
446 )
447 if valid_losses is not None:
448 logdata += (
449 ("validation_average_loss", f"{valid_losses.avg:.6f}"),
450 ("validation_median_loss", f"{valid_losses.median:.6f}"),
451 )
452 logdata += cpu_log()
453 if device != "cpu":
454 logdata += gpu_log()
456 if device != "cpu":
457 logdata += gpu_log()
459 logwriter.writerow(dict(k for k in logdata))
460 logfile.flush()
461 tqdm.write("|".join([f"{k}: {v}" for (k, v) in logdata[:4]]))
463 total_training_time = time.time() - start_training_time
464 logger.info(
465 f"Total training time: {datetime.timedelta(seconds=total_training_time)} ({(total_training_time/max_epoch):.4f}s in average per epoch)"
466 )