Coverage for /scratch/builds/bob/bob.ip.binseg/miniconda/conda-bld/bob.ip.binseg_1673966692152/_test_env_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_p/lib/python3.10/site-packages/bob/ip/common/test/test_checkpointer.py: 79%

67 statements  

« prev     ^ index     » next       coverage.py v7.0.5, created at 2023-01-17 15:03 +0000

1#!/usr/bin/env python 

2# -*- coding: utf-8 -*- 

3 

4import os 

5import unittest 

6 

7from collections import OrderedDict 

8from tempfile import TemporaryDirectory 

9 

10import torch 

11 

12from ..utils.checkpointer import Checkpointer 

13 

14 

15class TestCheckpointer(unittest.TestCase): 

16 def create_model(self): 

17 return torch.nn.Sequential(torch.nn.Linear(2, 3), torch.nn.Linear(3, 1)) 

18 

19 def create_complex_model(self): 

20 m = torch.nn.Module() 

21 m.block1 = torch.nn.Module() 

22 m.block1.layer1 = torch.nn.Linear(2, 3) 

23 m.layer2 = torch.nn.Linear(3, 2) 

24 m.res = torch.nn.Module() 

25 m.res.layer2 = torch.nn.Linear(3, 2) 

26 

27 state_dict = OrderedDict() 

28 state_dict["layer1.weight"] = torch.rand(3, 2) 

29 state_dict["layer1.bias"] = torch.rand(3) 

30 state_dict["layer2.weight"] = torch.rand(2, 3) 

31 state_dict["layer2.bias"] = torch.rand(2) 

32 state_dict["res.layer2.weight"] = torch.rand(2, 3) 

33 state_dict["res.layer2.bias"] = torch.rand(2) 

34 

35 return m, state_dict 

36 

37 def test_from_last_checkpoint_model(self): 

38 # test that loading works even if they differ by a prefix 

39 trained_model = self.create_model() 

40 fresh_model = self.create_model() 

41 with TemporaryDirectory() as f: 

42 checkpointer = Checkpointer(trained_model, path=f) 

43 checkpointer.save("checkpoint_file") 

44 

45 # in the same folder 

46 fresh_checkpointer = Checkpointer(fresh_model, path=f) 

47 assert fresh_checkpointer.has_checkpoint() 

48 assert fresh_checkpointer.last_checkpoint() == os.path.realpath( 

49 os.path.join(f, "checkpoint_file.pth") 

50 ) 

51 _ = fresh_checkpointer.load() 

52 

53 for trained_p, loaded_p in zip( 

54 trained_model.parameters(), fresh_model.parameters() 

55 ): 

56 # different tensor references 

57 assert id(trained_p) != id(loaded_p) 

58 # same content 

59 assert trained_p.equal(loaded_p) 

60 

61 def test_from_name_file_model(self): 

62 # test that loading works even if they differ by a prefix 

63 trained_model = self.create_model() 

64 fresh_model = self.create_model() 

65 with TemporaryDirectory() as f: 

66 checkpointer = Checkpointer(trained_model, path=f) 

67 checkpointer.save("checkpoint_file") 

68 

69 # on different folders 

70 with TemporaryDirectory() as g: 

71 fresh_checkpointer = Checkpointer(fresh_model, path=g) 

72 assert not fresh_checkpointer.has_checkpoint() 

73 assert fresh_checkpointer.last_checkpoint() is None 

74 _ = fresh_checkpointer.load( 

75 os.path.join(f, "checkpoint_file.pth") 

76 ) 

77 

78 for trained_p, loaded_p in zip( 

79 trained_model.parameters(), fresh_model.parameters() 

80 ): 

81 # different tensor references 

82 assert id(trained_p) != id(loaded_p) 

83 # same content 

84 assert trained_p.equal(loaded_p) 

85 

86 def test_checkpointer_process(self): 

87 from ...binseg.configs.models.lwnet import model 

88 from ...binseg.engine.trainer import checkpointer_process 

89 

90 with TemporaryDirectory() as f: 

91 checkpointer = Checkpointer(model, path=f) 

92 lowest_validation_loss = 0.001 

93 checkpoint_period = 0 

94 valid_loss = 1.5 

95 arguments = {"epoch": 0, "max_epoch": 10} 

96 epoch = 1 

97 max_epoch = 10 

98 lowest_validation_loss = checkpointer_process( 

99 checkpointer, 

100 checkpoint_period, 

101 valid_loss, 

102 lowest_validation_loss, 

103 arguments, 

104 epoch, 

105 max_epoch, 

106 ) 

107 

108 assert lowest_validation_loss == 0.001 

109 

110 lowest_validation_loss = 1000 

111 lowest_validation_loss = checkpointer_process( 

112 checkpointer, 

113 checkpoint_period, 

114 valid_loss, 

115 lowest_validation_loss, 

116 arguments, 

117 epoch, 

118 max_epoch, 

119 ) 

120 

121 assert lowest_validation_loss == 1.5