Coverage for /scratch/builds/bob/bob.med.tb/miniconda/conda-bld/bob.med.tb_1637571489937/_test_env_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placeho/lib/python3.8/site-packages/bob/med/tb/utils/checkpointer.py: 90%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

49 statements  

1#!/usr/bin/env python 

2# -*- coding: utf-8 -*- 

3 

4import os 

5 

6import torch 

7 

8import logging 

9 

10logger = logging.getLogger(__name__) 

11 

12 

13class Checkpointer: 

14 """A simple pytorch checkpointer 

15 

16 Parameters 

17 ---------- 

18 

19 model : torch.nn.Module 

20 Network model, eventually loaded from a checkpointed file 

21 

22 optimizer : :py:mod:`torch.optim`, Optional 

23 Optimizer 

24 

25 scheduler : :py:mod:`torch.optim`, Optional 

26 Learning rate scheduler 

27 

28 path : :py:class:`str`, Optional 

29 Directory where to save checkpoints. 

30 

31 """ 

32 

33 def __init__(self, model, optimizer=None, scheduler=None, path="."): 

34 

35 self.model = model 

36 self.optimizer = optimizer 

37 self.scheduler = scheduler 

38 self.path = os.path.realpath(path) 

39 

40 def save(self, name, **kwargs): 

41 

42 data = {} 

43 data["model"] = self.model.state_dict() 

44 if self.optimizer is not None: 

45 data["optimizer"] = self.optimizer.state_dict() 

46 if self.scheduler is not None: 

47 data["scheduler"] = self.scheduler.state_dict() 

48 data.update(kwargs) 

49 

50 name = f"{name}.pth" 

51 outf = os.path.join(self.path, name) 

52 logger.info(f"Saving checkpoint to {outf}") 

53 torch.save(data, outf) 

54 with open(self._last_checkpoint_filename, "w") as f: 

55 f.write(name) 

56 

57 def load(self, f=None, strict=True): 

58 """Loads model, optimizer and scheduler from file 

59 

60 

61 Parameters 

62 ========== 

63 

64 f : :py:class:`str`, Optional 

65 Name of a file (absolute or relative to ``self.path``), that 

66 contains the checkpoint data to load into the model, and optionally 

67 into the optimizer and the scheduler. If not specified, loads data 

68 from current path. 

69 

70 partial : :py:class:`bool`, Optional 

71 If True, loading is not strict and only the model is loaded 

72 

73 """ 

74 

75 if f is None: 

76 f = self.last_checkpoint() 

77 

78 if f is None: 

79 # no checkpoint could be found 

80 logger.warning("No checkpoint found (and none passed)") 

81 return {} 

82 

83 # loads file data into memory 

84 logger.info(f"Loading checkpoint from {f}...") 

85 checkpoint = torch.load(f, map_location=torch.device("cpu")) 

86 

87 # converts model entry to model parameters 

88 self.model.load_state_dict(checkpoint.pop("model"), strict=strict) 

89 

90 if strict: 

91 if self.optimizer is not None: 

92 self.optimizer.load_state_dict(checkpoint.pop("optimizer")) 

93 if self.scheduler is not None: 

94 self.scheduler.load_state_dict(checkpoint.pop("scheduler")) 

95 

96 return checkpoint 

97 

98 @property 

99 def _last_checkpoint_filename(self): 

100 return os.path.join(self.path, "last_checkpoint") 

101 

102 def has_checkpoint(self): 

103 return os.path.exists(self._last_checkpoint_filename) 

104 

105 def last_checkpoint(self): 

106 if self.has_checkpoint(): 

107 with open(self._last_checkpoint_filename, "r") as fobj: 

108 return os.path.join(self.path, fobj.read().strip()) 

109 return None