Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1#!/usr/bin/env python 

2# -*- coding: utf-8 -*- 

3 

4import os 

5import unittest 

6 

7from collections import OrderedDict 

8from tempfile import TemporaryDirectory 

9 

10import torch 

11 

12from ..utils.checkpointer import Checkpointer 

13 

14 

15class TestCheckpointer(unittest.TestCase): 

16 def create_model(self): 

17 return torch.nn.Sequential(torch.nn.Linear(2, 3), torch.nn.Linear(3, 1)) 

18 

19 def create_complex_model(self): 

20 m = torch.nn.Module() 

21 m.block1 = torch.nn.Module() 

22 m.block1.layer1 = torch.nn.Linear(2, 3) 

23 m.layer2 = torch.nn.Linear(3, 2) 

24 m.res = torch.nn.Module() 

25 m.res.layer2 = torch.nn.Linear(3, 2) 

26 

27 state_dict = OrderedDict() 

28 state_dict["layer1.weight"] = torch.rand(3, 2) 

29 state_dict["layer1.bias"] = torch.rand(3) 

30 state_dict["layer2.weight"] = torch.rand(2, 3) 

31 state_dict["layer2.bias"] = torch.rand(2) 

32 state_dict["res.layer2.weight"] = torch.rand(2, 3) 

33 state_dict["res.layer2.bias"] = torch.rand(2) 

34 

35 return m, state_dict 

36 

37 def test_from_last_checkpoint_model(self): 

38 # test that loading works even if they differ by a prefix 

39 trained_model = self.create_model() 

40 fresh_model = self.create_model() 

41 with TemporaryDirectory() as f: 

42 checkpointer = Checkpointer(trained_model, path=f) 

43 checkpointer.save("checkpoint_file") 

44 

45 # in the same folder 

46 fresh_checkpointer = Checkpointer(fresh_model, path=f) 

47 assert fresh_checkpointer.has_checkpoint() 

48 assert fresh_checkpointer.last_checkpoint() == os.path.realpath( 

49 os.path.join(f, "checkpoint_file.pth") 

50 ) 

51 _ = fresh_checkpointer.load() 

52 

53 for trained_p, loaded_p in zip( 

54 trained_model.parameters(), fresh_model.parameters() 

55 ): 

56 # different tensor references 

57 assert id(trained_p) != id(loaded_p) 

58 # same content 

59 assert trained_p.equal(loaded_p) 

60 

61 def test_from_name_file_model(self): 

62 # test that loading works even if they differ by a prefix 

63 trained_model = self.create_model() 

64 fresh_model = self.create_model() 

65 with TemporaryDirectory() as f: 

66 checkpointer = Checkpointer(trained_model, path=f) 

67 checkpointer.save("checkpoint_file") 

68 

69 # on different folders 

70 with TemporaryDirectory() as g: 

71 fresh_checkpointer = Checkpointer(fresh_model, path=g) 

72 assert not fresh_checkpointer.has_checkpoint() 

73 assert fresh_checkpointer.last_checkpoint() is None 

74 _ = fresh_checkpointer.load( 

75 os.path.join(f, "checkpoint_file.pth") 

76 ) 

77 

78 for trained_p, loaded_p in zip( 

79 trained_model.parameters(), fresh_model.parameters() 

80 ): 

81 # different tensor references 

82 assert id(trained_p) != id(loaded_p) 

83 # same content 

84 assert trained_p.equal(loaded_p)