Coverage for src/deepdraw/utils/checkpointer.py: 96%
48 statements
« prev ^ index » next coverage.py v7.3.1, created at 2023-11-30 15:00 +0100
« prev ^ index » next coverage.py v7.3.1, created at 2023-11-30 15:00 +0100
1# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
2#
3# SPDX-License-Identifier: GPL-3.0-or-later
5import logging
6import os
8import torch
10logger = logging.getLogger(__name__)
13class Checkpointer:
14 """A simple pytorch checkpointer.
16 Parameters
17 ----------
19 model : torch.nn.Module
20 Network model, eventually loaded from a checkpointed file
22 optimizer : :py:mod:`torch.optim`, Optional
23 Optimizer
25 scheduler : :py:mod:`torch.optim`, Optional
26 Learning rate scheduler
28 path : :py:class:`str`, Optional
29 Directory where to save checkpoints.
30 """
32 def __init__(self, model, optimizer=None, scheduler=None, path="."):
33 self.model = model
34 self.optimizer = optimizer
35 self.scheduler = scheduler
36 self.path = os.path.realpath(path)
38 def save(self, name, **kwargs):
39 data = {}
40 data["model"] = self.model.state_dict()
41 if self.optimizer is not None:
42 data["optimizer"] = self.optimizer.state_dict()
43 if self.scheduler is not None:
44 data["scheduler"] = self.scheduler.state_dict()
45 data.update(kwargs)
47 name = f"{name}.pth"
48 outf = os.path.join(self.path, name)
49 logger.info(f"Saving checkpoint to {outf}")
50 torch.save(data, outf)
51 with open(self._last_checkpoint_filename, "w") as f:
52 f.write(name)
54 def load(self, f=None):
55 """Loads model, optimizer and scheduler from file.
57 Parameters
58 ==========
60 f : :py:class:`str`, Optional
61 Name of a file (absolute or relative to ``self.path``), that
62 contains the checkpoint data to load into the model, and optionally
63 into the optimizer and the scheduler. If not specified, loads data
64 from current path.
65 """
67 if f is None:
68 f = self.last_checkpoint()
70 if f is None:
71 # no checkpoint could be found
72 logger.warning("No checkpoint found (and none passed)")
73 return {}
75 # loads file data into memory
76 logger.info(f"Loading checkpoint from {f}...")
77 checkpoint = torch.load(f, map_location=torch.device("cpu"))
79 # converts model entry to model parameters
80 self.model.load_state_dict(checkpoint.pop("model"))
82 if self.optimizer is not None:
83 self.optimizer.load_state_dict(checkpoint.pop("optimizer"))
84 if self.scheduler is not None:
85 self.scheduler.load_state_dict(checkpoint.pop("scheduler"))
87 return checkpoint
89 @property
90 def _last_checkpoint_filename(self):
91 return os.path.join(self.path, "last_checkpoint")
93 def has_checkpoint(self):
94 return os.path.exists(self._last_checkpoint_filename)
96 def last_checkpoint(self):
97 if self.has_checkpoint():
98 with open(self._last_checkpoint_filename) as fobj:
99 return os.path.join(self.path, fobj.read().strip())
100 return None