1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3
4import os
5import sys
6import csv
7import time
8import shutil
9import datetime
10import contextlib
11
12import numpy
13import torch
14from tqdm import tqdm
15
16from ..utils.measure import SmoothedValue
17from ..utils.summary import summary
18
19# from ..utils.resources import cpu_constants, gpu_constants, cpu_log, gpu_log
20from ..utils.resources import (
21 ResourceMonitor,
22 cpu_constants,
23 gpu_constants,
24)
25
26import logging
27
28logger = logging.getLogger(__name__)
29
30
31@contextlib.contextmanager
32def torch_evaluation(model):
33 """Context manager to turn ON/OFF model evaluation
34
35 This context manager will turn evaluation mode ON on entry and turn it OFF
36 when exiting the ``with`` statement block.
37
38
39 Parameters
40 ----------
41
42 model : :py:class:`torch.nn.Module`
43 Network
44
45
46 Yields
47 ------
48
49 model : :py:class:`torch.nn.Module`
50 Network
51
52 """
53
54 model.eval()
55 yield model
56 model.train()
57
58
59def check_gpu(device):
60 """
61 Check the device type and the availability of GPU.
62
63 Parameters
64 ----------
65
66 device : :py:class:`torch.device`
67 device to use
68
69 """
70 if device.type == "cuda":
71 # asserts we do have a GPU
72 assert bool(
73 gpu_constants()
74 ), f"Device set to '{device}', but nvidia-smi is not installed"
75
76
77def save_model_summary(output_folder, model):
78 """
79 Save a little summary of the model in a txt file.
80
81 Parameters
82 ----------
83
84 output_folder : str
85 output path
86
87 model : :py:class:`torch.nn.Module`
88 Network (e.g. driu, hed, unet)
89
90 Returns
91 -------
92 r : str
93 The model summary in a text format.
94
95 n : int
96 The number of parameters of the model.
97
98 """
99 summary_path = os.path.join(output_folder, "model_summary.txt")
100 logger.info(f"Saving model summary at {summary_path}...")
101 with open(summary_path, "wt") as f:
102 r, n = summary(model)
103 logger.info(f"Model has {n} parameters...")
104 f.write(r)
105 return r, n
106
107
108def static_information_to_csv(static_logfile_name, device, n):
109 """
110 Save the static information in a csv file.
111
112 Parameters
113 ----------
114
115 static_logfile_name : str
116 The static file name which is a join between the output folder and "constant.csv"
117
118 """
119 if os.path.exists(static_logfile_name):
120 backup = static_logfile_name + "~"
121 if os.path.exists(backup):
122 os.unlink(backup)
123 shutil.move(static_logfile_name, backup)
124 with open(static_logfile_name, "w", newline="") as f:
125 logdata = cpu_constants()
126 if device.type == "cuda":
127 logdata += gpu_constants()
128 logdata += (("model_size", n),)
129 logwriter = csv.DictWriter(f, fieldnames=[k[0] for k in logdata])
130 logwriter.writeheader()
131 logwriter.writerow(dict(k for k in logdata))
132
133
134def check_exist_logfile(logfile_name, arguments):
135 """
136 Check existance of logfile (trainlog.csv),
137 If the logfile exist the and the epochs number are still 0, The logfile will be replaced.
138
139 Parameters
140 ----------
141
142 logfile_name : str
143 The logfile_name which is a join between the output_folder and trainlog.csv
144
145 arguments : dict
146 start and end epochs
147
148 """
149 if arguments["epoch"] == 0 and os.path.exists(logfile_name):
150 backup = logfile_name + "~"
151 if os.path.exists(backup):
152 os.unlink(backup)
153 shutil.move(logfile_name, backup)
154
155
156def create_logfile_fields(valid_loader, extra_valid_loaders, device):
157 """
158 Creation of the logfile fields that will appear in the logfile.
159
160 Parameters
161 ----------
162
163 valid_loader : :py:class:`torch.utils.data.DataLoader`
164 To be used to validate the model and enable automatic checkpointing.
165 If set to ``None``, then do not validate it.
166
167 extra_valid_loaders : :py:class:`list` of :py:class:`torch.utils.data.DataLoader`
168 To be used to validate the model, however **does not affect** automatic
169 checkpointing. If set to ``None``, or empty, then does not log anything
170 else. Otherwise, an extra column with the loss of every dataset in
171 this list is kept on the final training log.
172
173 device : :py:class:`torch.device`
174 device to use
175
176 Returns
177 -------
178
179 logfile_fields: tuple
180 The fields that will appear in trainlog.csv
181
182
183 """
184 logfile_fields = (
185 "epoch",
186 "total_time",
187 "eta",
188 "loss",
189 "learning_rate",
190 )
191 if valid_loader is not None:
192 logfile_fields += ("validation_loss",)
193 if extra_valid_loaders:
194 logfile_fields += ("extra_validation_losses",)
195 logfile_fields += tuple(
196 ResourceMonitor.monitored_keys(device.type == "cuda")
197 )
198 return logfile_fields
199
200
201def train_epoch(loader, model, optimizer, device, criterion, batch_chunk_count):
202 """Trains the model for a single epoch (through all batches)
203
204 Parameters
205 ----------
206
207 loader : :py:class:`torch.utils.data.DataLoader`
208 To be used to train the model
209
210 model : :py:class:`torch.nn.Module`
211 Network (e.g. driu, hed, unet)
212
213 optimizer : :py:mod:`torch.optim`
214
215 device : :py:class:`torch.device`
216 device to use
217
218 criterion : :py:class:`torch.nn.modules.loss._Loss`
219
220 batch_chunk_count: int
221 If this number is different than 1, then each batch will be divided in
222 this number of chunks. Gradients will be accumulated to perform each
223 mini-batch. This is particularly interesting when one has limited RAM
224 on the GPU, but would like to keep training with larger batches. One
225 exchanges for longer processing times in this case. To better understand
226 gradient accumulation, read
227 https://stackoverflow.com/questions/62067400/understanding-accumulated-gradients-in-pytorch.
228
229
230 Returns
231 -------
232
233 loss : float
234 A floating-point value corresponding the weighted average of this
235 epoch's loss
236
237 """
238
239 losses_in_epoch = []
240 samples_in_epoch = []
241 losses_in_batch = []
242 samples_in_batch = []
243
244 # progress bar only on interactive jobs
245 for idx, samples in enumerate(
246 tqdm(loader, desc="train", leave=False, disable=None)
247 ):
248
249 images = samples[1].to(
250 device=device, non_blocking=torch.cuda.is_available()
251 )
252 labels = samples[2].to(
253 device=device, non_blocking=torch.cuda.is_available()
254 )
255
256 # Increase label dimension if too low
257 # Allows single and multiclass usage
258 if labels.ndim == 1:
259 labels = torch.reshape(labels, (labels.shape[0], 1))
260
261 # Forward pass on the network
262 outputs = model(images)
263
264 loss = criterion(outputs, labels.double())
265
266 losses_in_batch.append(loss.item())
267 samples_in_batch.append(len(samples))
268
269 # Normalize loss to account for batch accumulation
270 loss = loss / batch_chunk_count
271
272 # Accumulate gradients - does not update weights just yet...
273 loss.backward()
274
275 # Weight update on the network
276 if ((idx + 1) % batch_chunk_count == 0) or (idx + 1 == len(loader)):
277 # Advances optimizer to the "next" state and applies weight update
278 # over the whole model
279 optimizer.step()
280
281 # Zeroes gradients for the next batch
282 optimizer.zero_grad()
283
284 # Normalize loss for current batch
285 batch_loss = numpy.average(
286 losses_in_batch, weights=samples_in_batch
287 )
288 losses_in_epoch.append(batch_loss.item())
289 samples_in_epoch.append(len(samples))
290
291 losses_in_batch.clear()
292 samples_in_batch.clear()
293 logger.debug(f"batch loss: {batch_loss.item()}")
294
295 return numpy.average(losses_in_epoch, weights=samples_in_epoch)
296
297
298def validate_epoch(loader, model, device, criterion, pbar_desc):
299 """
300 Processes input samples and returns loss (scalar)
301
302
303 Parameters
304 ----------
305
306 loader : :py:class:`torch.utils.data.DataLoader`
307 To be used to validate the model
308
309 model : :py:class:`torch.nn.Module`
310 Network (e.g. driu, hed, unet)
311
312 optimizer : :py:mod:`torch.optim`
313
314 device : :py:class:`torch.device`
315 device to use
316
317 criterion : :py:class:`torch.nn.modules.loss._Loss`
318 loss function
319
320 pbar_desc : str
321 A string for the progress bar descriptor
322
323
324 Returns
325 -------
326
327 loss : float
328 A floating-point value corresponding the weighted average of this
329 epoch's loss
330
331 """
332
333 batch_losses = []
334 samples_in_batch = []
335
336 with torch.no_grad(), torch_evaluation(model):
337
338 for samples in tqdm(loader, desc=pbar_desc, leave=False, disable=None):
339 images = samples[1].to(
340 device=device,
341 non_blocking=torch.cuda.is_available(),
342 )
343 labels = samples[2].to(
344 device=device,
345 non_blocking=torch.cuda.is_available(),
346 )
347
348 # Increase label dimension if too low
349 # Allows single and multiclass usage
350 if labels.ndim == 1:
351 labels = torch.reshape(labels, (labels.shape[0], 1))
352
353 # data forwarding on the existing network
354 outputs = model(images)
355 loss = criterion(outputs, labels.double())
356
357 batch_losses.append(loss.item())
358 samples_in_batch.append(len(samples))
359
360 return numpy.average(batch_losses, weights=samples_in_batch)
361
362
363def checkpointer_process(
364 checkpointer,
365 checkpoint_period,
366 valid_loss,
367 lowest_validation_loss,
368 arguments,
369 epoch,
370 max_epoch,
371):
372 """
373 Process the checkpointer, save the final model and keep track of the best model.
374
375 Parameters
376 ----------
377
378 checkpointer : :py:class:`bob.med.tb.utils.checkpointer.Checkpointer`
379 checkpointer implementation
380
381 checkpoint_period : int
382 save a checkpoint every ``n`` epochs. If set to ``0`` (zero), then do
383 not save intermediary checkpoints
384
385 valid_loss : float
386 Current epoch validation loss
387
388 lowest_validation_loss : float
389 Keeps track of the best (lowest) validation loss
390
391 arguments : dict
392 start and end epochs
393
394 max_epoch : int
395 end_potch
396
397 Returns
398 -------
399
400 lowest_validation_loss : float
401 The lowest validation loss currently observed
402
403
404 """
405 if checkpoint_period and (epoch % checkpoint_period == 0):
406 checkpointer.save("model_periodic_save", **arguments)
407
408 if valid_loss is not None and valid_loss < lowest_validation_loss:
409 lowest_validation_loss = valid_loss
410 logger.info(
411 f"Found new low on validation set:" f" {lowest_validation_loss:.6f}"
412 )
413 checkpointer.save("model_lowest_valid_loss", **arguments)
414
415 if epoch >= max_epoch:
416 checkpointer.save("model_final_epoch", **arguments)
417
418 return lowest_validation_loss
419
420
421def write_log_info(
422 epoch,
423 current_time,
424 eta_seconds,
425 loss,
426 valid_loss,
427 extra_valid_losses,
428 optimizer,
429 logwriter,
430 logfile,
431 resource_data,
432):
433 """
434 Write log info in trainlog.csv
435
436 Parameters
437 ----------
438
439 epoch : int
440 Current epoch
441
442 current_time : float
443 Current training time
444
445 eta_seconds : float
446 estimated time-of-arrival taking into consideration previous epoch performance
447
448 loss : float
449 Current epoch's training loss
450
451 valid_loss : :py:class:`float`, None
452 Current epoch's validation loss
453
454 extra_valid_losses : :py:class:`list` of :py:class:`float`
455 Validation losses from other validation datasets being currently
456 tracked
457
458 optimizer : :py:mod:`torch.optim`
459
460 logwriter : csv.DictWriter
461 Dictionary writer that give the ability to write on the trainlog.csv
462
463 logfile : io.TextIOWrapper
464
465 resource_data : tuple
466 Monitored resources at the machine (CPU and GPU)
467
468 """
469
470 logdata = (
471 ("epoch", f"{epoch}"),
472 (
473 "total_time",
474 f"{datetime.timedelta(seconds=int(current_time))}",
475 ),
476 ("eta", f"{datetime.timedelta(seconds=int(eta_seconds))}"),
477 ("loss", f"{loss:.6f}"),
478 ("learning_rate", f"{optimizer.param_groups[0]['lr']:.6f}"),
479 )
480
481 if valid_loss is not None:
482 logdata += (("validation_loss", f"{valid_loss:.6f}"),)
483
484 if extra_valid_losses:
485 entry = numpy.array_str(
486 numpy.array(extra_valid_losses),
487 max_line_width=sys.maxsize,
488 precision=6,
489 )
490 logdata += (("extra_validation_losses", entry),)
491
492 logdata += resource_data
493
494 logwriter.writerow(dict(k for k in logdata))
495 logfile.flush()
496 tqdm.write("|".join([f"{k}: {v}" for (k, v) in logdata[:4]]))
497
498
499def run(
500 model,
501 data_loader,
502 valid_loader,
503 extra_valid_loaders,
504 optimizer,
505 criterion,
506 checkpointer,
507 checkpoint_period,
508 device,
509 arguments,
510 output_folder,
511 monitoring_interval,
512 batch_chunk_count,
513 criterion_valid,
514):
515 """
516 Fits a CNN model using supervised learning and save it to disk.
517
518 This method supports periodic checkpointing and the output of a
519 CSV-formatted log with the evolution of some figures during training.
520
521
522 Parameters
523 ----------
524
525 model : :py:class:`torch.nn.Module`
526 Network (e.g. driu, hed, unet)
527
528 data_loader : :py:class:`torch.utils.data.DataLoader`
529 To be used to train the model
530
531 valid_loaders : :py:class:`list` of :py:class:`torch.utils.data.DataLoader`
532 To be used to validate the model and enable automatic checkpointing.
533 If ``None``, then do not validate it.
534
535 extra_valid_loaders : :py:class:`list` of :py:class:`torch.utils.data.DataLoader`
536 To be used to validate the model, however **does not affect** automatic
537 checkpointing. If empty, then does not log anything else. Otherwise,
538 an extra column with the loss of every dataset in this list is kept on
539 the final training log.
540
541 optimizer : :py:mod:`torch.optim`
542
543 criterion : :py:class:`torch.nn.modules.loss._Loss`
544 loss function
545
546 checkpointer : :py:class:`bob.med.tb.utils.checkpointer.Checkpointer`
547 checkpointer implementation
548
549 checkpoint_period : int
550 save a checkpoint every ``n`` epochs. If set to ``0`` (zero), then do
551 not save intermediary checkpoints
552
553 device : :py:class:`torch.device`
554 device to use
555
556 arguments : dict
557 start and end epochs
558
559 output_folder : str
560 output path
561
562 monitoring_interval : int, float
563 interval, in seconds (or fractions), through which we should monitor
564 resources during training.
565
566 batch_chunk_count: int
567 If this number is different than 1, then each batch will be divided in
568 this number of chunks. Gradients will be accumulated to perform each
569 mini-batch. This is particularly interesting when one has limited RAM
570 on the GPU, but would like to keep training with larger batches. One
571 exchanges for longer processing times in this case.
572
573 criterion_valid : :py:class:`torch.nn.modules.loss._Loss`
574 specific loss function for the validation set
575
576 """
577
578 start_epoch = arguments["epoch"]
579 max_epoch = arguments["max_epoch"]
580
581 check_gpu(device)
582
583 os.makedirs(output_folder, exist_ok=True)
584
585 # Save model summary
586 r, n = save_model_summary(output_folder, model)
587
588 # write static information to a CSV file
589 static_logfile_name = os.path.join(output_folder, "constants.csv")
590
591 static_information_to_csv(static_logfile_name, device, n)
592
593 # Log continous information to (another) file
594 logfile_name = os.path.join(output_folder, "trainlog.csv")
595
596 check_exist_logfile(logfile_name, arguments)
597
598 logfile_fields = create_logfile_fields(
599 valid_loader, extra_valid_loaders, device
600 )
601
602 # the lowest validation loss obtained so far - this value is updated only
603 # if a validation set is available
604 lowest_validation_loss = sys.float_info.max
605
606 # set a specific validation criterion if the user has set one
607 criterion_valid = criterion_valid or criterion
608
609 with open(logfile_name, "a+", newline="") as logfile:
610 logwriter = csv.DictWriter(logfile, fieldnames=logfile_fields)
611
612 if arguments["epoch"] == 0:
613 logwriter.writeheader()
614
615 model.train() # set training mode
616
617 model.to(device) # set/cast parameters to device
618 for state in optimizer.state.values():
619 for k, v in state.items():
620 if isinstance(v, torch.Tensor):
621 state[k] = v.to(device)
622
623 # Total training timer
624 start_training_time = time.time()
625
626 for epoch in tqdm(
627 range(start_epoch, max_epoch),
628 desc="epoch",
629 leave=False,
630 disable=None,
631 ):
632
633 with ResourceMonitor(
634 interval=monitoring_interval,
635 has_gpu=(device.type == "cuda"),
636 main_pid=os.getpid(),
637 logging_level=logging.ERROR,
638 ) as resource_monitor:
639 epoch = epoch + 1
640 arguments["epoch"] = epoch
641
642 # Epoch time
643 start_epoch_time = time.time()
644
645 train_loss = train_epoch(
646 data_loader,
647 model,
648 optimizer,
649 device,
650 criterion,
651 batch_chunk_count,
652 )
653
654 valid_loss = (
655 validate_epoch(
656 valid_loader, model, device, criterion_valid, "valid"
657 )
658 if valid_loader is not None
659 else None
660 )
661
662 extra_valid_losses = []
663 for pos, extra_valid_loader in enumerate(extra_valid_loaders):
664 loss = validate_epoch(
665 extra_valid_loader,
666 model,
667 device,
668 criterion_valid,
669 f"xval@{pos+1}",
670 )
671 extra_valid_losses.append(loss)
672
673 lowest_validation_loss = checkpointer_process(
674 checkpointer,
675 checkpoint_period,
676 valid_loss,
677 lowest_validation_loss,
678 arguments,
679 epoch,
680 max_epoch,
681 )
682
683 # computes ETA (estimated time-of-arrival; end of training) taking
684 # into consideration previous epoch performance
685 epoch_time = time.time() - start_epoch_time
686 eta_seconds = epoch_time * (max_epoch - epoch)
687 current_time = time.time() - start_training_time
688
689 write_log_info(
690 epoch,
691 current_time,
692 eta_seconds,
693 train_loss,
694 valid_loss,
695 extra_valid_losses,
696 optimizer,
697 logwriter,
698 logfile,
699 resource_monitor.data,
700 )
701
702 total_training_time = time.time() - start_training_time
703 logger.info(
704 f"Total training time: {datetime.timedelta(seconds=total_training_time)} ({(total_training_time/max_epoch):.4f}s in average per epoch)"
705 )