1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3
4import logging
5import os
6
7import torch
8
9logger = logging.getLogger(__name__)
10
11
12class Checkpointer:
13 """A simple pytorch checkpointer
14
15 Parameters
16 ----------
17
18 model : torch.nn.Module
19 Network model, eventually loaded from a checkpointed file
20
21 optimizer : :py:mod:`torch.optim`, Optional
22 Optimizer
23
24 scheduler : :py:mod:`torch.optim`, Optional
25 Learning rate scheduler
26
27 path : :py:class:`str`, Optional
28 Directory where to save checkpoints.
29
30 """
31
32 def __init__(self, model, optimizer=None, scheduler=None, path="."):
33
34 self.model = model
35 self.optimizer = optimizer
36 self.scheduler = scheduler
37 self.path = os.path.realpath(path)
38
39 def save(self, name, **kwargs):
40
41 data = {}
42 data["model"] = self.model.state_dict()
43 if self.optimizer is not None:
44 data["optimizer"] = self.optimizer.state_dict()
45 if self.scheduler is not None:
46 data["scheduler"] = self.scheduler.state_dict()
47 data.update(kwargs)
48
49 name = f"{name}.pth"
50 outf = os.path.join(self.path, name)
51 logger.info(f"Saving checkpoint to {outf}")
52 torch.save(data, outf)
53 with open(self._last_checkpoint_filename, "w") as f:
54 f.write(name)
55
56 def load(self, f=None):
57 """Loads model, optimizer and scheduler from file
58
59
60 Parameters
61 ==========
62
63 f : :py:class:`str`, Optional
64 Name of a file (absolute or relative to ``self.path``), that
65 contains the checkpoint data to load into the model, and optionally
66 into the optimizer and the scheduler. If not specified, loads data
67 from current path.
68
69 """
70
71 if f is None:
72 f = self.last_checkpoint()
73
74 if f is None:
75 # no checkpoint could be found
76 logger.warning("No checkpoint found (and none passed)")
77 return {}
78
79 # loads file data into memory
80 logger.info(f"Loading checkpoint from {f}...")
81 checkpoint = torch.load(f, map_location=torch.device("cpu"))
82
83 # converts model entry to model parameters
84 self.model.load_state_dict(checkpoint.pop("model"))
85
86 if self.optimizer is not None:
87 self.optimizer.load_state_dict(checkpoint.pop("optimizer"))
88 if self.scheduler is not None:
89 self.scheduler.load_state_dict(checkpoint.pop("scheduler"))
90
91 return checkpoint
92
93 @property
94 def _last_checkpoint_filename(self):
95 return os.path.join(self.path, "last_checkpoint")
96
97 def has_checkpoint(self):
98 return os.path.exists(self._last_checkpoint_filename)
99
100 def last_checkpoint(self):
101 if self.has_checkpoint():
102 with open(self._last_checkpoint_filename, "r") as fobj:
103 return os.path.join(self.path, fobj.read().strip())
104 return None