Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
4import contextlib
5import csv
6import datetime
7import distutils.version
8import logging
9import os
10import shutil
11import sys
12import time
14import torch
16from tqdm import tqdm
18from ..utils.measure import SmoothedValue
19from ..utils.resources import cpu_constants, cpu_log, gpu_constants, gpu_log
20from ..utils.summary import summary
22logger = logging.getLogger(__name__)
24PYTORCH_GE_110 = distutils.version.LooseVersion(torch.__version__) >= "1.1.0"
27@contextlib.contextmanager
28def torch_evaluation(model):
29 """Context manager to turn ON/OFF model evaluation
31 This context manager will turn evaluation mode ON on entry and turn it OFF
32 when exiting the ``with`` statement block.
35 Parameters
36 ----------
38 model : :py:class:`torch.nn.Module`
39 Network (e.g. driu, hed, unet)
42 Yields
43 ------
45 model : :py:class:`torch.nn.Module`
46 Network (e.g. driu, hed, unet)
48 """
50 model.eval()
51 yield model
52 model.train()
55def check_gpu(device):
56 """
57 Check the device type and the availability of GPU.
59 Parameters
60 ----------
62 device : :py:class:`torch.device`
63 device to use
65 """
66 if device.type == "cuda":
67 # asserts we do have a GPU
68 assert bool(
69 gpu_constants()
70 ), f"Device set to '{device}', but nvidia-smi is not installed"
73def save_model_summary(output_folder, model):
74 """
75 Save a little summary of the model in a txt file.
77 Parameters
78 ----------
80 output_folder : str
81 output path
83 model : :py:class:`torch.nn.Module`
84 Network (e.g. driu, hed, unet)
86 Returns
87 -------
88 r : str
89 The model summary in a text format.
91 n : int
92 The number of parameters of the model.
94 """
95 summary_path = os.path.join(output_folder, "model_summary.txt")
96 logger.info(f"Saving model summary at {summary_path}...")
97 with open(summary_path, "wt") as f:
98 r, n = summary(model)
99 logger.info(f"Model has {n} parameters...")
100 f.write(r)
101 return r, n
104def static_information_to_csv(static_logfile_name, device, n):
105 """
106 Save the static information in a csv file.
108 Parameters
109 ----------
111 static_logfile_name : str
112 The static file name which is a join between the output folder and "constant.csv"
114 """
115 if os.path.exists(static_logfile_name):
116 backup = static_logfile_name + "~"
117 if os.path.exists(backup):
118 os.unlink(backup)
119 shutil.move(static_logfile_name, backup)
120 with open(static_logfile_name, "w", newline="") as f:
121 logdata = cpu_constants()
122 if device.type == "cuda":
123 logdata += gpu_constants()
124 logdata += (("model_size", n),)
125 logwriter = csv.DictWriter(f, fieldnames=[k[0] for k in logdata])
126 logwriter.writeheader()
127 logwriter.writerow(dict(k for k in logdata))
130def check_exist_logfile(logfile_name, arguments):
131 """
132 Check existance of logfile (trainlog.csv),
133 If the logfile exist the and the epochs number are still 0, The logfile will be replaced.
135 Parameters
136 ----------
138 logfile_name : str
139 The logfile_name which is a join between the output_folder and trainlog.csv
141 arguments : dict
142 start and end epochs
144 """
145 if arguments["epoch"] == 0 and os.path.exists(logfile_name):
146 backup = logfile_name + "~"
147 if os.path.exists(backup):
148 os.unlink(backup)
149 shutil.move(logfile_name, backup)
152def create_logfile_fields(valid_loader, device):
153 """
154 Creation of the logfile fields that will appear in the logfile.
156 Parameters
157 ----------
159 valid_loader : :py:class:`torch.utils.data.DataLoader`
160 To be used to validate the model and enable automatic checkpointing.
161 If set to ``None``, then do not validate it.
163 device : :py:class:`torch.device`
164 device to use
166 Returns
167 -------
169 logfile_fields: tuple
170 The fields that will appear in trainlog.csv
173 """
174 logfile_fields = (
175 "epoch",
176 "total_time",
177 "eta",
178 "average_loss",
179 "median_loss",
180 "learning_rate",
181 )
182 if valid_loader is not None:
183 logfile_fields += ("validation_average_loss", "validation_median_loss")
184 logfile_fields += tuple([k[0] for k in cpu_log()])
185 if device.type == "cuda":
186 logfile_fields += tuple([k[0] for k in gpu_log()])
187 return logfile_fields
190def train_sample_process(samples, model, optimizer, losses, device, criterion):
191 """
192 Processing the training inputs (Images, ground truth, masks) and apply the backprogration to update the training losses.
194 Parameters
195 ----------
197 samples : list
199 model : :py:class:`torch.nn.Module`
200 Network (e.g. driu, hed, unet)
202 optimizer : :py:mod:`torch.optim`
204 losses : :py:class:`bob.ip.binseg.utils.measure.SmoothedValue`
206 device : :py:class:`torch.device`
207 device to use
209 criterion : :py:class:`torch.nn.modules.loss._Loss`
210 loss function
212 Returns
213 -------
215 losses : :py:class:`bob.ip.binseg.utils.measure.SmoothedValue`
217 optimizer : :py:mod:`torch.optim`
220 """
221 images = samples[1].to(
222 device=device, non_blocking=torch.cuda.is_available()
223 )
224 ground_truths = samples[2].to(
225 device=device, non_blocking=torch.cuda.is_available()
226 )
227 masks = (
228 torch.ones_like(ground_truths)
229 if len(samples) < 4
230 else samples[3].to(
231 device=device, non_blocking=torch.cuda.is_available()
232 )
233 )
234 outputs = model(images)
235 loss = criterion(outputs, ground_truths, masks)
236 optimizer.zero_grad()
237 loss.backward()
238 optimizer.step()
239 losses.update(loss)
240 logger.debug(f"batch loss: {loss.item()}")
241 return losses, optimizer
244def valid_sample_process(samples, model, valid_losses, device, criterion):
246 """
247 Processing the validation inputs (Images, ground truth, masks) and update validation losses.
249 Parameters
250 ----------
252 samples : list
254 model : :py:class:`torch.nn.Module`
255 Network (e.g. driu, hed, unet)
257 optimizer : :py:mod:`torch.optim`
259 valid_losses : :py:class:`bob.ip.binseg.utils.measure.SmoothedValue`
261 device : :py:class:`torch.device`
262 device to use
264 criterion : :py:class:`torch.nn.modules.loss._Loss`
265 loss function
267 Returns
268 -------
270 valid_losses : :py:class:`bob.ip.binseg.utils.measure.SmoothedValue`
272 """
273 images = samples[1].to(
274 device=device,
275 non_blocking=torch.cuda.is_available(),
276 )
277 ground_truths = samples[2].to(
278 device=device,
279 non_blocking=torch.cuda.is_available(),
280 )
281 masks = (
282 torch.ones_like(ground_truths)
283 if len(samples) < 4
284 else samples[3].to(
285 device=device,
286 non_blocking=torch.cuda.is_available(),
287 )
288 )
290 outputs = model(images)
291 loss = criterion(outputs, ground_truths, masks)
292 valid_losses.update(loss)
293 return valid_losses
296def checkpointer_process(
297 checkpointer,
298 checkpoint_period,
299 valid_losses,
300 lowest_validation_loss,
301 arguments,
302 epoch,
303 max_epoch,
304):
305 """
306 Process the checkpointer, save the final model and keep track of the best model.
308 Parameters
309 ----------
311 checkpointer : :py:class:`bob.ip.binseg.utils.checkpointer.Checkpointer`
312 checkpointer implementation
314 checkpoint_period : int
315 save a checkpoint every ``n`` epochs. If set to ``0`` (zero), then do
316 not save intermediary checkpoints
318 valid_losses : :py:class:`bob.ip.binseg.utils.measure.SmoothedValue`
320 lowest_validation_loss : float
321 Keep track of the best (lowest) validation loss
323 arguments : dict
324 start and end epochs
326 max_epoch : int
327 end_potch
331 """
332 if checkpoint_period and (epoch % checkpoint_period == 0):
333 checkpointer.save(f"model_{epoch:03d}", **arguments)
335 if valid_losses is not None and valid_losses.avg < lowest_validation_loss:
336 lowest_validation_loss = valid_losses.avg
337 logger.info(
338 f"Found new low on validation set:" f" {lowest_validation_loss:.6f}"
339 )
340 checkpointer.save("model_lowest_valid_loss", **arguments)
342 if epoch >= max_epoch:
343 checkpointer.save("model_final", **arguments)
346def write_log_info(
347 epoch,
348 current_time,
349 eta_seconds,
350 losses,
351 valid_losses,
352 optimizer,
353 logwriter,
354 logfile,
355 device,
356):
357 """
358 Write log info in trainlog.csv
360 Parameters
361 ----------
363 epoch : int
364 Current epoch
366 current_time : float
367 Current training time
369 eta_seconds : float
370 estimated time-of-arrival taking into consideration previous epoch performance
372 losses : :py:class:`bob.ip.binseg.utils.measure.SmoothedValue`
374 valid_losses : :py:class:`bob.ip.binseg.utils.measure.SmoothedValue`
376 optimizer : :py:mod:`torch.optim`
378 logwriter : csv.DictWriter
379 Dictionary writer that give the ability to write on the trainlog.csv
381 logfile: io.TextIOWrapper
383 device : :py:class:`torch.device`
384 device to use
388 """
389 logdata = (
390 ("epoch", f"{epoch}"),
391 (
392 "total_time",
393 f"{datetime.timedelta(seconds=int(current_time))}",
394 ),
395 ("eta", f"{datetime.timedelta(seconds=int(eta_seconds))}"),
396 ("average_loss", f"{losses.avg:.6f}"),
397 ("median_loss", f"{losses.median:.6f}"),
398 ("learning_rate", f"{optimizer.param_groups[0]['lr']:.6f}"),
399 )
400 if valid_losses is not None:
401 logdata += (
402 ("validation_average_loss", f"{valid_losses.avg:.6f}"),
403 ("validation_median_loss", f"{valid_losses.median:.6f}"),
404 )
405 logdata += cpu_log()
406 if device.type == "cuda":
407 logdata += gpu_log()
409 logwriter.writerow(dict(k for k in logdata))
410 logfile.flush()
411 tqdm.write("|".join([f"{k}: {v}" for (k, v) in logdata[:4]]))
414def run(
415 model,
416 data_loader,
417 valid_loader,
418 optimizer,
419 criterion,
420 scheduler,
421 checkpointer,
422 checkpoint_period,
423 device,
424 arguments,
425 output_folder,
426):
427 """
428 Fits an FCN model using supervised learning and save it to disk.
430 This method supports periodic checkpointing and the output of a
431 CSV-formatted log with the evolution of some figures during training.
434 Parameters
435 ----------
437 model : :py:class:`torch.nn.Module`
438 Network (e.g. driu, hed, unet)
440 data_loader : :py:class:`torch.utils.data.DataLoader`
441 To be used to train the model
443 valid_loader : :py:class:`torch.utils.data.DataLoader`
444 To be used to validate the model and enable automatic checkpointing.
445 If set to ``None``, then do not validate it.
447 optimizer : :py:mod:`torch.optim`
449 criterion : :py:class:`torch.nn.modules.loss._Loss`
450 loss function
452 scheduler : :py:mod:`torch.optim`
453 learning rate scheduler
455 checkpointer : :py:class:`bob.ip.binseg.utils.checkpointer.Checkpointer`
456 checkpointer implementation
458 checkpoint_period : int
459 save a checkpoint every ``n`` epochs. If set to ``0`` (zero), then do
460 not save intermediary checkpoints
462 device : :py:class:`torch.device`
463 device to use
465 arguments : dict
466 start and end epochs
468 output_folder : str
469 output path
470 """
472 start_epoch = arguments["epoch"]
473 max_epoch = arguments["max_epoch"]
475 check_gpu(device)
477 os.makedirs(output_folder, exist_ok=True)
479 # Save model summary
480 r, n = save_model_summary(output_folder, model)
482 # write static information to a CSV file
483 static_logfile_name = os.path.join(output_folder, "constants.csv")
485 static_information_to_csv(static_logfile_name, device, n)
487 # Log continous information to (another) file
488 logfile_name = os.path.join(output_folder, "trainlog.csv")
490 check_exist_logfile(logfile_name, arguments)
492 logfile_fields = create_logfile_fields(valid_loader, device)
494 # the lowest validation loss obtained so far - this value is updated only
495 # if a validation set is available
496 lowest_validation_loss = sys.float_info.max
498 with open(logfile_name, "a+", newline="") as logfile:
499 logwriter = csv.DictWriter(logfile, fieldnames=logfile_fields)
501 if arguments["epoch"] == 0:
502 logwriter.writeheader()
504 model.train() # set training mode
506 model.to(device) # set/cast parameters to device
507 for state in optimizer.state.values():
508 for k, v in state.items():
509 if isinstance(v, torch.Tensor):
510 state[k] = v.to(device)
512 # Total training timer
513 start_training_time = time.time()
515 for epoch in tqdm(
516 range(start_epoch, max_epoch),
517 desc="epoch",
518 leave=False,
519 disable=None,
520 ):
521 if not PYTORCH_GE_110:
522 scheduler.step()
523 losses = SmoothedValue(len(data_loader))
524 epoch = epoch + 1
525 arguments["epoch"] = epoch
527 # Epoch time
528 start_epoch_time = time.time()
530 # progress bar only on interactive jobs
531 for samples in tqdm(
532 data_loader, desc="batch", leave=False, disable=None
533 ):
534 # data forwarding on the existing network
535 losses, optimizer = train_sample_process(
536 samples, model, optimizer, losses, device, criterion
537 )
539 if PYTORCH_GE_110:
540 scheduler.step()
542 # calculates the validation loss if necessary
543 valid_losses = None
544 if valid_loader is not None:
546 with torch.no_grad(), torch_evaluation(model):
548 valid_losses = SmoothedValue(len(valid_loader))
549 for samples in tqdm(
550 valid_loader, desc="valid", leave=False, disable=None
551 ):
552 # data forwarding on the existing network
553 valid_losses = valid_sample_process(
554 samples, model, valid_losses, device, criterion
555 )
557 checkpointer_process(
558 checkpointer,
559 checkpoint_period,
560 valid_losses,
561 lowest_validation_loss,
562 arguments,
563 epoch,
564 max_epoch,
565 )
567 # computes ETA (estimated time-of-arrival; end of training) taking
568 # into consideration previous epoch performance
569 epoch_time = time.time() - start_epoch_time
570 eta_seconds = epoch_time * (max_epoch - epoch)
571 current_time = time.time() - start_training_time
573 write_log_info(
574 epoch,
575 current_time,
576 eta_seconds,
577 losses,
578 valid_losses,
579 optimizer,
580 logwriter,
581 logfile,
582 device,
583 )
585 total_training_time = time.time() - start_training_time
586 logger.info(
587 f"Total training time: {datetime.timedelta(seconds=total_training_time)} ({(total_training_time/max_epoch):.4f}s in average per epoch)"
588 )