Coverage for src/deepdraw/utils/checkpointer.py: 96%

48 statements  

« prev     ^ index     » next       coverage.py v7.3.1, created at 2023-11-30 15:00 +0100

1# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> 

2# 

3# SPDX-License-Identifier: GPL-3.0-or-later 

4 

5import logging 

6import os 

7 

8import torch 

9 

10logger = logging.getLogger(__name__) 

11 

12 

13class Checkpointer: 

14 """A simple pytorch checkpointer. 

15 

16 Parameters 

17 ---------- 

18 

19 model : torch.nn.Module 

20 Network model, eventually loaded from a checkpointed file 

21 

22 optimizer : :py:mod:`torch.optim`, Optional 

23 Optimizer 

24 

25 scheduler : :py:mod:`torch.optim`, Optional 

26 Learning rate scheduler 

27 

28 path : :py:class:`str`, Optional 

29 Directory where to save checkpoints. 

30 """ 

31 

32 def __init__(self, model, optimizer=None, scheduler=None, path="."): 

33 self.model = model 

34 self.optimizer = optimizer 

35 self.scheduler = scheduler 

36 self.path = os.path.realpath(path) 

37 

38 def save(self, name, **kwargs): 

39 data = {} 

40 data["model"] = self.model.state_dict() 

41 if self.optimizer is not None: 

42 data["optimizer"] = self.optimizer.state_dict() 

43 if self.scheduler is not None: 

44 data["scheduler"] = self.scheduler.state_dict() 

45 data.update(kwargs) 

46 

47 name = f"{name}.pth" 

48 outf = os.path.join(self.path, name) 

49 logger.info(f"Saving checkpoint to {outf}") 

50 torch.save(data, outf) 

51 with open(self._last_checkpoint_filename, "w") as f: 

52 f.write(name) 

53 

54 def load(self, f=None): 

55 """Loads model, optimizer and scheduler from file. 

56 

57 Parameters 

58 ========== 

59 

60 f : :py:class:`str`, Optional 

61 Name of a file (absolute or relative to ``self.path``), that 

62 contains the checkpoint data to load into the model, and optionally 

63 into the optimizer and the scheduler. If not specified, loads data 

64 from current path. 

65 """ 

66 

67 if f is None: 

68 f = self.last_checkpoint() 

69 

70 if f is None: 

71 # no checkpoint could be found 

72 logger.warning("No checkpoint found (and none passed)") 

73 return {} 

74 

75 # loads file data into memory 

76 logger.info(f"Loading checkpoint from {f}...") 

77 checkpoint = torch.load(f, map_location=torch.device("cpu")) 

78 

79 # converts model entry to model parameters 

80 self.model.load_state_dict(checkpoint.pop("model")) 

81 

82 if self.optimizer is not None: 

83 self.optimizer.load_state_dict(checkpoint.pop("optimizer")) 

84 if self.scheduler is not None: 

85 self.scheduler.load_state_dict(checkpoint.pop("scheduler")) 

86 

87 return checkpoint 

88 

89 @property 

90 def _last_checkpoint_filename(self): 

91 return os.path.join(self.path, "last_checkpoint") 

92 

93 def has_checkpoint(self): 

94 return os.path.exists(self._last_checkpoint_filename) 

95 

96 def last_checkpoint(self): 

97 if self.has_checkpoint(): 

98 with open(self._last_checkpoint_filename) as fobj: 

99 return os.path.join(self.path, fobj.read().strip()) 

100 return None