1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3
4import os
5
6import torch
7
8import logging
9
10logger = logging.getLogger(__name__)
11
12
13class Checkpointer:
14 """A simple pytorch checkpointer
15
16 Parameters
17 ----------
18
19 model : torch.nn.Module
20 Network model, eventually loaded from a checkpointed file
21
22 optimizer : :py:mod:`torch.optim`, Optional
23 Optimizer
24
25 scheduler : :py:mod:`torch.optim`, Optional
26 Learning rate scheduler
27
28 path : :py:class:`str`, Optional
29 Directory where to save checkpoints.
30
31 """
32
33 def __init__(self, model, optimizer=None, scheduler=None, path="."):
34
35 self.model = model
36 self.optimizer = optimizer
37 self.scheduler = scheduler
38 self.path = os.path.realpath(path)
39
40 def save(self, name, **kwargs):
41
42 data = {}
43 data["model"] = self.model.state_dict()
44 if self.optimizer is not None:
45 data["optimizer"] = self.optimizer.state_dict()
46 if self.scheduler is not None:
47 data["scheduler"] = self.scheduler.state_dict()
48 data.update(kwargs)
49
50 name = f"{name}.pth"
51 outf = os.path.join(self.path, name)
52 logger.info(f"Saving checkpoint to {outf}")
53 torch.save(data, outf)
54 with open(self._last_checkpoint_filename, "w") as f:
55 f.write(name)
56
57 def load(self, f=None, strict=True):
58 """Loads model, optimizer and scheduler from file
59
60
61 Parameters
62 ==========
63
64 f : :py:class:`str`, Optional
65 Name of a file (absolute or relative to ``self.path``), that
66 contains the checkpoint data to load into the model, and optionally
67 into the optimizer and the scheduler. If not specified, loads data
68 from current path.
69
70 partial : :py:class:`bool`, Optional
71 If True, loading is not strict and only the model is loaded
72
73 """
74
75 if f is None:
76 f = self.last_checkpoint()
77
78 if f is None:
79 # no checkpoint could be found
80 logger.warning("No checkpoint found (and none passed)")
81 return {}
82
83 # loads file data into memory
84 logger.info(f"Loading checkpoint from {f}...")
85 checkpoint = torch.load(f, map_location=torch.device("cpu"))
86
87 # converts model entry to model parameters
88 self.model.load_state_dict(checkpoint.pop("model"), strict=strict)
89
90 if strict:
91 if self.optimizer is not None:
92 self.optimizer.load_state_dict(checkpoint.pop("optimizer"))
93 if self.scheduler is not None:
94 self.scheduler.load_state_dict(checkpoint.pop("scheduler"))
95
96 return checkpoint
97
98 @property
99 def _last_checkpoint_filename(self):
100 return os.path.join(self.path, "last_checkpoint")
101
102 def has_checkpoint(self):
103 return os.path.exists(self._last_checkpoint_filename)
104
105 def last_checkpoint(self):
106 if self.has_checkpoint():
107 with open(self._last_checkpoint_filename, "r") as fobj:
108 return os.path.join(self.path, fobj.read().strip())
109 return None