Coverage for /scratch/builds/bob/bob.ip.binseg/miniconda/conda-bld/bob.ip.binseg_1673966692152/_test_env_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_p/lib/python3.10/site-packages/bob/ip/common/utils/checkpointer.py: 96%

48 statements  

« prev     ^ index     » next       coverage.py v7.0.5, created at 2023-01-17 15:03 +0000

1#!/usr/bin/env python 

2# -*- coding: utf-8 -*- 

3 

4import logging 

5import os 

6 

7import torch 

8 

9logger = logging.getLogger(__name__) 

10 

11 

12class Checkpointer: 

13 """A simple pytorch checkpointer 

14 

15 Parameters 

16 ---------- 

17 

18 model : torch.nn.Module 

19 Network model, eventually loaded from a checkpointed file 

20 

21 optimizer : :py:mod:`torch.optim`, Optional 

22 Optimizer 

23 

24 scheduler : :py:mod:`torch.optim`, Optional 

25 Learning rate scheduler 

26 

27 path : :py:class:`str`, Optional 

28 Directory where to save checkpoints. 

29 

30 """ 

31 

32 def __init__(self, model, optimizer=None, scheduler=None, path="."): 

33 

34 self.model = model 

35 self.optimizer = optimizer 

36 self.scheduler = scheduler 

37 self.path = os.path.realpath(path) 

38 

39 def save(self, name, **kwargs): 

40 

41 data = {} 

42 data["model"] = self.model.state_dict() 

43 if self.optimizer is not None: 

44 data["optimizer"] = self.optimizer.state_dict() 

45 if self.scheduler is not None: 

46 data["scheduler"] = self.scheduler.state_dict() 

47 data.update(kwargs) 

48 

49 name = f"{name}.pth" 

50 outf = os.path.join(self.path, name) 

51 logger.info(f"Saving checkpoint to {outf}") 

52 torch.save(data, outf) 

53 with open(self._last_checkpoint_filename, "w") as f: 

54 f.write(name) 

55 

56 def load(self, f=None): 

57 """Loads model, optimizer and scheduler from file 

58 

59 

60 Parameters 

61 ========== 

62 

63 f : :py:class:`str`, Optional 

64 Name of a file (absolute or relative to ``self.path``), that 

65 contains the checkpoint data to load into the model, and optionally 

66 into the optimizer and the scheduler. If not specified, loads data 

67 from current path. 

68 

69 """ 

70 

71 if f is None: 

72 f = self.last_checkpoint() 

73 

74 if f is None: 

75 # no checkpoint could be found 

76 logger.warning("No checkpoint found (and none passed)") 

77 return {} 

78 

79 # loads file data into memory 

80 logger.info(f"Loading checkpoint from {f}...") 

81 checkpoint = torch.load(f, map_location=torch.device("cpu")) 

82 

83 # converts model entry to model parameters 

84 self.model.load_state_dict(checkpoint.pop("model")) 

85 

86 if self.optimizer is not None: 

87 self.optimizer.load_state_dict(checkpoint.pop("optimizer")) 

88 if self.scheduler is not None: 

89 self.scheduler.load_state_dict(checkpoint.pop("scheduler")) 

90 

91 return checkpoint 

92 

93 @property 

94 def _last_checkpoint_filename(self): 

95 return os.path.join(self.path, "last_checkpoint") 

96 

97 def has_checkpoint(self): 

98 return os.path.exists(self._last_checkpoint_filename) 

99 

100 def last_checkpoint(self): 

101 if self.has_checkpoint(): 

102 with open(self._last_checkpoint_filename, "r") as fobj: 

103 return os.path.join(self.path, fobj.read().strip()) 

104 return None