Coverage for src/deepdraw/engine/trainer.py: 91%
159 statements
« prev ^ index » next coverage.py v7.3.1, created at 2023-11-30 15:00 +0100
« prev ^ index » next coverage.py v7.3.1, created at 2023-11-30 15:00 +0100
1# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
2#
3# SPDX-License-Identifier: GPL-3.0-or-later
5import contextlib
6import csv
7import datetime
8import logging
9import os
10import shutil
11import sys
12import time
14import numpy
15import torch
17from tqdm import tqdm
19from ..utils.resources import ResourceMonitor, cpu_constants, gpu_constants
20from ..utils.summary import summary
22logger = logging.getLogger(__name__)
25@contextlib.contextmanager
26def torch_evaluation(model):
27 """Context manager to turn ON/OFF model evaluation.
29 This context manager will turn evaluation mode ON on entry and turn it OFF
30 when exiting the ``with`` statement block.
33 Parameters
34 ----------
36 model : :py:class:`torch.nn.Module`
37 Network (e.g. driu, hed, unet)
40 Yields
41 ------
43 model : :py:class:`torch.nn.Module`
44 Network (e.g. driu, hed, unet)
45 """
47 model.eval()
48 yield model
49 model.train()
52def check_gpu(device):
53 """Check the device type and the availability of GPU.
55 Parameters
56 ----------
58 device : :py:class:`torch.device`
59 device to use
60 """
61 if device.type == "cuda":
62 # asserts we do have a GPU
63 assert bool(
64 gpu_constants()
65 ), f"Device set to '{device}', but nvidia-smi is not installed"
68def save_model_summary(output_folder, model):
69 """Save a little summary of the model in a txt file.
71 Parameters
72 ----------
74 output_folder : str
75 output path
77 model : :py:class:`torch.nn.Module`
78 Network (e.g. driu, hed, unet)
80 Returns
81 -------
82 r : str
83 The model summary in a text format.
85 n : int
86 The number of parameters of the model.
87 """
88 summary_path = os.path.join(output_folder, "model_summary.txt")
89 logger.info(f"Saving model summary at {summary_path}...")
90 with open(summary_path, "w") as f:
91 r, n = summary(model)
92 logger.info(f"Model has {n} parameters...")
93 f.write(r)
94 return r, n
97def static_information_to_csv(static_logfile_name, device, n):
98 """Save the static information in a csv file.
100 Parameters
101 ----------
103 static_logfile_name : str
104 The static file name which is a join between the output folder and "constant.csv"
105 """
106 if os.path.exists(static_logfile_name):
107 backup = static_logfile_name + "~"
108 if os.path.exists(backup):
109 os.unlink(backup)
110 shutil.move(static_logfile_name, backup)
111 with open(static_logfile_name, "w", newline="") as f:
112 logdata = cpu_constants()
113 if device.type == "cuda":
114 logdata += gpu_constants()
115 logdata += (("model_size", n),)
116 logwriter = csv.DictWriter(f, fieldnames=[k[0] for k in logdata])
117 logwriter.writeheader()
118 logwriter.writerow(dict(k for k in logdata))
121def check_exist_logfile(logfile_name, arguments):
122 """Check existance of logfile (trainlog.csv), If the logfile exist the and
123 the epochs number are still 0, The logfile will be replaced.
125 Parameters
126 ----------
128 logfile_name : str
129 The logfile_name which is a join between the output_folder and trainlog.csv
131 arguments : dict
132 start and end epochs
133 """
134 if arguments["epoch"] == 0 and os.path.exists(logfile_name):
135 backup = logfile_name + "~"
136 if os.path.exists(backup):
137 os.unlink(backup)
138 shutil.move(logfile_name, backup)
141def create_logfile_fields(valid_loader, extra_valid_loaders, device):
142 """Creation of the logfile fields that will appear in the logfile.
144 Parameters
145 ----------
147 valid_loader : :py:class:`torch.utils.data.DataLoader`
148 To be used to validate the model and enable automatic checkpointing.
149 If set to ``None``, then do not validate it.
151 extra_valid_loaders : :py:class:`list` of :py:class:`torch.utils.data.DataLoader`
152 To be used to validate the model, however **does not affect** automatic
153 checkpointing. If set to ``None``, or empty, then does not log anything
154 else. Otherwise, an extra column with the loss of every dataset in
155 this list is kept on the final training log.
157 device : :py:class:`torch.device`
158 device to use
160 Returns
161 -------
163 logfile_fields: tuple
164 The fields that will appear in trainlog.csv
165 """
166 logfile_fields = (
167 "epoch",
168 "total_time",
169 "eta",
170 "loss",
171 "learning_rate",
172 )
173 if valid_loader is not None:
174 logfile_fields += ("validation_loss",)
175 if extra_valid_loaders:
176 logfile_fields += ("extra_validation_losses",)
177 logfile_fields += tuple(
178 ResourceMonitor.monitored_keys(device.type == "cuda")
179 )
180 return logfile_fields
183def train_epoch(loader, model, optimizer, device, criterion, batch_chunk_count):
184 """Trains the model for a single epoch (through all batches)
186 Parameters
187 ----------
189 loader : :py:class:`torch.utils.data.DataLoader`
190 To be used to train the model
192 model : :py:class:`torch.nn.Module`
193 Network (e.g. driu, hed, unet)
195 optimizer : :py:mod:`torch.optim`
197 device : :py:class:`torch.device`
198 device to use
200 criterion : :py:class:`torch.nn.modules.loss._Loss`
202 batch_chunk_count: int
203 If this number is different than 1, then each batch will be divided in
204 this number of chunks. Gradients will be accumulated to perform each
205 mini-batch. This is particularly interesting when one has limited RAM
206 on the GPU, but would like to keep training with larger batches. One
207 exchanges for longer processing times in this case. To better understand
208 gradient accumulation, read
209 https://stackoverflow.com/questions/62067400/understanding-accumulated-gradients-in-pytorch.
212 Returns
213 -------
215 loss : float
216 A floating-point value corresponding the weighted average of this
217 epoch's loss
218 """
220 losses_in_epoch = []
221 samples_in_epoch = []
222 losses_in_batch = []
223 samples_in_batch = []
225 # progress bar only on interactive jobs
226 for idx, samples in enumerate(
227 tqdm(loader, desc="train", leave=False, disable=None)
228 ):
229 images = samples[1].to(
230 device=device, non_blocking=torch.cuda.is_available()
231 )
232 ground_truths = samples[2].to(
233 device=device, non_blocking=torch.cuda.is_available()
234 )
235 masks = (
236 torch.ones_like(ground_truths)
237 if len(samples) < 4
238 else samples[3].to(
239 device=device, non_blocking=torch.cuda.is_available()
240 )
241 )
243 # Forward pass on the network
244 outputs = model(images)
245 loss = criterion(outputs, ground_truths, masks)
247 losses_in_batch.append(loss.item())
248 samples_in_batch.append(len(samples))
250 # Normalize loss to account for batch accumulation
251 loss = loss / batch_chunk_count
253 # Accumulate gradients - does not update weights just yet...
254 loss.backward()
256 # Weight update on the network
257 if ((idx + 1) % batch_chunk_count == 0) or (idx + 1 == len(loader)):
258 # Advances optimizer to the "next" state and applies weight update
259 # over the whole model
260 optimizer.step()
262 # Zeroes gradients for the next batch
263 optimizer.zero_grad()
265 # Normalize loss for current batch
266 batch_loss = numpy.average(
267 losses_in_batch, weights=samples_in_batch
268 )
269 losses_in_epoch.append(batch_loss.item())
270 samples_in_epoch.append(len(samples))
272 losses_in_batch.clear()
273 samples_in_batch.clear()
274 logger.debug(f"batch loss: {batch_loss.item()}")
276 return numpy.average(losses_in_epoch, weights=samples_in_epoch)
279def validate_epoch(loader, model, device, criterion, pbar_desc):
280 """Processes input samples and returns loss (scalar)
282 Parameters
283 ----------
285 loader : :py:class:`torch.utils.data.DataLoader`
286 To be used to validate the model
288 model : :py:class:`torch.nn.Module`
289 Network (e.g. driu, hed, unet)
291 optimizer : :py:mod:`torch.optim`
293 device : :py:class:`torch.device`
294 device to use
296 criterion : :py:class:`torch.nn.modules.loss._Loss`
297 loss function
299 pbar_desc : str
300 A string for the progress bar descriptor
303 Returns
304 -------
306 loss : float
307 A floating-point value corresponding the weighted average of this
308 epoch's loss
309 """
311 batch_losses = []
312 samples_in_batch = []
314 with torch.no_grad(), torch_evaluation(model):
315 for samples in tqdm(loader, desc=pbar_desc, leave=False, disable=None):
316 images = samples[1].to(
317 device=device,
318 non_blocking=torch.cuda.is_available(),
319 )
320 ground_truths = samples[2].to(
321 device=device,
322 non_blocking=torch.cuda.is_available(),
323 )
324 masks = (
325 torch.ones_like(ground_truths)
326 if len(samples) < 4
327 else samples[3].to(
328 device=device,
329 non_blocking=torch.cuda.is_available(),
330 )
331 )
333 # data forwarding on the existing network
334 outputs = model(images)
335 loss = criterion(outputs, ground_truths, masks)
337 batch_losses.append(loss.item())
338 samples_in_batch.append(len(samples))
340 return numpy.average(batch_losses, weights=samples_in_batch)
343def checkpointer_process(
344 checkpointer,
345 checkpoint_period,
346 valid_loss,
347 lowest_validation_loss,
348 arguments,
349 epoch,
350 max_epoch,
351):
352 """Process the checkpointer, save the final model and keep track of the
353 best model.
355 Parameters
356 ----------
358 checkpointer : :py:class:`deepdraw.utils.checkpointer.Checkpointer`
359 checkpointer implementation
361 checkpoint_period : int
362 save a checkpoint every ``n`` epochs. If set to ``0`` (zero), then do
363 not save intermediary checkpoints
365 valid_loss : float
366 Current epoch validation loss
368 lowest_validation_loss : float
369 Keeps track of the best (lowest) validation loss
371 arguments : dict
372 start and end epochs
374 max_epoch : int
375 end_potch
377 Returns
378 -------
380 lowest_validation_loss : float
381 The lowest validation loss currently observed
382 """
383 if checkpoint_period and (epoch % checkpoint_period == 0):
384 checkpointer.save("model_periodic_save", **arguments)
386 if valid_loss is not None and valid_loss < lowest_validation_loss:
387 lowest_validation_loss = valid_loss
388 logger.info(
389 f"Found new low on validation set:" f" {lowest_validation_loss:.6f}"
390 )
391 checkpointer.save("model_lowest_valid_loss", **arguments)
393 if epoch >= max_epoch:
394 checkpointer.save("model_final_epoch", **arguments)
396 return lowest_validation_loss
399def write_log_info(
400 epoch,
401 current_time,
402 eta_seconds,
403 loss,
404 valid_loss,
405 extra_valid_losses,
406 optimizer,
407 logwriter,
408 logfile,
409 resource_data,
410):
411 """Write log info in trainlog.csv.
413 Parameters
414 ----------
416 epoch : int
417 Current epoch
419 current_time : float
420 Current training time
422 eta_seconds : float
423 estimated time-of-arrival taking into consideration previous epoch performance
425 loss : float
426 Current epoch's training loss
428 valid_loss : :py:class:`float`, None
429 Current epoch's validation loss
431 extra_valid_losses : :py:class:`list` of :py:class:`float`
432 Validation losses from other validation datasets being currently
433 tracked
435 optimizer : :py:mod:`torch.optim`
437 logwriter : csv.DictWriter
438 Dictionary writer that give the ability to write on the trainlog.csv
440 logfile : io.TextIOWrapper
442 resource_data : tuple
443 Monitored resources at the machine (CPU and GPU)
444 """
446 logdata = (
447 ("epoch", f"{epoch}"),
448 (
449 "total_time",
450 f"{datetime.timedelta(seconds=int(current_time))}",
451 ),
452 ("eta", f"{datetime.timedelta(seconds=int(eta_seconds))}"),
453 ("loss", f"{loss:.6f}"),
454 ("learning_rate", f"{optimizer.param_groups[0]['lr']:.6f}"),
455 )
457 if valid_loss is not None:
458 logdata += (("validation_loss", f"{valid_loss:.6f}"),)
460 if extra_valid_losses:
461 entry = numpy.array_str(
462 numpy.array(extra_valid_losses),
463 max_line_width=sys.maxsize,
464 precision=6,
465 )
466 logdata += (("extra_validation_losses", entry),)
468 logdata += resource_data
470 logwriter.writerow(dict(k for k in logdata))
471 logfile.flush()
472 tqdm.write("|".join([f"{k}: {v}" for (k, v) in logdata[:4]]))
475def run(
476 model,
477 data_loader,
478 valid_loader,
479 extra_valid_loaders,
480 optimizer,
481 criterion,
482 scheduler,
483 checkpointer,
484 checkpoint_period,
485 device,
486 arguments,
487 output_folder,
488 monitoring_interval,
489 batch_chunk_count,
490):
491 """Fits an FCN model using supervised learning and save it to disk.
493 This method supports periodic checkpointing and the output of a
494 CSV-formatted log with the evolution of some figures during training.
497 Parameters
498 ----------
500 model : :py:class:`torch.nn.Module`
501 Network (e.g. driu, hed, unet)
503 data_loader : :py:class:`torch.utils.data.DataLoader`
504 To be used to train the model
506 valid_loaders : :py:class:`list` of :py:class:`torch.utils.data.DataLoader`
507 To be used to validate the model and enable automatic checkpointing.
508 If ``None``, then do not validate it.
510 extra_valid_loaders : :py:class:`list` of :py:class:`torch.utils.data.DataLoader`
511 To be used to validate the model, however **does not affect** automatic
512 checkpointing. If empty, then does not log anything else. Otherwise,
513 an extra column with the loss of every dataset in this list is kept on
514 the final training log.
516 optimizer : :py:mod:`torch.optim`
518 criterion : :py:class:`torch.nn.modules.loss._Loss`
519 loss function
521 scheduler : :py:mod:`torch.optim`
522 learning rate scheduler
524 checkpointer : :py:class:`deepdraw.utils.checkpointer.Checkpointer`
525 checkpointer implementation
527 checkpoint_period : int
528 save a checkpoint every ``n`` epochs. If set to ``0`` (zero), then do
529 not save intermediary checkpoints
531 device : :py:class:`torch.device`
532 device to use
534 arguments : dict
535 start and end epochs
537 output_folder : str
538 output path
540 monitoring_interval : int, float
541 interval, in seconds (or fractions), through which we should monitor
542 resources during training.
544 batch_chunk_count: int
545 If this number is different than 1, then each batch will be divided in
546 this number of chunks. Gradients will be accumulated to perform each
547 mini-batch. This is particularly interesting when one has limited RAM
548 on the GPU, but would like to keep training with larger batches. One
549 exchanges for longer processing times in this case.
550 """
552 start_epoch = arguments["epoch"]
553 max_epoch = arguments["max_epoch"]
555 check_gpu(device)
557 os.makedirs(output_folder, exist_ok=True)
559 # Save model summary
560 r, n = save_model_summary(output_folder, model)
562 # write static information to a CSV file
563 static_logfile_name = os.path.join(output_folder, "constants.csv")
565 static_information_to_csv(static_logfile_name, device, n)
567 # Log continous information to (another) file
568 logfile_name = os.path.join(output_folder, "trainlog.csv")
570 check_exist_logfile(logfile_name, arguments)
572 logfile_fields = create_logfile_fields(
573 valid_loader, extra_valid_loaders, device
574 )
576 # the lowest validation loss obtained so far - this value is updated only
577 # if a validation set is available
578 lowest_validation_loss = sys.float_info.max
580 with open(logfile_name, "a+", newline="") as logfile:
581 logwriter = csv.DictWriter(logfile, fieldnames=logfile_fields)
583 if arguments["epoch"] == 0:
584 logwriter.writeheader()
586 model.train() # set training mode
588 model.to(device) # set/cast parameters to device
589 for state in optimizer.state.values():
590 for k, v in state.items():
591 if isinstance(v, torch.Tensor):
592 state[k] = v.to(device)
594 # Total training timer
595 start_training_time = time.time()
597 for epoch in tqdm(
598 range(start_epoch, max_epoch),
599 desc="epoch",
600 leave=False,
601 disable=None,
602 ):
603 with ResourceMonitor(
604 interval=monitoring_interval,
605 has_gpu=(device.type == "cuda"),
606 main_pid=os.getpid(),
607 logging_level=logging.ERROR,
608 ) as resource_monitor:
609 epoch = epoch + 1
610 arguments["epoch"] = epoch
612 # Epoch time
613 start_epoch_time = time.time()
615 train_loss = train_epoch(
616 data_loader,
617 model,
618 optimizer,
619 device,
620 criterion,
621 batch_chunk_count,
622 )
624 scheduler.step()
626 valid_loss = (
627 validate_epoch(
628 valid_loader, model, device, criterion, "valid"
629 )
630 if valid_loader is not None
631 else None
632 )
634 extra_valid_losses = []
635 for pos, extra_valid_loader in enumerate(extra_valid_loaders):
636 loss = validate_epoch(
637 extra_valid_loader,
638 model,
639 device,
640 criterion,
641 f"xval@{pos+1}",
642 )
643 extra_valid_losses.append(loss)
645 lowest_validation_loss = checkpointer_process(
646 checkpointer,
647 checkpoint_period,
648 valid_loss,
649 lowest_validation_loss,
650 arguments,
651 epoch,
652 max_epoch,
653 )
655 # computes ETA (estimated time-of-arrival; end of training) taking
656 # into consideration previous epoch performance
657 epoch_time = time.time() - start_epoch_time
658 eta_seconds = epoch_time * (max_epoch - epoch)
659 current_time = time.time() - start_training_time
661 write_log_info(
662 epoch,
663 current_time,
664 eta_seconds,
665 train_loss,
666 valid_loss,
667 extra_valid_losses,
668 optimizer,
669 logwriter,
670 logfile,
671 resource_monitor.data,
672 )
674 total_training_time = time.time() - start_training_time
675 logger.info(
676 f"Total training time: {datetime.timedelta(seconds=total_training_time)} ({(total_training_time/max_epoch):.4f}s in average per epoch)"
677 )